Skip to content

Commit

Permalink
fix: set subindices directly via access path (#488)
Browse files Browse the repository at this point in the history
* test: ass test for proper error message

* fix: enable proper error messages when subindex is set

* test: add test for access path setting of subindex

* fix: access path setting with subindex

* test: fix backend test

* test: fix misconfigured test

* fix: use context manager

* refactor: clean up inports

* refactor: update docarray/array/storage/base/getsetdel.py

Co-authored-by: AlaeddineAbdessalem <[email protected]>

* refactor: update docarray/array/storage/base/getsetdel.py

Co-authored-by: AlaeddineAbdessalem <[email protected]>

* refactor: update docarray/array/storage/base/getsetdel.py

* fix: catch attempted del by access path

Co-authored-by: AlaeddineAbdessalem <[email protected]>
  • Loading branch information
JohannesMessner and alaeddine-13 authored Aug 18, 2022
1 parent 7e3988c commit 461b996
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 24 deletions.
2 changes: 2 additions & 0 deletions docarray/array/storage/base/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def _init_subindices(self, *args, **kwargs):
config_subindex = (
dict() if config_subindex is None else config_subindex
) # allow None as input
if is_dataclass(config_subindex):
config_subindex = asdict(config_subindex)
config_joined = {**config, **config_subindex}
config_joined = self._ensure_unique_config(
config, config_subindex, config_joined, name
Expand Down
54 changes: 36 additions & 18 deletions docarray/array/storage/base/getsetdel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@
from docarray import Document, DocumentArray


def _check_valid_values_nested_set(docs, values):
docs, values = DocumentArray(docs), DocumentArray(values)
if len(docs) != len(values):
raise ValueError(
f'length of docs to set({len(docs)}) does not match '
f'length of values({len(values)})'
)
if docs[:, 'id'] != values[:, 'id']:
raise ValueError(
'Setting Documents by traversal paths with different IDs is not supported'
)


class BaseGetSetDelMixin(ABC):
"""Provide abstract methods and derived methods for ``__getitem__``, ``__setitem__`` and ``__delitem__``
Expand Down Expand Up @@ -128,6 +141,8 @@ def _del_docs_by_ids(self, ids):
self._del_doc_by_id(_id)

def _update_subindices_del(self, ids):
if isinstance(ids, str) and ids.startswith('@'):
return # deleting via access path is not supported
if getattr(self, '_subindices', None):
for selector, da in self._subindices.items():
ids_subindex = DocumentArray(self[ids])[selector, 'id']
Expand Down Expand Up @@ -173,15 +188,26 @@ def _set_docs_by_ids(self, ids, docs: Iterable['Document'], mismatch_ids: Dict):
for _id, doc in zip(ids, docs):
self._set_doc_by_id(_id, doc)

def _update_subindices_set(self, ids, docs):
if getattr(self, '_subindices', None):
for selector, da in self._subindices.items():
old_ids = DocumentArray(self[ids])[
selector, 'id'
] # hack to get the Document['@c'] without having to do Document.chunks
with da:
del da[old_ids]
da.extend(DocumentArray(docs)[selector]) # same hack here
def _update_subindices_set(self, set_index, docs):
subindices = getattr(self, '_subindices', None)
if not subindices:
return
if isinstance(set_index, tuple): # handled later in recursive call
return
if isinstance(set_index, str) and set_index.startswith('@'):
# 'nested' (non root-level) set, update entire subindex directly
_check_valid_values_nested_set(self[set_index], docs)
if set_index in subindices:
subindex_da = subindices[set_index]
with subindex_da:
subindex_da.clear()
subindex_da.extend(docs)
else: # root level set, update subindices iteratively
for subindex_selector, subindex_da in subindices.items():
old_ids = DocumentArray(self[set_index])[subindex_selector, 'id']
with subindex_da:
del subindex_da[old_ids]
subindex_da.extend(DocumentArray(docs)[subindex_selector])

def _set_docs(self, ids, docs: Iterable['Document']):
docs = list(docs)
Expand Down Expand Up @@ -228,17 +254,9 @@ def _set_doc_value_pairs_nested(
:param values: the value docs will be updated to
"""
docs = list(docs)
if len(docs) != len(values):
raise ValueError(
f'length of docs to set({len(docs)}) does not match '
f'length of values({len(values)})'
)
_check_valid_values_nested_set(docs, values)

for _d, _v in zip(docs, values):
if _d.id != _v.id:
raise ValueError(
'Setting Documents by traversal paths with different IDs is not supported'
)
_d._data = _v._data
if _d not in self:
root_d = self._find_root_doc_and_modify(_d)
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/array/mixins/test_getset.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,16 @@ def test_set_on_subindex(storage, config):
matches = da.find(query=np.random.random(2), on='@c')
assert matches
assert len(matches[0].embedding) == 2


def test_raise_correct_error_subindex_set():
da = DocumentArray(
[
Document(chunks=[Document(text='hello')]),
Document(chunks=[Document(text='world')]),
],
subindex_configs={'@c': None},
)

with pytest.raises(ValueError):
da['@c'] = DocumentArray(Document() for _ in range(2))
68 changes: 62 additions & 6 deletions tests/unit/array/test_advance_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,8 @@ def test_path_syntax_indexing(storage, config, start_storage):
('elasticsearch', ElasticConfig(n_dim=123)),
],
)
def test_path_syntax_indexing_set(storage, config, start_storage):
@pytest.mark.parametrize('use_subindex', [False, True])
def test_path_syntax_indexing_set(storage, config, use_subindex, start_storage):
da = DocumentArray.empty(3)
for i, d in enumerate(da):
d.chunks = DocumentArray.empty(5)
Expand All @@ -348,13 +349,22 @@ def test_path_syntax_indexing_set(storage, config, start_storage):
repeat = lambda s, l: [s] * l
da['@r,c,m,cc', 'text'] = repeat('a', 3 + 5 * 3 + 7 * 3 + 3 * 5 * 3)

if storage != 'memory':
if config:
da = DocumentArray(da, storage=storage, config=config)
else:
da = DocumentArray(da, storage=storage)
if config:
da = DocumentArray(
da,
storage=storage,
config=config,
subindex_configs={'@c': {'n_dim': 123}} if use_subindex else None,
)
else:
da = DocumentArray(
da, storage=storage, subindex_configs={'@c': None} if use_subindex else None
)

assert da['@c'].texts == repeat('a', 3 * 5)
assert da['@c', 'text'] == repeat('a', 3 * 5)
if use_subindex:
assert da._subindices['@c'].texts == repeat('a', 3 * 5)
assert da['@c:1', 'text'] == repeat('a', 3)
assert da['@c-1:', 'text'] == repeat('a', 3)
assert da['@c1', 'text'] == repeat('a', 3)
Expand All @@ -372,6 +382,8 @@ def test_path_syntax_indexing_set(storage, config, start_storage):
da['@m,cc', 'text'] = repeat('b', 3 + 5 * 3 + 7 * 3 + 3 * 5 * 3)

assert da['@c', 'text'] == repeat('a', 3 * 5)
if use_subindex:
assert da._subindices['@c'].texts == repeat('a', 3 * 5)
assert da['@c:1', 'text'] == repeat('a', 3)
assert da['@c-1:', 'text'] == repeat('a', 3)
assert da['@c1', 'text'] == repeat('a', 3)
Expand Down Expand Up @@ -410,6 +422,50 @@ def test_path_syntax_indexing_set(storage, config, start_storage):
assert da[2].id == 'new_id'


@pytest.mark.parametrize(
'storage,config',
[
('memory', None),
('sqlite', None),
('weaviate', WeaviateConfig(n_dim=123)),
('annlite', AnnliteConfig(n_dim=123)),
('qdrant', QdrantConfig(n_dim=123)),
('elasticsearch', ElasticConfig(n_dim=123)),
],
)
def test_getset_subindex(storage, config, start_storage):
da = DocumentArray(
[Document(chunks=[Document() for _ in range(5)]) for _ in range(3)],
config=config,
subindex_configs={'@c': {'n_dim': 123}} if config else {'@c': None},
)
assert len(da['@c']) == 15
assert len(da._subindices['@c']) == 15
# set entire subindex
chunks_ids = [c.id for c in da['@c']]
new_chunks = [Document(id=cid, text=f'{i}') for i, cid in enumerate(chunks_ids)]
da['@c'] = new_chunks
new_chunks = DocumentArray(new_chunks)
assert da['@c'] == new_chunks
assert da._subindices['@c'] == new_chunks
collected_chunks = DocumentArray.empty(0)
for d in da:
collected_chunks.extend(d.chunks)
assert collected_chunks == new_chunks
# set part of a subindex
chunks_ids = [c.id for c in da['@c:3']]
new_chunks = [Document(id=cid, text=f'{2*i}') for i, cid in enumerate(chunks_ids)]
da['@c:3'] = new_chunks
new_chunks = DocumentArray(new_chunks)
assert da['@c:3'] == new_chunks
for d in new_chunks:
assert d in da._subindices['@c']
collected_chunks = DocumentArray.empty(0)
for d in da:
collected_chunks.extend(d.chunks[:3])
assert collected_chunks == new_chunks


@pytest.mark.parametrize('size', [1, 5])
@pytest.mark.parametrize(
'storage,config_gen',
Expand Down

0 comments on commit 461b996

Please sign in to comment.