Skip to content

Commit

Permalink
fix(weaviate): remove ndim requirement in weaviate (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Feb 28, 2022
1 parent fd1f0f9 commit 78cadd0
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 47 deletions.
3 changes: 2 additions & 1 deletion docarray/array/storage/base/getsetdel.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,5 @@ def _save_offset2ids(self):
...

def __del__(self):
self._save_offset2ids()
if hasattr(self, '_offset2ids'):
self._save_offset2ids()
33 changes: 14 additions & 19 deletions docarray/array/storage/weaviate/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class WeaviateConfig:
"""This class stores the config variables to initialize
connection to the Weaviate server"""

n_dim: int
host: Optional[str] = field(default="localhost")
host: Optional[str] = field(default='localhost')
port: Optional[int] = field(default=8080)
protocol: Optional[int] = field(default="http")
protocol: Optional[str] = field(default='http')
name: Optional[str] = None
serialize_config: Dict = field(default_factory=dict)
n_dim: Optional[int] = None # deprecated, not used anymore since weaviate 1.10


class BackendMixin(BaseBackendMixin):
Expand All @@ -60,13 +60,12 @@ def _init_storage(
"""

if not config:
raise ValueError('Config object must be specified')
config = WeaviateConfig()
elif isinstance(config, dict):
config = dataclass_from_dict(WeaviateConfig, config)

from ... import DocumentArray

self._n_dim = config.n_dim
self._serialize_config = config.serialize_config

if config.name and config.name != config.name.capitalize():
Expand Down Expand Up @@ -278,25 +277,21 @@ def _doc2weaviate_create_payload(self, value: 'Document'):
:param value: document to create a payload for
:return: the payload dictionary
"""
if value.embedding is None:
embedding = np.zeros(self._n_dim)
else:
if value.embedding is not None:
from ....math.ndarray import to_numpy_array

embedding = to_numpy_array(value.embedding)

if embedding.ndim > 1:
embedding = np.asarray(embedding).squeeze()
if embedding.shape != (self._n_dim,):
raise ValueError(
f'All documents must have embedding of shape n_dim: {self._n_dim}, receiving shape: {embedding.shape}'
)
if embedding.ndim > 1:
embedding = np.asarray(embedding).squeeze()

# Weaviate expects vector to have dim 2 at least
# or get weaviate.exceptions.UnexpectedStatusCodeException: models.C11yVector
# hence we cast it to list of a single element
if len(embedding) == 1:
embedding = [embedding[0]]
# Weaviate expects vector to have dim 2 at least
# or get weaviate.exceptions.UnexpectedStatusCodeException: models.C11yVector
# hence we cast it to list of a single element
if len(embedding) == 1:
embedding = [embedding[0]]
else:
embedding = None

return dict(
data_object={'_serialized': value.to_base64(**self._serialize_config)},
Expand Down
19 changes: 10 additions & 9 deletions docarray/array/storage/weaviate/getsetdel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ def _getitem(self, wid: str) -> 'Document':
:raises KeyError: raise error when weaviate id does not exist in storage
:return: Document
"""
resp = self._client.data_object.get_by_id(wid, with_vector=True)
if not resp:
raise KeyError(wid)
return Document.from_base64(
resp['properties']['_serialized'], **self._serialize_config
)
try:
resp = self._client.data_object.get_by_id(wid, with_vector=True)
return Document.from_base64(
resp['properties']['_serialized'], **self._serialize_config
)
except Exception as ex:
raise KeyError(wid) from ex

def _get_doc_by_id(self, _id: str) -> 'Document':
"""Concrete implementation of base class' ``_get_doc_by_id``
Expand All @@ -37,10 +38,10 @@ def _set_doc_by_id(self, _id: str, value: 'Document'):
"""
if _id != value.id:
self._del_doc_by_id(_id)
wid = self._wmap(value.id)

payload = self._doc2weaviate_create_payload(value)
if self._client.data_object.exists(wid):
self._client.data_object.delete(wid)
if self._client.data_object.exists(payload['uuid']):
self._client.data_object.delete(payload['uuid'])
self._client.data_object.create(**payload)

def _del_doc_by_id(self, _id: str):
Expand Down
7 changes: 3 additions & 4 deletions docs/advanced/document-store/weaviate.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ services:
- '8080'
- --scheme
- http
image: semitechnologies/weaviate:1.9.0
image: semitechnologies/weaviate:1.10.0
ports:
- 8080:8080
restart: on-failure:0
Expand All @@ -48,7 +48,7 @@ Assuming service is started using the default configuration (i.e. server address
```python
from docarray import DocumentArray

da = DocumentArray(storage='weaviate', config={'n_dim': 10})
da = DocumentArray(storage='weaviate')
```

The usage would be the same as the ordinary DocumentArray.
Expand All @@ -60,7 +60,7 @@ Note, that the `name` parameter in `config` needs to be capitalized.
```python
from docarray import DocumentArray

da = DocumentArray(storage='weaviate', config={'name': 'Persisted', 'host': 'localhost', 'port': 1234, 'n_dim': 10})
da = DocumentArray(storage='weaviate', config={'name': 'Persisted', 'host': 'localhost', 'port': 1234})

da.summary()
```
Expand All @@ -73,7 +73,6 @@ The following configs can be set:

| Name | Description | Default |
|--------------------|----------------------------------------------------------------------------------------|-----------------------------|
| `n_dim` | Number of dimensions of embeddings to be stored and retrieved | **This is always required** |
| `host` | Hostname of the Weaviate server | 'localhost' |
| `port` | port of the Weaviate server | 8080 |
| `protocol` | protocol to be used. Can be 'http' or 'https' | 'http' |
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/array/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: "3.3"
services:
weaviate:
image: semitechnologies/weaviate:1.9.0
image: semitechnologies/weaviate:1.10.0
ports:
- 8080:8080
environment:
Expand Down
13 changes: 0 additions & 13 deletions tests/unit/array/mixins/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,3 @@ def test_embeddings_setter(da_len, da_cls, config, start_storage):
da.embeddings = np.random.rand(da_len, 5)
for doc in da:
assert doc.embedding.shape == (5,)


@pytest.mark.parametrize('da_len', [0, 1])
@pytest.mark.parametrize('da_cls', [DocumentArrayWeaviate])
@pytest.mark.parametrize(
'config, n_dim', [({'n_dim': 1}, 1), (WeaviateConfig(n_dim=5), 5)]
)
def test_content_by_config(da_len, da_cls, config, n_dim):
with pytest.raises(ValueError):
da_cls(da_len)

da = da_cls.empty(da_len, config=config)
assert da._n_dim == n_dim

0 comments on commit 78cadd0

Please sign in to comment.