diff --git a/deeplake/core/vectorstore/test_deeplake_vectorstore.py b/deeplake/core/vectorstore/test_deeplake_vectorstore.py index 6579cb3349..c3af4c549a 100644 --- a/deeplake/core/vectorstore/test_deeplake_vectorstore.py +++ b/deeplake/core/vectorstore/test_deeplake_vectorstore.py @@ -388,7 +388,15 @@ def filter_fn(x): ) # One for each return_tensors assert len(data_e_f.keys()) == 2 - # Run a filter query using a json with indra + # Run a filter query using a list + data_e_j = vector_store.search( + k=2, + return_tensors=["id", "text"], + filter={"text": texts[0:2]}, + ) + assert len(data_e_j["text"]) == 2 + + # Run a filter query using a json with indra. Wrap text as list to make sure it works data_ce_f = vector_store_cloud.search( embedding=query_embedding, exec_option="compute_engine", @@ -398,7 +406,9 @@ def filter_fn(x): "metadata": vector_store_cloud.dataset_handler.dataset.metadata[0].data()[ "value" ], - "text": vector_store_cloud.dataset_handler.dataset.text[0].data()["value"], + "text": [ + vector_store_cloud.dataset_handler.dataset.text[0].data()["value"] + ], }, ) assert len(data_ce_f["text"]) == 1 diff --git a/deeplake/core/vectorstore/vector_search/filter/filter.py b/deeplake/core/vectorstore/vector_search/filter/filter.py index 77d50d780f..638fbd4c87 100644 --- a/deeplake/core/vectorstore/vector_search/filter/filter.py +++ b/deeplake/core/vectorstore/vector_search/filter/filter.py @@ -9,7 +9,10 @@ def dp_filter_python(x: dict, filter: Dict) -> bool: - """Filter helper function for Deep Lake""" + """Filter helper function for Deep Lake + For non-dict tensors, perform exact match if target data is not a list, and perform "IN" match if target data is a list. + For dict tensors, perform exact match for each key-value pair in the target data. + """ result = True @@ -22,7 +25,10 @@ def dp_filter_python(x: dict, filter: Dict) -> bool: k in data and v == data[k] for k, v in filter[tensor].items() ) else: - result = result and data == filter[tensor] + if type(filter[tensor]) == list: + result = result and data in filter[tensor] + else: + result = result and data == filter[tensor] return result @@ -50,6 +56,11 @@ def attribute_based_filtering_python( def attribute_based_filtering_tql( view, filter: Optional[Dict] = None, debug_mode=False, logger=None ): + """Filter helper function converting filter dictionary to TQL Deep Lake + For non-dict tensors, perform exact match if target data is not a list, and perform "IN" match if target data is a list. + For dict tensors, perform exact match for each key-value pair in the target data. + """ + tql_filter = "" if filter is not None: @@ -64,13 +75,22 @@ def attribute_based_filtering_tql( val_str = f"'{value}'" if type(value) == str else f"{value}" tql_filter += f"{tensor}['{key}'] == {val_str} and " else: - val_str = ( - f"'{filter[tensor]}'" - if isinstance(filter[tensor], str) - or isinstance(filter[tensor], np.str_) - else f"{filter[tensor]}" - ) - tql_filter += f"{tensor} == {val_str} and " + if type(filter[tensor]) == list: + val_str = str(filter[tensor])[ + 1:-1 + ] # Remove square bracked and add rounded brackets below. + + tql_filter += f"{tensor} in ({val_str}) and " + + else: + val_str = ( + f"'{filter[tensor]}'" + if isinstance(filter[tensor], str) + or isinstance(filter[tensor], np.str_) + else f"{filter[tensor]}" + ) + tql_filter += f"{tensor} == {val_str} and " + tql_filter = tql_filter[:-5] if debug_mode and logger is not None: diff --git a/deeplake/core/vectorstore/vector_search/filter/test_filter.py b/deeplake/core/vectorstore/vector_search/filter/test_filter.py index 2ffb8cfbe1..9af4c3d21b 100644 --- a/deeplake/core/vectorstore/vector_search/filter/test_filter.py +++ b/deeplake/core/vectorstore/vector_search/filter/test_filter.py @@ -9,10 +9,13 @@ def test_attribute_based_filtering(): ds.create_tensor("metadata", htype="json") ds.create_tensor("metadata2", htype="json") ds.create_tensor("text", htype="text") + ds.create_tensor("text2", htype="text") ds.metadata.extend([{"k": 1}, {"k": 2}, {"k": 3}, {"k": 4}]) ds.metadata2.extend([{"kk": "a"}, {"kk": "b"}, {"kk": "c"}, {"kk": "d"}]) ds.text.extend(["AA", "BB", "CC", "DD"]) + ds.text2.extend(["11", "22", "33", "DD"]) + # Test basic filter filter_dict = {"metadata": {"k": 1}, "metadata2": {"kk": "a"}, "text": "AA"} def filter_udf(x): @@ -23,9 +26,7 @@ def filter_udf(x): view_udf = filter_utils.attribute_based_filtering_python(ds, filter=filter_udf) - view_tql, tql_filter = filter_utils.attribute_based_filtering_tql( - ds, filter=filter_dict - ) + view_tql, _ = filter_utils.attribute_based_filtering_tql(ds, filter=filter_dict) assert view_dict.metadata.data()["value"][0] == filter_dict["metadata"] assert view_dict.metadata2.data()["value"][0] == filter_dict["metadata2"] @@ -34,10 +35,24 @@ def filter_udf(x): assert view_udf.metadata.data()["value"][0] == filter_dict["metadata"] assert len(view_tql) == len(ds) - assert ( - tql_filter == "metadata['k'] == 1 and metadata2['kk'] == 'a' and text == 'AA'" + + # Test filter with list + filter_dict_list = {"text2": ["11", "DD"]} + + view_dict_list = filter_utils.attribute_based_filtering_python( + ds, filter=filter_dict_list ) + view_tql_list, _ = filter_utils.attribute_based_filtering_tql( + ds, filter=filter_dict_list + ) + + assert view_dict_list.text2.data()["value"][0] in filter_dict_list["text2"] + assert view_dict_list.text2.data()["value"][1] in filter_dict_list["text2"] + + assert len(view_tql_list) == len(ds) + + # Test bad tensor filter_dict_bad_tensor = { "metadata_bad": {"k": 1}, "metadata2": {"kk": "a"}, diff --git a/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py b/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py index ad75894e51..a7a745fe60 100644 --- a/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py +++ b/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py @@ -53,6 +53,7 @@ def run( tql_filter, return_tensors, ) + view = self._get_view( tql_query, runtime=self.runtime,