Skip to content

Commit

Permalink
[feat][index] Add Scalar Data with post filter with coprocessor.
Browse files Browse the repository at this point in the history
  • Loading branch information
Haijun Yu authored and ketor committed Mar 6, 2024
1 parent 02c3e5d commit ec9d291
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 11 deletions.
87 changes: 76 additions & 11 deletions src/vector/vector_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,18 @@ butil::Status VectorReader::SearchVector(
uint32_t top_n = parameter.top_n();
bool enable_range_search = parameter.enable_range_search();

if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() == 0)) {
if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() == 0) &&
!parameter.has_vector_coprocessor()) {
butil::Status status = VectorReader::SearchAndRangeSearchWrapper(
vector_index, region_range, vector_with_ids, parameter, vector_with_distance_results, top_n, {});
if (!status.ok()) {
DINGO_LOG(ERROR) << status.error_cstr();
return status;
}

} else {
} else if (parameter.has_vector_coprocessor()) {
if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() != 0)) {
DINGO_LOG(WARNING) << "vector_with_ids[0].scalar_data() deprecated. use coprocessor.";
}
top_n *= 10;
butil::Status status = VectorReader::SearchAndRangeSearchWrapper(vector_index, region_range, vector_with_ids,
parameter, tmp_results, top_n, {});
Expand All @@ -110,17 +113,26 @@ butil::Status VectorReader::SearchVector(
return status;
}

std::shared_ptr<RawCoprocessor> scalar_coprocessor = std::make_shared<CoprocessorScalar>();

status = scalar_coprocessor->Open(CoprocessorPbWrapper{parameter.vector_coprocessor()});
if (!status.ok()) {
DINGO_LOG(ERROR) << "scalar coprocessor::Open failed " << status.error_cstr();
return status;
}

for (auto& vector_with_distance_result : tmp_results) {
pb::index::VectorWithDistanceResult new_vector_with_distance_result;

for (auto& temp_vector_with_distance : *vector_with_distance_result.mutable_vector_with_distances()) {
int64_t temp_id = temp_vector_with_distance.vector_with_id().id();
bool compare_result = false;
butil::Status status = CompareVectorScalarData(region_range, partition_id, temp_id,
vector_with_ids[0].scalar_data(), compare_result);
butil::Status status = CompareVectorScalarDataWithCoprocessor(region_range, partition_id, temp_id,
scalar_coprocessor, compare_result);
if (!status.ok()) {
return status;
}

if (!compare_result) {
continue;
}
Expand All @@ -135,6 +147,12 @@ butil::Status VectorReader::SearchVector(
}
vector_with_distance_results.emplace_back(std::move(new_vector_with_distance_result));
}

} else { //! parameter.has_vector_coprocessor() && vector_with_ids[0].scalar_data().scalar_data_size() != 0
butil::Status status;
std::string s = fmt::format("CompareVectorScalarData deprecated. use coprocessor instead.");
DINGO_LOG(ERROR) << s;
return butil::Status(pb::error::EVECTOR_NOT_SUPPORT, s);
}
} else if (dingodb::pb::common::VectorFilter::VECTOR_ID_FILTER == vector_filter) { // vector id array search
butil::Status status = DoVectorSearchForVectorIdPreFilter(vector_index, vector_with_ids, parameter, region_range,
Expand All @@ -143,7 +161,6 @@ butil::Status VectorReader::SearchVector(
DINGO_LOG(ERROR) << fmt::format("DoVectorSearchForVectorIdPreFilter failed");
return status;
}

} else if (dingodb::pb::common::VectorFilter::SCALAR_FILTER == vector_filter &&
dingodb::pb::common::VectorFilterType::QUERY_PRE == vector_filter_type) { // scalar pre filter search

Expand All @@ -153,7 +170,6 @@ butil::Status VectorReader::SearchVector(
DINGO_LOG(ERROR) << fmt::format("DoVectorSearchForScalarPreFilter failed ");
return status;
}

} else if (dingodb::pb::common::VectorFilter::TABLE_FILTER ==
vector_filter) { // table coprocessor pre filter search. not impl
butil::Status status = DoVectorSearchForTableCoprocessor(vector_index, region_range, vector_with_ids, parameter,
Expand Down Expand Up @@ -325,6 +341,43 @@ butil::Status VectorReader::CompareVectorScalarData(const pb::common::Range& reg
return butil::Status();
}

butil::Status VectorReader::CompareVectorScalarDataWithCoprocessor(
const pb::common::Range& region_range, int64_t partition_id, int64_t vector_id,
const std::shared_ptr<RawCoprocessor>& scalar_coprocessor, bool& compare_result) {
compare_result = false;
std::string key, value;

VectorCodec::EncodeVectorKey(region_range.start_key()[0], partition_id, vector_id, key);

auto status = reader_->KvGet(Constant::kVectorScalarCF, key, value);
if (!status.ok()) {
DINGO_LOG(WARNING) << fmt::format("Get vector scalar data failed, vector_id: {} error: {} ", vector_id,
status.error_str());
return status;
}

pb::common::VectorScalardata vector_scalar;
if (!vector_scalar.ParseFromString(value)) {
return butil::Status(pb::error::EINTERNAL, "Decode vector scalar data failed");
}

auto lambda_scalar_compare_with_coprocessor_function =
[&scalar_coprocessor](const pb::common::VectorScalardata& internal_vector_scalar) {
bool is_reverse = false;
butil::Status status = scalar_coprocessor->Filter(internal_vector_scalar, is_reverse);
if (!status.ok()) {
LOG(ERROR) << "[" << __PRETTY_FUNCTION__ << "] "
<< "scalar coprocessor::Filter failed " << status.error_cstr();
return false;
}
return is_reverse;
};

compare_result = lambda_scalar_compare_with_coprocessor_function(vector_scalar);

return butil::Status();
}

butil::Status VectorReader::VectorBatchSearch(std::shared_ptr<Engine::VectorReader::Context> ctx,
std::vector<pb::index::VectorWithDistanceResult>& results) { // NOLINT
// Search vectors by vectors
Expand Down Expand Up @@ -985,15 +1038,19 @@ butil::Status VectorReader::SearchVectorDebug(
uint32_t top_n = parameter.top_n();
bool enable_range_search = parameter.enable_range_search();

if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() == 0)) {
if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() == 0) &&
!parameter.has_vector_coprocessor()) {
butil::Status status = VectorReader::SearchAndRangeSearchWrapper(
vector_index, region_range, vector_with_ids, parameter, vector_with_distance_results, top_n, {});
if (!status.ok()) {
DINGO_LOG(ERROR) << status.error_cstr();
return status;
}
} else {
} else if (parameter.has_vector_coprocessor()) {
auto start = lambda_time_now_function();
if (BAIDU_UNLIKELY(vector_with_ids[0].scalar_data().scalar_data_size() != 0)) {
DINGO_LOG(WARNING) << "vector_with_ids[0].scalar_data() deprecated. use coprocessor.";
}
top_n *= 10;
butil::Status status = VectorReader::SearchAndRangeSearchWrapper(vector_index, region_range, vector_with_ids,
parameter, tmp_results, top_n, {});
Expand All @@ -1005,14 +1062,22 @@ butil::Status VectorReader::SearchVectorDebug(
search_time_us = lambda_time_diff_microseconds_function(start, end);

auto start_kv_get = lambda_time_now_function();
std::shared_ptr<RawCoprocessor> scalar_coprocessor = std::make_shared<CoprocessorScalar>();

status = scalar_coprocessor->Open(CoprocessorPbWrapper{parameter.vector_coprocessor()});
if (!status.ok()) {
DINGO_LOG(ERROR) << "scalar coprocessor::Open failed " << status.error_cstr();
return status;
}

for (auto& vector_with_distance_result : tmp_results) {
pb::index::VectorWithDistanceResult new_vector_with_distance_result;

for (auto& temp_vector_with_distance : *vector_with_distance_result.mutable_vector_with_distances()) {
int64_t temp_id = temp_vector_with_distance.vector_with_id().id();
bool compare_result = false;
butil::Status status = CompareVectorScalarData(region_range, partition_id, temp_id,
vector_with_ids[0].scalar_data(), compare_result);
butil::Status status = CompareVectorScalarDataWithCoprocessor(region_range, partition_id, temp_id,
scalar_coprocessor, compare_result);
if (!status.ok()) {
return status;
}
Expand Down
6 changes: 6 additions & 0 deletions src/vector/vector_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <vector>

#include "butil/status.h"
#include "coprocessor/raw_coprocessor.h"
#include "engine/engine.h"
#include "engine/raw_engine.h"
#include "proto/common.pb.h"
Expand Down Expand Up @@ -82,6 +83,11 @@ class VectorReader {
butil::Status CompareVectorScalarData(const pb::common::Range& region_range, int64_t partition_id, int64_t vector_id,
const pb::common::VectorScalardata& source_scalar_data, bool& compare_result);

butil::Status CompareVectorScalarDataWithCoprocessor(const pb::common::Range& region_range, int64_t partition_id,
int64_t vector_id,
const std::shared_ptr<RawCoprocessor>& scalar_coprocessor,
bool& compare_result);

butil::Status QueryVectorTableData(const pb::common::Range& region_range, int64_t partition_id,
pb::common::VectorWithId& vector_with_id);
butil::Status QueryVectorTableData(const pb::common::Range& region_range, int64_t partition_id,
Expand Down

0 comments on commit ec9d291

Please sign in to comment.