diff --git a/src/vector/vector_reader.cc b/src/vector/vector_reader.cc index 7f9fa2a45..48b2e8862 100644 --- a/src/vector/vector_reader.cc +++ b/src/vector/vector_reader.cc @@ -93,15 +93,18 @@ butil::Status VectorReader::SearchVector( uint32_t top_n = parameter.top_n(); bool enable_range_search = parameter.enable_range_search(); - if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() == 0)) { + if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() == 0) && + !parameter.has_vector_coprocessor()) { butil::Status status = VectorReader::SearchAndRangeSearchWrapper( vector_index, region_range, vector_with_ids, parameter, vector_with_distance_results, top_n, {}); if (!status.ok()) { DINGO_LOG(ERROR) << status.error_cstr(); return status; } - - } else { + } else if (parameter.has_vector_coprocessor()) { + if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() != 0)) { + DINGO_LOG(WARNING) << "vector_with_ids[0].scalar_data() deprecated. use coprocessor."; + } top_n *= 10; butil::Status status = VectorReader::SearchAndRangeSearchWrapper(vector_index, region_range, vector_with_ids, parameter, tmp_results, top_n, {}); @@ -110,17 +113,26 @@ butil::Status VectorReader::SearchVector( return status; } + std::shared_ptr scalar_coprocessor = std::make_shared(); + + status = scalar_coprocessor->Open(CoprocessorPbWrapper{parameter.vector_coprocessor()}); + if (!status.ok()) { + DINGO_LOG(ERROR) << "scalar coprocessor::Open failed " << status.error_cstr(); + return status; + } + for (auto& vector_with_distance_result : tmp_results) { pb::index::VectorWithDistanceResult new_vector_with_distance_result; for (auto& temp_vector_with_distance : *vector_with_distance_result.mutable_vector_with_distances()) { int64_t temp_id = temp_vector_with_distance.vector_with_id().id(); bool compare_result = false; - butil::Status status = CompareVectorScalarData(region_range, partition_id, temp_id, - vector_with_ids[0].scalar_data(), compare_result); + butil::Status status = CompareVectorScalarDataWithCoprocessor(region_range, partition_id, temp_id, + scalar_coprocessor, compare_result); if (!status.ok()) { return status; } + if (!compare_result) { continue; } @@ -135,6 +147,12 @@ butil::Status VectorReader::SearchVector( } vector_with_distance_results.emplace_back(std::move(new_vector_with_distance_result)); } + + } else { //! parameter.has_vector_coprocessor() && vector_with_ids[0].scalar_data().scalar_data_size() != 0 + butil::Status status; + std::string s = fmt::format("CompareVectorScalarData deprecated. use coprocessor instead."); + DINGO_LOG(ERROR) << s; + return butil::Status(pb::error::EVECTOR_NOT_SUPPORT, s); } } else if (dingodb::pb::common::VectorFilter::VECTOR_ID_FILTER == vector_filter) { // vector id array search butil::Status status = DoVectorSearchForVectorIdPreFilter(vector_index, vector_with_ids, parameter, region_range, @@ -143,7 +161,6 @@ butil::Status VectorReader::SearchVector( DINGO_LOG(ERROR) << fmt::format("DoVectorSearchForVectorIdPreFilter failed"); return status; } - } else if (dingodb::pb::common::VectorFilter::SCALAR_FILTER == vector_filter && dingodb::pb::common::VectorFilterType::QUERY_PRE == vector_filter_type) { // scalar pre filter search @@ -153,7 +170,6 @@ butil::Status VectorReader::SearchVector( DINGO_LOG(ERROR) << fmt::format("DoVectorSearchForScalarPreFilter failed "); return status; } - } else if (dingodb::pb::common::VectorFilter::TABLE_FILTER == vector_filter) { // table coprocessor pre filter search. not impl butil::Status status = DoVectorSearchForTableCoprocessor(vector_index, region_range, vector_with_ids, parameter, @@ -325,6 +341,43 @@ butil::Status VectorReader::CompareVectorScalarData(const pb::common::Range& reg return butil::Status(); } +butil::Status VectorReader::CompareVectorScalarDataWithCoprocessor( + const pb::common::Range& region_range, int64_t partition_id, int64_t vector_id, + const std::shared_ptr& scalar_coprocessor, bool& compare_result) { + compare_result = false; + std::string key, value; + + VectorCodec::EncodeVectorKey(region_range.start_key()[0], partition_id, vector_id, key); + + auto status = reader_->KvGet(Constant::kVectorScalarCF, key, value); + if (!status.ok()) { + DINGO_LOG(WARNING) << fmt::format("Get vector scalar data failed, vector_id: {} error: {} ", vector_id, + status.error_str()); + return status; + } + + pb::common::VectorScalardata vector_scalar; + if (!vector_scalar.ParseFromString(value)) { + return butil::Status(pb::error::EINTERNAL, "Decode vector scalar data failed"); + } + + auto lambda_scalar_compare_with_coprocessor_function = + [&scalar_coprocessor](const pb::common::VectorScalardata& internal_vector_scalar) { + bool is_reverse = false; + butil::Status status = scalar_coprocessor->Filter(internal_vector_scalar, is_reverse); + if (!status.ok()) { + LOG(ERROR) << "[" << __PRETTY_FUNCTION__ << "] " + << "scalar coprocessor::Filter failed " << status.error_cstr(); + return false; + } + return is_reverse; + }; + + compare_result = lambda_scalar_compare_with_coprocessor_function(vector_scalar); + + return butil::Status(); +} + butil::Status VectorReader::VectorBatchSearch(std::shared_ptr ctx, std::vector& results) { // NOLINT // Search vectors by vectors @@ -985,15 +1038,19 @@ butil::Status VectorReader::SearchVectorDebug( uint32_t top_n = parameter.top_n(); bool enable_range_search = parameter.enable_range_search(); - if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() == 0)) { + if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() == 0) && + !parameter.has_vector_coprocessor()) { butil::Status status = VectorReader::SearchAndRangeSearchWrapper( vector_index, region_range, vector_with_ids, parameter, vector_with_distance_results, top_n, {}); if (!status.ok()) { DINGO_LOG(ERROR) << status.error_cstr(); return status; } - } else { + } else if (parameter.has_vector_coprocessor()) { auto start = lambda_time_now_function(); + if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() != 0)) { + DINGO_LOG(WARNING) << "vector_with_ids[0].scalar_data() deprecated. use coprocessor."; + } top_n *= 10; butil::Status status = VectorReader::SearchAndRangeSearchWrapper(vector_index, region_range, vector_with_ids, parameter, tmp_results, top_n, {}); @@ -1005,14 +1062,22 @@ butil::Status VectorReader::SearchVectorDebug( search_time_us = lambda_time_diff_microseconds_function(start, end); auto start_kv_get = lambda_time_now_function(); + std::shared_ptr scalar_coprocessor = std::make_shared(); + + status = scalar_coprocessor->Open(CoprocessorPbWrapper{parameter.vector_coprocessor()}); + if (!status.ok()) { + DINGO_LOG(ERROR) << "scalar coprocessor::Open failed " << status.error_cstr(); + return status; + } + for (auto& vector_with_distance_result : tmp_results) { pb::index::VectorWithDistanceResult new_vector_with_distance_result; for (auto& temp_vector_with_distance : *vector_with_distance_result.mutable_vector_with_distances()) { int64_t temp_id = temp_vector_with_distance.vector_with_id().id(); bool compare_result = false; - butil::Status status = CompareVectorScalarData(region_range, partition_id, temp_id, - vector_with_ids[0].scalar_data(), compare_result); + butil::Status status = CompareVectorScalarDataWithCoprocessor(region_range, partition_id, temp_id, + scalar_coprocessor, compare_result); if (!status.ok()) { return status; } diff --git a/src/vector/vector_reader.h b/src/vector/vector_reader.h index c18150db9..ab299e0ee 100644 --- a/src/vector/vector_reader.h +++ b/src/vector/vector_reader.h @@ -22,6 +22,7 @@ #include #include "butil/status.h" +#include "coprocessor/raw_coprocessor.h" #include "engine/engine.h" #include "engine/raw_engine.h" #include "proto/common.pb.h" @@ -82,6 +83,11 @@ class VectorReader { butil::Status CompareVectorScalarData(const pb::common::Range& region_range, int64_t partition_id, int64_t vector_id, const pb::common::VectorScalardata& source_scalar_data, bool& compare_result); + butil::Status CompareVectorScalarDataWithCoprocessor(const pb::common::Range& region_range, int64_t partition_id, + int64_t vector_id, + const std::shared_ptr& scalar_coprocessor, + bool& compare_result); + butil::Status QueryVectorTableData(const pb::common::Range& region_range, int64_t partition_id, pb::common::VectorWithId& vector_with_id); butil::Status QueryVectorTableData(const pb::common::Range& region_range, int64_t partition_id,