Skip to content

Commit

Permalink
Added dictionary based filtering for IN (#2893)
Browse files Browse the repository at this point in the history
Added dictinary based filtering where if target is a list, it's compared via IN. Cleaned tests a bit.
  • Loading branch information
istranical authored Jul 2, 2024
1 parent fdabf5a commit fa20c2f
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 16 deletions.
14 changes: 12 additions & 2 deletions deeplake/core/vectorstore/test_deeplake_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,15 @@ def filter_fn(x):
) # One for each return_tensors
assert len(data_e_f.keys()) == 2

# Run a filter query using a json with indra
# Run a filter query using a list
data_e_j = vector_store.search(
k=2,
return_tensors=["id", "text"],
filter={"text": texts[0:2]},
)
assert len(data_e_j["text"]) == 2

# Run a filter query using a json with indra. Wrap text as list to make sure it works
data_ce_f = vector_store_cloud.search(
embedding=query_embedding,
exec_option="compute_engine",
Expand All @@ -398,7 +406,9 @@ def filter_fn(x):
"metadata": vector_store_cloud.dataset_handler.dataset.metadata[0].data()[
"value"
],
"text": vector_store_cloud.dataset_handler.dataset.text[0].data()["value"],
"text": [
vector_store_cloud.dataset_handler.dataset.text[0].data()["value"]
],
},
)
assert len(data_ce_f["text"]) == 1
Expand Down
38 changes: 29 additions & 9 deletions deeplake/core/vectorstore/vector_search/filter/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@


def dp_filter_python(x: dict, filter: Dict) -> bool:
"""Filter helper function for Deep Lake"""
"""Filter helper function for Deep Lake
For non-dict tensors, perform exact match if target data is not a list, and perform "IN" match if target data is a list.
For dict tensors, perform exact match for each key-value pair in the target data.
"""

result = True

Expand All @@ -22,7 +25,10 @@ def dp_filter_python(x: dict, filter: Dict) -> bool:
k in data and v == data[k] for k, v in filter[tensor].items()
)
else:
result = result and data == filter[tensor]
if type(filter[tensor]) == list:
result = result and data in filter[tensor]
else:
result = result and data == filter[tensor]

return result

Expand Down Expand Up @@ -50,6 +56,11 @@ def attribute_based_filtering_python(
def attribute_based_filtering_tql(
view, filter: Optional[Dict] = None, debug_mode=False, logger=None
):
"""Filter helper function converting filter dictionary to TQL Deep Lake
For non-dict tensors, perform exact match if target data is not a list, and perform "IN" match if target data is a list.
For dict tensors, perform exact match for each key-value pair in the target data.
"""

tql_filter = ""

if filter is not None:
Expand All @@ -64,13 +75,22 @@ def attribute_based_filtering_tql(
val_str = f"'{value}'" if type(value) == str else f"{value}"
tql_filter += f"{tensor}['{key}'] == {val_str} and "
else:
val_str = (
f"'{filter[tensor]}'"
if isinstance(filter[tensor], str)
or isinstance(filter[tensor], np.str_)
else f"{filter[tensor]}"
)
tql_filter += f"{tensor} == {val_str} and "
if type(filter[tensor]) == list:
val_str = str(filter[tensor])[
1:-1
] # Remove square bracked and add rounded brackets below.

tql_filter += f"{tensor} in ({val_str}) and "

else:
val_str = (
f"'{filter[tensor]}'"
if isinstance(filter[tensor], str)
or isinstance(filter[tensor], np.str_)
else f"{filter[tensor]}"
)
tql_filter += f"{tensor} == {val_str} and "

tql_filter = tql_filter[:-5]

if debug_mode and logger is not None:
Expand Down
25 changes: 20 additions & 5 deletions deeplake/core/vectorstore/vector_search/filter/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ def test_attribute_based_filtering():
ds.create_tensor("metadata", htype="json")
ds.create_tensor("metadata2", htype="json")
ds.create_tensor("text", htype="text")
ds.create_tensor("text2", htype="text")
ds.metadata.extend([{"k": 1}, {"k": 2}, {"k": 3}, {"k": 4}])
ds.metadata2.extend([{"kk": "a"}, {"kk": "b"}, {"kk": "c"}, {"kk": "d"}])
ds.text.extend(["AA", "BB", "CC", "DD"])
ds.text2.extend(["11", "22", "33", "DD"])

# Test basic filter
filter_dict = {"metadata": {"k": 1}, "metadata2": {"kk": "a"}, "text": "AA"}

def filter_udf(x):
Expand All @@ -23,9 +26,7 @@ def filter_udf(x):

view_udf = filter_utils.attribute_based_filtering_python(ds, filter=filter_udf)

view_tql, tql_filter = filter_utils.attribute_based_filtering_tql(
ds, filter=filter_dict
)
view_tql, _ = filter_utils.attribute_based_filtering_tql(ds, filter=filter_dict)

assert view_dict.metadata.data()["value"][0] == filter_dict["metadata"]
assert view_dict.metadata2.data()["value"][0] == filter_dict["metadata2"]
Expand All @@ -34,10 +35,24 @@ def filter_udf(x):
assert view_udf.metadata.data()["value"][0] == filter_dict["metadata"]

assert len(view_tql) == len(ds)
assert (
tql_filter == "metadata['k'] == 1 and metadata2['kk'] == 'a' and text == 'AA'"

# Test filter with list
filter_dict_list = {"text2": ["11", "DD"]}

view_dict_list = filter_utils.attribute_based_filtering_python(
ds, filter=filter_dict_list
)

view_tql_list, _ = filter_utils.attribute_based_filtering_tql(
ds, filter=filter_dict_list
)

assert view_dict_list.text2.data()["value"][0] in filter_dict_list["text2"]
assert view_dict_list.text2.data()["value"][1] in filter_dict_list["text2"]

assert len(view_tql_list) == len(ds)

# Test bad tensor
filter_dict_bad_tensor = {
"metadata_bad": {"k": 1},
"metadata2": {"kk": "a"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def run(
tql_filter,
return_tensors,
)

view = self._get_view(
tql_query,
runtime=self.runtime,
Expand Down

0 comments on commit fa20c2f

Please sign in to comment.