From 3b89c9b0210fe04d79068fb60515668837c60361 Mon Sep 17 00:00:00 2001 From: Ketor Date: Fri, 8 Mar 2024 20:25:29 +0800 Subject: [PATCH] [fix][index] Fix vector ids prefilter for bruteforce index. Signed-off-by: Ketor --- src/client/dingodb_client.cc | 2 + src/client/store_client_function.cc | 137 +++++++++++++++++++------- src/vector/vector_index.cc | 33 +++++-- src/vector/vector_index.h | 12 +-- src/vector/vector_index_bruteforce.cc | 8 +- src/vector/vector_index_bruteforce.h | 8 +- src/vector/vector_index_flat.cc | 8 +- src/vector/vector_index_flat.h | 8 +- src/vector/vector_index_hnsw.cc | 9 +- src/vector/vector_index_hnsw.h | 8 +- src/vector/vector_index_ivf_flat.cc | 9 +- src/vector/vector_index_ivf_flat.h | 8 +- src/vector/vector_index_ivf_pq.cc | 8 +- src/vector/vector_index_ivf_pq.h | 8 +- src/vector/vector_index_raw_ivf_pq.cc | 11 ++- src/vector/vector_index_raw_ivf_pq.h | 8 +- src/vector/vector_reader.cc | 100 ++++++++++++------- src/vector/vector_reader.h | 9 +- 18 files changed, 252 insertions(+), 142 deletions(-) diff --git a/src/client/dingodb_client.cc b/src/client/dingodb_client.cc index 21fce9a43..29b7db9f9 100644 --- a/src/client/dingodb_client.cc +++ b/src/client/dingodb_client.cc @@ -157,6 +157,8 @@ DEFINE_bool(with_vector_ids, false, "Search vector with vector ids list default DEFINE_bool(with_scalar_pre_filter, false, "Search vector with scalar data pre filter"); DEFINE_bool(with_scalar_post_filter, false, "Search vector with scalar data post filter"); DEFINE_bool(with_table_pre_filter, false, "Search vector with table data pre filter"); +DEFINE_string(scalar_key, "", "Request scalar_key"); +DEFINE_string(scalar_value, "", "Request scalar_value"); DEFINE_int32(vector_ids_count, 100, "vector ids count"); DEFINE_string(csv_data, "", "csv data"); DEFINE_string(json_data, "", "json data"); diff --git a/src/client/store_client_function.cc b/src/client/store_client_function.cc index 93c25373c..6271a4e45 100644 --- a/src/client/store_client_function.cc +++ b/src/client/store_client_function.cc @@ -70,6 +70,8 @@ DECLARE_string(vector_data); DECLARE_string(csv_data); DECLARE_string(csv_output); DECLARE_int32(dimension); +DECLARE_string(scalar_key); +DECLARE_string(scalar_value); // for calc distance DEFINE_string(vector_data1, "", "vector data 1"); @@ -503,27 +505,66 @@ void SendVectorSearch(int64_t region_id, uint32_t dimension, uint32_t topn) { request.mutable_parameter()->set_vector_filter(::dingodb::pb::common::VectorFilter::SCALAR_FILTER); request.mutable_parameter()->set_vector_filter_type(::dingodb::pb::common::VectorFilterType::QUERY_PRE); - for (int k = 0; k < 2; k++) { - ::dingodb::pb::common::ScalarValue scalar_value; - scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING); - ::dingodb::pb::common::ScalarField* field = scalar_value.add_fields(); - field->set_string_data("value" + std::to_string(k)); + if (FLAGS_scalar_filter_key.empty() || FLAGS_scalar_filter_value.empty()) { + DINGO_LOG(ERROR) << "scalar_filter_key or scalar_filter_value is empty"; + return; + } - vector->mutable_scalar_data()->mutable_scalar_data()->insert({"key" + std::to_string(k), scalar_value}); + if (!FLAGS_scalar_filter_key.empty()) { + dingodb::pb::common::ScalarValue scalar_value; + scalar_value.set_field_type(dingodb::pb::common::ScalarFieldType::STRING); + dingodb::pb::common::ScalarField* field = scalar_value.add_fields(); + field->set_string_data(FLAGS_scalar_filter_value); + + vector->mutable_scalar_data()->mutable_scalar_data()->insert({FLAGS_scalar_filter_key, scalar_value}); + + DINGO_LOG(INFO) << "scalar_filter_key: " << FLAGS_scalar_filter_key + << " scalar_filter_value: " << FLAGS_scalar_filter_value; + } + + if (!FLAGS_scalar_filter_key2.empty()) { + dingodb::pb::common::ScalarValue scalar_value; + scalar_value.set_field_type(dingodb::pb::common::ScalarFieldType::STRING); + dingodb::pb::common::ScalarField* field = scalar_value.add_fields(); + field->set_string_data(FLAGS_scalar_filter_value2); + + vector->mutable_scalar_data()->mutable_scalar_data()->insert({FLAGS_scalar_filter_key2, scalar_value}); + + DINGO_LOG(INFO) << "scalar_filter_key2: " << FLAGS_scalar_filter_key2 + << " scalar_filter_value2: " << FLAGS_scalar_filter_value2; } } if (FLAGS_with_scalar_post_filter) { request.mutable_parameter()->set_vector_filter(::dingodb::pb::common::VectorFilter::SCALAR_FILTER); request.mutable_parameter()->set_vector_filter_type(::dingodb::pb::common::VectorFilterType::QUERY_POST); + if (FLAGS_scalar_filter_key.empty() || FLAGS_scalar_filter_value.empty()) { + DINGO_LOG(ERROR) << "scalar_filter_key or scalar_filter_value is empty"; + return; + } - for (int k = 0; k < 2; k++) { - ::dingodb::pb::common::ScalarValue scalar_value; - scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING); - ::dingodb::pb::common::ScalarField* field = scalar_value.add_fields(); - field->set_string_data("value" + std::to_string(k)); + if (!FLAGS_scalar_filter_key.empty()) { + dingodb::pb::common::ScalarValue scalar_value; + scalar_value.set_field_type(dingodb::pb::common::ScalarFieldType::STRING); + dingodb::pb::common::ScalarField* field = scalar_value.add_fields(); + field->set_string_data(FLAGS_scalar_filter_value); - vector->mutable_scalar_data()->mutable_scalar_data()->insert({"key" + std::to_string(k), scalar_value}); + vector->mutable_scalar_data()->mutable_scalar_data()->insert({FLAGS_scalar_filter_key, scalar_value}); + + DINGO_LOG(INFO) << "scalar_filter_key: " << FLAGS_scalar_filter_key + << " scalar_filter_value: " << FLAGS_scalar_filter_value; + } + + if (!FLAGS_scalar_filter_key2.empty()) { + dingodb::pb::common::ScalarValue scalar_value; + scalar_value.set_field_type(dingodb::pb::common::ScalarFieldType::STRING); + dingodb::pb::common::ScalarField* field = scalar_value.add_fields(); + field->set_string_data(FLAGS_scalar_filter_value2); + + vector->mutable_scalar_data()->mutable_scalar_data()->insert({FLAGS_scalar_filter_key2, scalar_value}); + + DINGO_LOG(INFO) << "scalar_filter_key2: " << FLAGS_scalar_filter_key2 + << " scalar_filter_value2: " << FLAGS_scalar_filter_value2; } } @@ -580,21 +621,23 @@ void SendVectorSearch(int64_t region_id, uint32_t dimension, uint32_t topn) { } } - for (const auto& vector_with_distance : batch_result.vector_with_distances()) { - std::vector x_i = dingodb::Helper::PbRepeatedToVector(vector->vector().float_values()); - std::vector y_j = - dingodb::Helper::PbRepeatedToVector(vector_with_distance.vector_with_id().vector().float_values()); + if (!FLAGS_without_vector) { + for (const auto& vector_with_distance : batch_result.vector_with_distances()) { + std::vector x_i = dingodb::Helper::PbRepeatedToVector(vector->vector().float_values()); + std::vector y_j = + dingodb::Helper::PbRepeatedToVector(vector_with_distance.vector_with_id().vector().float_values()); - auto faiss_l2 = dingodb::Helper::DingoFaissL2sqr(x_i.data(), y_j.data(), dimension); - auto faiss_ip = dingodb::Helper::DingoFaissInnerProduct(x_i.data(), y_j.data(), dimension); - auto hnsw_l2 = dingodb::Helper::DingoHnswL2Sqr(x_i.data(), y_j.data(), dimension); - auto hnsw_ip = dingodb::Helper::DingoHnswInnerProduct(x_i.data(), y_j.data(), dimension); - auto hnsw_ip_dist = dingodb::Helper::DingoHnswInnerProductDistance(x_i.data(), y_j.data(), dimension); + auto faiss_l2 = dingodb::Helper::DingoFaissL2sqr(x_i.data(), y_j.data(), dimension); + auto faiss_ip = dingodb::Helper::DingoFaissInnerProduct(x_i.data(), y_j.data(), dimension); + auto hnsw_l2 = dingodb::Helper::DingoHnswL2Sqr(x_i.data(), y_j.data(), dimension); + auto hnsw_ip = dingodb::Helper::DingoHnswInnerProduct(x_i.data(), y_j.data(), dimension); + auto hnsw_ip_dist = dingodb::Helper::DingoHnswInnerProductDistance(x_i.data(), y_j.data(), dimension); - DINGO_LOG(INFO) << "vector_id: " << vector_with_distance.vector_with_id().id() - << ", distance: " << vector_with_distance.distance() << ", [faiss_l2: " << faiss_l2 - << ", faiss_ip: " << faiss_ip << ", hnsw_l2: " << hnsw_l2 << ", hnsw_ip: " << hnsw_ip - << ", hnsw_ip_dist: " << hnsw_ip_dist << "]"; + DINGO_LOG(INFO) << "vector_id: " << vector_with_distance.vector_with_id().id() + << ", distance: " << vector_with_distance.distance() << ", [faiss_l2: " << faiss_l2 + << ", faiss_ip: " << faiss_ip << ", hnsw_l2: " << hnsw_l2 << ", hnsw_ip: " << hnsw_ip + << ", hnsw_ip_dist: " << hnsw_ip_dist << "]"; + } } } @@ -1732,19 +1775,37 @@ int SendBatchVectorAdd(int64_t region_id, uint32_t dimension, std::vectormutable_scalar_data()->mutable_scalar_data(); - dingodb::pb::common::ScalarValue scalar_value; - scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING); - scalar_value.add_fields()->set_string_data(fmt::format("scalar_value{}", k)); - (*scalar_data)[fmt::format("scalar_key{}", k)] = scalar_value; - } - for (int k = 2; k < 4; ++k) { - auto* scalar_data = vector_with_id->mutable_scalar_data()->mutable_scalar_data(); - dingodb::pb::common::ScalarValue scalar_value; - scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::INT64); - scalar_value.add_fields()->set_long_data(k); - (*scalar_data)[fmt::format("scalar_key{}", k)] = scalar_value; + if (FLAGS_scalar_filter_key.empty() || FLAGS_scalar_filter_value.empty()) { + for (int k = 0; k < 2; ++k) { + auto* scalar_data = vector_with_id->mutable_scalar_data()->mutable_scalar_data(); + dingodb::pb::common::ScalarValue scalar_value; + scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING); + scalar_value.add_fields()->set_string_data(fmt::format("scalar_value{}", k)); + (*scalar_data)[fmt::format("scalar_key{}", k)] = scalar_value; + } + for (int k = 2; k < 4; ++k) { + auto* scalar_data = vector_with_id->mutable_scalar_data()->mutable_scalar_data(); + dingodb::pb::common::ScalarValue scalar_value; + scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::INT64); + scalar_value.add_fields()->set_long_data(k); + (*scalar_data)[fmt::format("scalar_key{}", k)] = scalar_value; + } + } else { + if (!FLAGS_scalar_filter_key.empty()) { + auto* scalar_data = vector_with_id->mutable_scalar_data()->mutable_scalar_data(); + dingodb::pb::common::ScalarValue scalar_value; + scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING); + scalar_value.add_fields()->set_string_data(FLAGS_scalar_filter_value); + (*scalar_data)[FLAGS_scalar_filter_key] = scalar_value; + } + + if (!FLAGS_scalar_filter_key2.empty()) { + auto* scalar_data = vector_with_id->mutable_scalar_data()->mutable_scalar_data(); + dingodb::pb::common::ScalarValue scalar_value; + scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING); + scalar_value.add_fields()->set_string_data(FLAGS_scalar_filter_value2); + (*scalar_data)[FLAGS_scalar_filter_key2] = scalar_value; + } } } diff --git a/src/vector/vector_index.cc b/src/vector/vector_index.cc index 7fa77049f..730a1c8de 100644 --- a/src/vector/vector_index.cc +++ b/src/vector/vector_index.cc @@ -858,7 +858,12 @@ butil::Status VectorIndexWrapper::Search(std::vector v if (region_range.start_key() != index_range.start_key() || region_range.end_key() != index_range.end_key()) { int64_t min_vector_id = 0, max_vector_id = 0; VectorCodec::DecodeRangeToVectorId(region_range, min_vector_id, max_vector_id); - VectorIndexWrapper::SetVectorIndexFilter(vector_index, filters, min_vector_id, max_vector_id); + auto ret = VectorIndexWrapper::SetVectorIndexRangeFilter(vector_index, filters, min_vector_id, max_vector_id); + if (!ret.ok()) { + DINGO_LOG(ERROR) << fmt::format("[vector_index.wrapper][index_id({})] set vector index filter failed, error: {}", + Id(), ret.error_str()); + return ret; + } } return vector_index->Search(vector_with_ids, topk, filters, reconstruct, parameter, results); @@ -914,7 +919,12 @@ butil::Status VectorIndexWrapper::RangeSearch(std::vectorRangeSearch(vector_with_ids, radius, filters, reconstruct, parameter, results); @@ -935,14 +945,13 @@ bool VectorIndexWrapper::IsPermanentHoldVectorIndex(int64_t region_id) { return true; } -butil::Status VectorIndexWrapper::SetVectorIndexFilter( - VectorIndexPtr vector_index, - std::vector>& filters, // NOLINT +butil::Status VectorIndexWrapper::SetVectorIndexRangeFilter( + VectorIndexPtr vector_index, std::vector>& filters, int64_t min_vector_id, int64_t max_vector_id) { if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_HNSW) { filters.push_back(std::make_shared(min_vector_id, max_vector_id)); - } else if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_FLAT) { - // filters.push_back(std::make_shared(min_vector_id, max_vector_id)); + } else if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_FLAT || + vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_BRUTEFORCE) { filters.push_back(std::make_shared(min_vector_id, max_vector_id)); } else if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_IVF_FLAT) { filters.push_back(std::make_shared(min_vector_id, max_vector_id)); @@ -952,8 +961,16 @@ butil::Status VectorIndexWrapper::SetVectorIndexFilter( } else if (vector_index->VectorIndexSubType() == pb::common::VECTOR_INDEX_TYPE_FLAT) { filters.push_back(std::make_shared(min_vector_id, max_vector_id)); } else { - // do nothing + return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT, + fmt::format("SetVectorIndexFilter not support index type: {} sub type: {}", + pb::common::VectorIndexType_Name(vector_index->VectorIndexType()), + pb::common::VectorIndexType_Name(vector_index->VectorIndexSubType()))); } + } else { + return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT, + fmt::format("SetVectorIndexFilter not support index type: {} sub type: {}", + pb::common::VectorIndexType_Name(vector_index->VectorIndexType()), + pb::common::VectorIndexType_Name(vector_index->VectorIndexSubType()))); } return butil::Status::OK(); diff --git a/src/vector/vector_index.h b/src/vector/vector_index.h index 97439fd25..1e0d87937 100644 --- a/src/vector/vector_index.h +++ b/src/vector/vector_index.h @@ -150,14 +150,14 @@ class VectorIndex { virtual butil::Status Load(const std::string& path); - virtual butil::Status Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool reconstruct, + virtual butil::Status Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) = 0; - virtual butil::Status RangeSearch(std::vector vector_with_ids, float radius, - std::vector> filters, bool reconstruct, - const pb::common::VectorSearchParameter& parameter, + virtual butil::Status RangeSearch(const std::vector& vector_with_ids, float radius, + const std::vector>& filters, + bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) = 0; virtual void LockWrite() = 0; @@ -359,7 +359,7 @@ class VectorIndexWrapper : public std::enable_shared_from_this& results); - static butil::Status SetVectorIndexFilter( + static butil::Status SetVectorIndexRangeFilter( VectorIndexPtr vector_index, std::vector>& filters, // NOLINT int64_t min_vector_id, int64_t max_vector_id); diff --git a/src/vector/vector_index_bruteforce.cc b/src/vector/vector_index_bruteforce.cc index 386076b7d..90f9b91f8 100644 --- a/src/vector/vector_index_bruteforce.cc +++ b/src/vector/vector_index_bruteforce.cc @@ -48,16 +48,16 @@ butil::Status VectorIndexBruteforce::Add(const std::vector& /*delete_ids*/) { return butil::Status::OK(); } -butil::Status VectorIndexBruteforce::Search(std::vector /*vector_with_ids*/, - uint32_t /*topk*/, std::vector> /*filters*/, +butil::Status VectorIndexBruteforce::Search(const std::vector& /*vector_with_ids*/, + uint32_t /*topk*/, const std::vector>& /*filters*/, bool, const pb::common::VectorSearchParameter&, std::vector& /*results*/) { return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT, "not support"); } -butil::Status VectorIndexBruteforce::RangeSearch(std::vector /*vector_with_ids*/, +butil::Status VectorIndexBruteforce::RangeSearch(const std::vector& /*vector_with_ids*/, float /*radius*/, - std::vector> /*filters*/, + const std::vector>& /*filters*/, bool /*reconstruct*/, const pb::common::VectorSearchParameter& /*parameter*/, std::vector& /*results*/) { diff --git a/src/vector/vector_index_bruteforce.h b/src/vector/vector_index_bruteforce.h index 0801442ac..6223331a6 100644 --- a/src/vector/vector_index_bruteforce.h +++ b/src/vector/vector_index_bruteforce.h @@ -51,13 +51,13 @@ class VectorIndexBruteforce : public VectorIndex { butil::Status Delete(const std::vector& delete_ids) override; - butil::Status Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool reconstruct, + butil::Status Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; - butil::Status RangeSearch(std::vector vector_with_ids, float radius, - std::vector> filters, bool reconstruct, + butil::Status RangeSearch(const std::vector& vector_with_ids, float radius, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; diff --git a/src/vector/vector_index_flat.cc b/src/vector/vector_index_flat.cc index 47de3aa63..6b879158a 100644 --- a/src/vector/vector_index_flat.cc +++ b/src/vector/vector_index_flat.cc @@ -166,8 +166,8 @@ butil::Status VectorIndexFlat::Delete(const std::vector& delete_ids) { return butil::Status::OK(); } -butil::Status VectorIndexFlat::Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool, +butil::Status VectorIndexFlat::Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, bool, const pb::common::VectorSearchParameter&, std::vector& results) { if (vector_with_ids.empty()) { @@ -230,8 +230,8 @@ butil::Status VectorIndexFlat::Search(std::vector vect return butil::Status::OK(); } -butil::Status VectorIndexFlat::RangeSearch(std::vector vector_with_ids, float radius, - std::vector> filters, +butil::Status VectorIndexFlat::RangeSearch(const std::vector& vector_with_ids, float radius, + const std::vector>& filters, bool /*reconstruct*/, const pb::common::VectorSearchParameter& /*parameter*/, std::vector& results) { if (vector_with_ids.empty()) { diff --git a/src/vector/vector_index_flat.h b/src/vector/vector_index_flat.h index 88d3e48bd..8b730afbe 100644 --- a/src/vector/vector_index_flat.h +++ b/src/vector/vector_index_flat.h @@ -81,13 +81,13 @@ class VectorIndexFlat : public VectorIndex { butil::Status Delete(const std::vector& delete_ids) override; - butil::Status Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool reconstruct, + butil::Status Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; - butil::Status RangeSearch(std::vector vector_with_ids, float radius, - std::vector> filters, bool reconstruct, + butil::Status RangeSearch(const std::vector& vector_with_ids, float radius, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; diff --git a/src/vector/vector_index_hnsw.cc b/src/vector/vector_index_hnsw.cc index 253d6d39e..83bbec4d6 100644 --- a/src/vector/vector_index_hnsw.cc +++ b/src/vector/vector_index_hnsw.cc @@ -296,8 +296,8 @@ butil::Status VectorIndexHnsw::Load(const std::string& path) { } } -butil::Status VectorIndexHnsw::Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool reconstruct, +butil::Status VectorIndexHnsw::Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& search_parameter, std::vector& results) { if (vector_with_ids.empty()) { @@ -484,8 +484,9 @@ butil::Status VectorIndexHnsw::Search(std::vector vect return butil::Status::OK(); } -butil::Status VectorIndexHnsw::RangeSearch(std::vector /*vector_with_ids*/, float /*radius*/, - std::vector> /*filters*/, +butil::Status VectorIndexHnsw::RangeSearch(const std::vector& /*vector_with_ids*/, + float /*radius*/, + const std::vector>& /*filters*/, bool /*reconstruct*/, const pb::common::VectorSearchParameter& /*parameter*/, std::vector& /*results*/) { return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT, "RangeSearch not support in Hnsw!!!"); diff --git a/src/vector/vector_index_hnsw.h b/src/vector/vector_index_hnsw.h index 189a765f1..c4e6595fe 100644 --- a/src/vector/vector_index_hnsw.h +++ b/src/vector/vector_index_hnsw.h @@ -60,13 +60,13 @@ class VectorIndexHnsw : public VectorIndex { void LockWrite() override; void UnlockWrite() override; - butil::Status Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool reconstruct, + butil::Status Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; - butil::Status RangeSearch(std::vector vector_with_ids, float radius, - std::vector> filters, bool reconstruct, + butil::Status RangeSearch(const std::vector& vector_with_ids, float radius, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; diff --git a/src/vector/vector_index_ivf_flat.cc b/src/vector/vector_index_ivf_flat.cc index f30dc5435..159c19aac 100644 --- a/src/vector/vector_index_ivf_flat.cc +++ b/src/vector/vector_index_ivf_flat.cc @@ -183,8 +183,8 @@ butil::Status VectorIndexIvfFlat::Delete(const std::vector& delete_ids) return butil::Status::OK(); } -butil::Status VectorIndexIvfFlat::Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool, +butil::Status VectorIndexIvfFlat::Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, bool, const pb::common::VectorSearchParameter& parameter, std::vector& results) { // NOLINT if (vector_with_ids.empty()) { @@ -267,8 +267,9 @@ butil::Status VectorIndexIvfFlat::Search(std::vector v return butil::Status::OK(); } -butil::Status VectorIndexIvfFlat::RangeSearch(std::vector vector_with_ids, float radius, - std::vector> filters, +butil::Status VectorIndexIvfFlat::RangeSearch(const std::vector& vector_with_ids, + float radius, + const std::vector>& filters, bool /*reconstruct*/, const pb::common::VectorSearchParameter& parameter, std::vector& results) { if (vector_with_ids.empty()) { diff --git a/src/vector/vector_index_ivf_flat.h b/src/vector/vector_index_ivf_flat.h index b7435c152..ecf619e35 100644 --- a/src/vector/vector_index_ivf_flat.h +++ b/src/vector/vector_index_ivf_flat.h @@ -80,13 +80,13 @@ class VectorIndexIvfFlat : public VectorIndex { butil::Status Delete(const std::vector& delete_ids) override; - butil::Status Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool reconstruct, + butil::Status Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; - butil::Status RangeSearch(std::vector vector_with_ids, float radius, - std::vector> filters, bool reconstruct, + butil::Status RangeSearch(const std::vector& vector_with_ids, float radius, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; diff --git a/src/vector/vector_index_ivf_pq.cc b/src/vector/vector_index_ivf_pq.cc index d9ce4bbd6..5c7223b00 100644 --- a/src/vector/vector_index_ivf_pq.cc +++ b/src/vector/vector_index_ivf_pq.cc @@ -145,8 +145,8 @@ butil::Status VectorIndexIvfPq::Delete(const std::vector& delete_ids) { return butil::Status::OK(); } -butil::Status VectorIndexIvfPq::Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool reconstruct, +butil::Status VectorIndexIvfPq::Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) { BvarLatencyGuard bvar_guard(&g_ivf_pq_search_latency); @@ -167,8 +167,8 @@ butil::Status VectorIndexIvfPq::Search(std::vector vec return butil::Status::OK(); } -butil::Status VectorIndexIvfPq::RangeSearch(std::vector vector_with_ids, float radius, - std::vector> filters, +butil::Status VectorIndexIvfPq::RangeSearch(const std::vector& vector_with_ids, float radius, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) { BvarLatencyGuard bvar_guard(&g_ivf_pq_range_search_latency); diff --git a/src/vector/vector_index_ivf_pq.h b/src/vector/vector_index_ivf_pq.h index dab419d27..e7165d293 100644 --- a/src/vector/vector_index_ivf_pq.h +++ b/src/vector/vector_index_ivf_pq.h @@ -56,13 +56,13 @@ class VectorIndexIvfPq : public VectorIndex { butil::Status Delete(const std::vector& delete_ids) override; - butil::Status Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool reconstruct, + butil::Status Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; - butil::Status RangeSearch(std::vector vector_with_ids, float radius, - std::vector> filters, bool reconstruct, + butil::Status RangeSearch(const std::vector& vector_with_ids, float radius, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; diff --git a/src/vector/vector_index_raw_ivf_pq.cc b/src/vector/vector_index_raw_ivf_pq.cc index 5e310137c..291b26d9e 100644 --- a/src/vector/vector_index_raw_ivf_pq.cc +++ b/src/vector/vector_index_raw_ivf_pq.cc @@ -192,9 +192,9 @@ butil::Status VectorIndexRawIvfPq::Delete(const std::vector& delete_ids return butil::Status::OK(); } -butil::Status VectorIndexRawIvfPq::Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool /*reconstruct*/, - const pb::common::VectorSearchParameter& parameter, +butil::Status VectorIndexRawIvfPq::Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, + bool /*reconstruct*/, const pb::common::VectorSearchParameter& parameter, std::vector& results) { // NOLINT if (vector_with_ids.empty()) { DINGO_LOG(WARNING) << "vector_with_ids is empty"; @@ -272,8 +272,9 @@ butil::Status VectorIndexRawIvfPq::Search(std::vector return butil::Status::OK(); } -butil::Status VectorIndexRawIvfPq::RangeSearch(std::vector vector_with_ids, float radius, - std::vector> filters, +butil::Status VectorIndexRawIvfPq::RangeSearch(const std::vector& vector_with_ids, + float radius, + const std::vector>& filters, bool /*reconstruct*/, const pb::common::VectorSearchParameter& parameter, std::vector& results) { if (vector_with_ids.empty()) { diff --git a/src/vector/vector_index_raw_ivf_pq.h b/src/vector/vector_index_raw_ivf_pq.h index 381a86282..a617b8367 100644 --- a/src/vector/vector_index_raw_ivf_pq.h +++ b/src/vector/vector_index_raw_ivf_pq.h @@ -80,13 +80,13 @@ class VectorIndexRawIvfPq : public VectorIndex { butil::Status Delete(const std::vector& delete_ids) override; - butil::Status Search(std::vector vector_with_ids, uint32_t topk, - std::vector> filters, bool reconstruct, + butil::Status Search(const std::vector& vector_with_ids, uint32_t topk, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; - butil::Status RangeSearch(std::vector vector_with_ids, float radius, - std::vector> filters, bool reconstruct, + butil::Status RangeSearch(const std::vector& vector_with_ids, float radius, + const std::vector>& filters, bool reconstruct, const pb::common::VectorSearchParameter& parameter, std::vector& results) override; diff --git a/src/vector/vector_reader.cc b/src/vector/vector_reader.cc index 2fcf88c72..9abf07755 100644 --- a/src/vector/vector_reader.cc +++ b/src/vector/vector_reader.cc @@ -797,10 +797,15 @@ butil::Status VectorReader::DoVectorSearchForVectorIdPreFilter( // NOLINT const pb::common::VectorSearchParameter& parameter, const pb::common::Range& region_range, std::vector& vector_with_distance_results) { std::vector> filters; - VectorReader::SetVectorIndexFilter(vector_index, filters, Helper::PbRepeatedToVector(parameter.vector_ids())); + auto status = + VectorReader::SetVectorIndexIdsFilter(vector_index, filters, Helper::PbRepeatedToVector(parameter.vector_ids())); + if (!status.ok()) { + DINGO_LOG(ERROR) << status.error_str(); + return status; + } - butil::Status status = VectorReader::SearchAndRangeSearchWrapper( - vector_index, region_range, vector_with_ids, parameter, vector_with_distance_results, parameter.top_n(), filters); + status = VectorReader::SearchAndRangeSearchWrapper(vector_index, region_range, vector_with_ids, parameter, + vector_with_distance_results, parameter.top_n(), filters); if (!status.ok()) { DINGO_LOG(ERROR) << status.error_cstr(); return status; @@ -816,7 +821,14 @@ butil::Status VectorReader::DoVectorSearchForScalarPreFilter( // scalar pre filter search butil::Status status; -#if !defined(ENABLE_SCALAR_WITH_COPROCESSOR) + bool use_coprocessor = parameter.has_vector_coprocessor(); + + if (!use_coprocessor && vector_with_ids[0].scalar_data().scalar_data_size() == 0) { + std::string s = fmt::format("vector_with_ids[0].scalar_data() empty not support"); + DINGO_LOG(ERROR) << s; + return butil::Status(pb::error::Errno::EILLEGAL_PARAMTETERS, s); + } + const auto& std_vector_scalar = vector_with_ids[0].scalar_data(); auto lambda_scalar_compare_function = [&std_vector_scalar](const pb::common::VectorScalardata& internal_vector_scalar) { @@ -834,21 +846,16 @@ butil::Status VectorReader::DoVectorSearchForScalarPreFilter( return true; }; -#else // ENABLE_SCALAR_WITH_COPROCESSOR - if (!parameter.has_vector_coprocessor()) { - std::string s = fmt::format("vector_coprocessor empty not support"); - DINGO_LOG(ERROR) << s; - return butil::Status(pb::error::Errno::EILLEGAL_PARAMTETERS, s); - } - const auto& coprocessor = parameter.vector_coprocessor(); std::shared_ptr scalar_coprocessor = std::make_shared(); - status = scalar_coprocessor->Open(CoprocessorPbWrapper{coprocessor}); - if (!status.ok()) { - DINGO_LOG(ERROR) << "scalar coprocessor::Open failed " << status.error_cstr(); - return status; + if (use_coprocessor) { + status = scalar_coprocessor->Open(CoprocessorPbWrapper{coprocessor}); + if (!status.ok()) { + DINGO_LOG(ERROR) << "scalar coprocessor::Open failed " << status.error_cstr(); + return status; + } } auto lambda_scalar_compare_with_coprocessor_function = @@ -863,10 +870,6 @@ butil::Status VectorReader::DoVectorSearchForScalarPreFilter( return is_reverse; }; -#endif // #if defined(!ENABLE_SCALAR_WITH_COPROCESSOR) - - // std::string start_key = VectorCodec::FillVectorScalarPrefix(region_range.start_key()); - // std::string end_key = VectorCodec::FillVectorScalarPrefix(region_range.end_key()); const std::string& start_key = region_range.start_key(); const std::string& end_key = region_range.end_key(); @@ -889,11 +892,9 @@ butil::Status VectorReader::DoVectorSearchForScalarPreFilter( return butil::Status(pb::error::EINTERNAL, "Internal error, decode VectorScalar failed"); } -#if !defined(ENABLE_SCALAR_WITH_COPROCESSOR) - bool compare_result = lambda_scalar_compare_function(internal_vector_scalar); -#else - bool compare_result = lambda_scalar_compare_with_coprocessor_function(internal_vector_scalar); -#endif + bool compare_result = use_coprocessor ? lambda_scalar_compare_with_coprocessor_function(internal_vector_scalar) + : lambda_scalar_compare_function(internal_vector_scalar); + if (compare_result) { std::string key(iter->Key()); int64_t internal_vector_id = VectorCodec::DecodeVectorId(key); @@ -907,7 +908,11 @@ butil::Status VectorReader::DoVectorSearchForScalarPreFilter( } std::vector> filters; - VectorReader::SetVectorIndexFilter(vector_index, filters, vector_ids); + status = VectorReader::SetVectorIndexIdsFilter(vector_index, filters, vector_ids); + if (!status.ok()) { + DINGO_LOG(ERROR) << status.error_cstr(); + return status; + } status = VectorReader::SearchAndRangeSearchWrapper(vector_index, region_range, vector_with_ids, parameter, vector_with_distance_results, parameter.top_n(), filters); @@ -997,7 +1002,11 @@ butil::Status VectorReader::DoVectorSearchForTableCoprocessor( // NOLINT(*stati } std::vector> filters; - VectorReader::SetVectorIndexFilter(vector_index, filters, vector_ids); + status = VectorReader::SetVectorIndexIdsFilter(vector_index, filters, vector_ids); + if (!status.ok()) { + DINGO_LOG(ERROR) << status.error_cstr(); + return status; + } status = VectorReader::SearchAndRangeSearchWrapper(vector_index, region_range, vector_with_ids, parameter, vector_with_distance_results, parameter.top_n(), filters); @@ -1225,15 +1234,20 @@ butil::Status VectorReader::DoVectorSearchForVectorIdPreFilterDebug( // NOLINT std::vector> filters; auto start_ids = lambda_time_now_function(); - VectorReader::SetVectorIndexFilter(vector_index, filters, Helper::PbRepeatedToVector(parameter.vector_ids())); + auto status = + VectorReader::SetVectorIndexIdsFilter(vector_index, filters, Helper::PbRepeatedToVector(parameter.vector_ids())); + if (!status.ok()) { + DINGO_LOG(ERROR) << status.error_cstr(); + return status; + } auto end_ids = lambda_time_now_function(); deserialization_id_time_us = lambda_time_diff_microseconds_function(start_ids, end_ids); auto start_search = lambda_time_now_function(); - butil::Status status = VectorReader::SearchAndRangeSearchWrapper( - vector_index, region_range, vector_with_ids, parameter, vector_with_distance_results, parameter.top_n(), filters); + status = VectorReader::SearchAndRangeSearchWrapper(vector_index, region_range, vector_with_ids, parameter, + vector_with_distance_results, parameter.top_n(), filters); if (!status.ok()) { DINGO_LOG(ERROR) << status.error_cstr(); return status; @@ -1350,7 +1364,11 @@ butil::Status VectorReader::DoVectorSearchForScalarPreFilterDebug( std::vector> filters; - VectorReader::SetVectorIndexFilter(vector_index, filters, vector_ids); + status = VectorReader::SetVectorIndexIdsFilter(vector_index, filters, vector_ids); + if (!status.ok()) { + DINGO_LOG(ERROR) << status.error_cstr(); + return status; + } auto start_search = lambda_time_now_function(); @@ -1366,12 +1384,13 @@ butil::Status VectorReader::DoVectorSearchForScalarPreFilterDebug( return butil::Status::OK(); } -butil::Status VectorReader::SetVectorIndexFilter(VectorIndexWrapperPtr vector_index, - std::vector>& filters, - const std::vector& vector_ids) { +butil::Status VectorReader::SetVectorIndexIdsFilter(VectorIndexWrapperPtr vector_index, + std::vector>& filters, + const std::vector& vector_ids) { if (vector_index->Type() == pb::common::VECTOR_INDEX_TYPE_HNSW) { filters.push_back(std::make_shared(vector_ids)); - } else if (vector_index->Type() == pb::common::VECTOR_INDEX_TYPE_FLAT) { + } else if (vector_index->Type() == pb::common::VECTOR_INDEX_TYPE_FLAT || + vector_index->Type() == pb::common::VECTOR_INDEX_TYPE_BRUTEFORCE) { filters.push_back(std::make_shared(vector_ids)); } else if (vector_index->Type() == pb::common::VECTOR_INDEX_TYPE_IVF_FLAT) { filters.push_back(std::make_shared(vector_ids)); @@ -1381,9 +1400,18 @@ butil::Status VectorReader::SetVectorIndexFilter(VectorIndexWrapperPtr vector_in } else if (vector_index->SubType() == pb::common::VECTOR_INDEX_TYPE_FLAT) { filters.push_back(std::make_shared(vector_ids)); } else { - // do nothing + return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT, + fmt::format("SetVectorIndexFilter not support index type: {} sub type: {}", + pb::common::VectorIndexType_Name(vector_index->Type()), + pb::common::VectorIndexType_Name(vector_index->SubType()))); } + } else { + return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT, + fmt::format("SetVectorIndexFilter not support index type: {} sub type: {}", + pb::common::VectorIndexType_Name(vector_index->Type()), + pb::common::VectorIndexType_Name(vector_index->SubType()))); } + return butil::Status::OK(); } @@ -1435,7 +1463,7 @@ butil::Status VectorReader::SearchAndRangeSearchWrapper( status = vector_index->Search(vector_with_ids, topk, region_range, filters, with_vector_data, parameter, vector_with_distance_results); if (status.error_code() == pb::error::Errno::EVECTOR_NOT_SUPPORT) { - DINGO_LOG(INFO) << "Search vector index not support, try brute force, id: " << vector_index->Id(); + DINGO_LOG(DEBUG) << "Search vector index not support, try brute force, id: " << vector_index->Id(); return BruteForceSearch(vector_index, vector_with_ids, topk, region_range, filters, with_vector_data, parameter, vector_with_distance_results); } else if (!status.ok()) { diff --git a/src/vector/vector_reader.h b/src/vector/vector_reader.h index ab299e0ee..4257d9b4d 100644 --- a/src/vector/vector_reader.h +++ b/src/vector/vector_reader.h @@ -134,12 +134,11 @@ class VectorReader { VectorIndexWrapperPtr vector_index, pb::common::Range region_range, const std::vector& vector_with_ids, const pb::common::VectorSearchParameter& parameter, std::vector& vector_with_distance_results, int64_t& scan_scalar_time_us, - int64_t& search_time_us); // NOLINT + int64_t& search_time_us); - static butil::Status SetVectorIndexFilter( - VectorIndexWrapperPtr vector_index, - std::vector>& filters, // NOLINT - const std::vector& vector_ids); + static butil::Status SetVectorIndexIdsFilter(VectorIndexWrapperPtr vector_index, + std::vector>& filters, + const std::vector& vector_ids); butil::Status SearchAndRangeSearchWrapper( VectorIndexWrapperPtr vector_index, pb::common::Range region_range,