Skip to content

Commit

Permalink
[fix][index] Fix vector ids prefilter for bruteforce index.
Browse files Browse the repository at this point in the history
Signed-off-by: Ketor <[email protected]>
  • Loading branch information
ketor committed Mar 8, 2024
1 parent fbcb0c1 commit 3b89c9b
Show file tree
Hide file tree
Showing 18 changed files with 252 additions and 142 deletions.
2 changes: 2 additions & 0 deletions src/client/dingodb_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ DEFINE_bool(with_vector_ids, false, "Search vector with vector ids list default
DEFINE_bool(with_scalar_pre_filter, false, "Search vector with scalar data pre filter");
DEFINE_bool(with_scalar_post_filter, false, "Search vector with scalar data post filter");
DEFINE_bool(with_table_pre_filter, false, "Search vector with table data pre filter");
DEFINE_string(scalar_key, "", "Request scalar_key");
DEFINE_string(scalar_value, "", "Request scalar_value");
DEFINE_int32(vector_ids_count, 100, "vector ids count");
DEFINE_string(csv_data, "", "csv data");
DEFINE_string(json_data, "", "json data");
Expand Down
137 changes: 99 additions & 38 deletions src/client/store_client_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ DECLARE_string(vector_data);
DECLARE_string(csv_data);
DECLARE_string(csv_output);
DECLARE_int32(dimension);
DECLARE_string(scalar_key);
DECLARE_string(scalar_value);

// for calc distance
DEFINE_string(vector_data1, "", "vector data 1");
Expand Down Expand Up @@ -503,27 +505,66 @@ void SendVectorSearch(int64_t region_id, uint32_t dimension, uint32_t topn) {
request.mutable_parameter()->set_vector_filter(::dingodb::pb::common::VectorFilter::SCALAR_FILTER);
request.mutable_parameter()->set_vector_filter_type(::dingodb::pb::common::VectorFilterType::QUERY_PRE);

for (int k = 0; k < 2; k++) {
::dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING);
::dingodb::pb::common::ScalarField* field = scalar_value.add_fields();
field->set_string_data("value" + std::to_string(k));
if (FLAGS_scalar_filter_key.empty() || FLAGS_scalar_filter_value.empty()) {
DINGO_LOG(ERROR) << "scalar_filter_key or scalar_filter_value is empty";
return;
}

vector->mutable_scalar_data()->mutable_scalar_data()->insert({"key" + std::to_string(k), scalar_value});
if (!FLAGS_scalar_filter_key.empty()) {
dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(dingodb::pb::common::ScalarFieldType::STRING);
dingodb::pb::common::ScalarField* field = scalar_value.add_fields();
field->set_string_data(FLAGS_scalar_filter_value);

vector->mutable_scalar_data()->mutable_scalar_data()->insert({FLAGS_scalar_filter_key, scalar_value});

DINGO_LOG(INFO) << "scalar_filter_key: " << FLAGS_scalar_filter_key
<< " scalar_filter_value: " << FLAGS_scalar_filter_value;
}

if (!FLAGS_scalar_filter_key2.empty()) {
dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(dingodb::pb::common::ScalarFieldType::STRING);
dingodb::pb::common::ScalarField* field = scalar_value.add_fields();
field->set_string_data(FLAGS_scalar_filter_value2);

vector->mutable_scalar_data()->mutable_scalar_data()->insert({FLAGS_scalar_filter_key2, scalar_value});

DINGO_LOG(INFO) << "scalar_filter_key2: " << FLAGS_scalar_filter_key2
<< " scalar_filter_value2: " << FLAGS_scalar_filter_value2;
}
}

if (FLAGS_with_scalar_post_filter) {
request.mutable_parameter()->set_vector_filter(::dingodb::pb::common::VectorFilter::SCALAR_FILTER);
request.mutable_parameter()->set_vector_filter_type(::dingodb::pb::common::VectorFilterType::QUERY_POST);
if (FLAGS_scalar_filter_key.empty() || FLAGS_scalar_filter_value.empty()) {
DINGO_LOG(ERROR) << "scalar_filter_key or scalar_filter_value is empty";
return;
}

for (int k = 0; k < 2; k++) {
::dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING);
::dingodb::pb::common::ScalarField* field = scalar_value.add_fields();
field->set_string_data("value" + std::to_string(k));
if (!FLAGS_scalar_filter_key.empty()) {
dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(dingodb::pb::common::ScalarFieldType::STRING);
dingodb::pb::common::ScalarField* field = scalar_value.add_fields();
field->set_string_data(FLAGS_scalar_filter_value);

vector->mutable_scalar_data()->mutable_scalar_data()->insert({"key" + std::to_string(k), scalar_value});
vector->mutable_scalar_data()->mutable_scalar_data()->insert({FLAGS_scalar_filter_key, scalar_value});

DINGO_LOG(INFO) << "scalar_filter_key: " << FLAGS_scalar_filter_key
<< " scalar_filter_value: " << FLAGS_scalar_filter_value;
}

if (!FLAGS_scalar_filter_key2.empty()) {
dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(dingodb::pb::common::ScalarFieldType::STRING);
dingodb::pb::common::ScalarField* field = scalar_value.add_fields();
field->set_string_data(FLAGS_scalar_filter_value2);

vector->mutable_scalar_data()->mutable_scalar_data()->insert({FLAGS_scalar_filter_key2, scalar_value});

DINGO_LOG(INFO) << "scalar_filter_key2: " << FLAGS_scalar_filter_key2
<< " scalar_filter_value2: " << FLAGS_scalar_filter_value2;
}
}

Expand Down Expand Up @@ -580,21 +621,23 @@ void SendVectorSearch(int64_t region_id, uint32_t dimension, uint32_t topn) {
}
}

for (const auto& vector_with_distance : batch_result.vector_with_distances()) {
std::vector<float> x_i = dingodb::Helper::PbRepeatedToVector(vector->vector().float_values());
std::vector<float> y_j =
dingodb::Helper::PbRepeatedToVector(vector_with_distance.vector_with_id().vector().float_values());
if (!FLAGS_without_vector) {
for (const auto& vector_with_distance : batch_result.vector_with_distances()) {
std::vector<float> x_i = dingodb::Helper::PbRepeatedToVector(vector->vector().float_values());
std::vector<float> y_j =
dingodb::Helper::PbRepeatedToVector(vector_with_distance.vector_with_id().vector().float_values());

auto faiss_l2 = dingodb::Helper::DingoFaissL2sqr(x_i.data(), y_j.data(), dimension);
auto faiss_ip = dingodb::Helper::DingoFaissInnerProduct(x_i.data(), y_j.data(), dimension);
auto hnsw_l2 = dingodb::Helper::DingoHnswL2Sqr(x_i.data(), y_j.data(), dimension);
auto hnsw_ip = dingodb::Helper::DingoHnswInnerProduct(x_i.data(), y_j.data(), dimension);
auto hnsw_ip_dist = dingodb::Helper::DingoHnswInnerProductDistance(x_i.data(), y_j.data(), dimension);
auto faiss_l2 = dingodb::Helper::DingoFaissL2sqr(x_i.data(), y_j.data(), dimension);
auto faiss_ip = dingodb::Helper::DingoFaissInnerProduct(x_i.data(), y_j.data(), dimension);
auto hnsw_l2 = dingodb::Helper::DingoHnswL2Sqr(x_i.data(), y_j.data(), dimension);
auto hnsw_ip = dingodb::Helper::DingoHnswInnerProduct(x_i.data(), y_j.data(), dimension);
auto hnsw_ip_dist = dingodb::Helper::DingoHnswInnerProductDistance(x_i.data(), y_j.data(), dimension);

DINGO_LOG(INFO) << "vector_id: " << vector_with_distance.vector_with_id().id()
<< ", distance: " << vector_with_distance.distance() << ", [faiss_l2: " << faiss_l2
<< ", faiss_ip: " << faiss_ip << ", hnsw_l2: " << hnsw_l2 << ", hnsw_ip: " << hnsw_ip
<< ", hnsw_ip_dist: " << hnsw_ip_dist << "]";
DINGO_LOG(INFO) << "vector_id: " << vector_with_distance.vector_with_id().id()
<< ", distance: " << vector_with_distance.distance() << ", [faiss_l2: " << faiss_l2
<< ", faiss_ip: " << faiss_ip << ", hnsw_l2: " << hnsw_l2 << ", hnsw_ip: " << hnsw_ip
<< ", hnsw_ip_dist: " << hnsw_ip_dist << "]";
}
}
}

Expand Down Expand Up @@ -1732,19 +1775,37 @@ int SendBatchVectorAdd(int64_t region_id, uint32_t dimension, std::vector<int64_
}

if (with_scalar) {
for (int k = 0; k < 2; ++k) {
auto* scalar_data = vector_with_id->mutable_scalar_data()->mutable_scalar_data();
dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING);
scalar_value.add_fields()->set_string_data(fmt::format("scalar_value{}", k));
(*scalar_data)[fmt::format("scalar_key{}", k)] = scalar_value;
}
for (int k = 2; k < 4; ++k) {
auto* scalar_data = vector_with_id->mutable_scalar_data()->mutable_scalar_data();
dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::INT64);
scalar_value.add_fields()->set_long_data(k);
(*scalar_data)[fmt::format("scalar_key{}", k)] = scalar_value;
if (FLAGS_scalar_filter_key.empty() || FLAGS_scalar_filter_value.empty()) {
for (int k = 0; k < 2; ++k) {
auto* scalar_data = vector_with_id->mutable_scalar_data()->mutable_scalar_data();
dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING);
scalar_value.add_fields()->set_string_data(fmt::format("scalar_value{}", k));
(*scalar_data)[fmt::format("scalar_key{}", k)] = scalar_value;
}
for (int k = 2; k < 4; ++k) {
auto* scalar_data = vector_with_id->mutable_scalar_data()->mutable_scalar_data();
dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::INT64);
scalar_value.add_fields()->set_long_data(k);
(*scalar_data)[fmt::format("scalar_key{}", k)] = scalar_value;
}
} else {
if (!FLAGS_scalar_filter_key.empty()) {
auto* scalar_data = vector_with_id->mutable_scalar_data()->mutable_scalar_data();
dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING);
scalar_value.add_fields()->set_string_data(FLAGS_scalar_filter_value);
(*scalar_data)[FLAGS_scalar_filter_key] = scalar_value;
}

if (!FLAGS_scalar_filter_key2.empty()) {
auto* scalar_data = vector_with_id->mutable_scalar_data()->mutable_scalar_data();
dingodb::pb::common::ScalarValue scalar_value;
scalar_value.set_field_type(::dingodb::pb::common::ScalarFieldType::STRING);
scalar_value.add_fields()->set_string_data(FLAGS_scalar_filter_value2);
(*scalar_data)[FLAGS_scalar_filter_key2] = scalar_value;
}
}
}

Expand Down
33 changes: 25 additions & 8 deletions src/vector/vector_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,12 @@ butil::Status VectorIndexWrapper::Search(std::vector<pb::common::VectorWithId> v
if (region_range.start_key() != index_range.start_key() || region_range.end_key() != index_range.end_key()) {
int64_t min_vector_id = 0, max_vector_id = 0;
VectorCodec::DecodeRangeToVectorId(region_range, min_vector_id, max_vector_id);
VectorIndexWrapper::SetVectorIndexFilter(vector_index, filters, min_vector_id, max_vector_id);
auto ret = VectorIndexWrapper::SetVectorIndexRangeFilter(vector_index, filters, min_vector_id, max_vector_id);
if (!ret.ok()) {
DINGO_LOG(ERROR) << fmt::format("[vector_index.wrapper][index_id({})] set vector index filter failed, error: {}",
Id(), ret.error_str());
return ret;
}
}

return vector_index->Search(vector_with_ids, topk, filters, reconstruct, parameter, results);
Expand Down Expand Up @@ -914,7 +919,12 @@ butil::Status VectorIndexWrapper::RangeSearch(std::vector<pb::common::VectorWith
if (region_range.start_key() != index_range.start_key() || region_range.end_key() != index_range.end_key()) {
int64_t min_vector_id = 0, max_vector_id = 0;
VectorCodec::DecodeRangeToVectorId(region_range, min_vector_id, max_vector_id);
VectorIndexWrapper::SetVectorIndexFilter(vector_index, filters, min_vector_id, max_vector_id);
auto ret = VectorIndexWrapper::SetVectorIndexRangeFilter(vector_index, filters, min_vector_id, max_vector_id);
if (!ret.ok()) {
DINGO_LOG(ERROR) << fmt::format("[vector_index.wrapper][index_id({})] set vector index filter failed, error: {}",
Id(), ret.error_str());
return ret;
}
}

return vector_index->RangeSearch(vector_with_ids, radius, filters, reconstruct, parameter, results);
Expand All @@ -935,14 +945,13 @@ bool VectorIndexWrapper::IsPermanentHoldVectorIndex(int64_t region_id) {
return true;
}

butil::Status VectorIndexWrapper::SetVectorIndexFilter(
VectorIndexPtr vector_index,
std::vector<std::shared_ptr<VectorIndex::FilterFunctor>>& filters, // NOLINT
butil::Status VectorIndexWrapper::SetVectorIndexRangeFilter(
VectorIndexPtr vector_index, std::vector<std::shared_ptr<VectorIndex::FilterFunctor>>& filters,
int64_t min_vector_id, int64_t max_vector_id) {
if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_HNSW) {
filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
} else if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_FLAT) {
// filters.push_back(std::make_shared<VectorIndex::FlatRangeFilterFunctor>(min_vector_id, max_vector_id));
} else if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_FLAT ||
vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_BRUTEFORCE) {
filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
} else if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_IVF_FLAT) {
filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
Expand All @@ -952,8 +961,16 @@ butil::Status VectorIndexWrapper::SetVectorIndexFilter(
} else if (vector_index->VectorIndexSubType() == pb::common::VECTOR_INDEX_TYPE_FLAT) {
filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
} else {
// do nothing
return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT,
fmt::format("SetVectorIndexFilter not support index type: {} sub type: {}",
pb::common::VectorIndexType_Name(vector_index->VectorIndexType()),
pb::common::VectorIndexType_Name(vector_index->VectorIndexSubType())));
}
} else {
return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT,
fmt::format("SetVectorIndexFilter not support index type: {} sub type: {}",
pb::common::VectorIndexType_Name(vector_index->VectorIndexType()),
pb::common::VectorIndexType_Name(vector_index->VectorIndexSubType())));
}

return butil::Status::OK();
Expand Down
12 changes: 6 additions & 6 deletions src/vector/vector_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ class VectorIndex {

virtual butil::Status Load(const std::string& path);

virtual butil::Status Search(std::vector<pb::common::VectorWithId> vector_with_ids, uint32_t topk,
std::vector<std::shared_ptr<FilterFunctor>> filters, bool reconstruct,
virtual butil::Status Search(const std::vector<pb::common::VectorWithId>& vector_with_ids, uint32_t topk,
const std::vector<std::shared_ptr<FilterFunctor>>& filters, bool reconstruct,
const pb::common::VectorSearchParameter& parameter,
std::vector<pb::index::VectorWithDistanceResult>& results) = 0;

virtual butil::Status RangeSearch(std::vector<pb::common::VectorWithId> vector_with_ids, float radius,
std::vector<std::shared_ptr<VectorIndex::FilterFunctor>> filters, bool reconstruct,
const pb::common::VectorSearchParameter& parameter,
virtual butil::Status RangeSearch(const std::vector<pb::common::VectorWithId>& vector_with_ids, float radius,
const std::vector<std::shared_ptr<VectorIndex::FilterFunctor>>& filters,
bool reconstruct, const pb::common::VectorSearchParameter& parameter,
std::vector<pb::index::VectorWithDistanceResult>& results) = 0;

virtual void LockWrite() = 0;
Expand Down Expand Up @@ -359,7 +359,7 @@ class VectorIndexWrapper : public std::enable_shared_from_this<VectorIndexWrappe
const pb::common::VectorSearchParameter& parameter,
std::vector<pb::index::VectorWithDistanceResult>& results);

static butil::Status SetVectorIndexFilter(
static butil::Status SetVectorIndexRangeFilter(
VectorIndexPtr vector_index,
std::vector<std::shared_ptr<VectorIndex::FilterFunctor>>& filters, // NOLINT
int64_t min_vector_id, int64_t max_vector_id);
Expand Down
8 changes: 4 additions & 4 deletions src/vector/vector_index_bruteforce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ butil::Status VectorIndexBruteforce::Add(const std::vector<pb::common::VectorWit

butil::Status VectorIndexBruteforce::Delete(const std::vector<int64_t>& /*delete_ids*/) { return butil::Status::OK(); }

butil::Status VectorIndexBruteforce::Search(std::vector<pb::common::VectorWithId> /*vector_with_ids*/,
uint32_t /*topk*/, std::vector<std::shared_ptr<FilterFunctor>> /*filters*/,
butil::Status VectorIndexBruteforce::Search(const std::vector<pb::common::VectorWithId>& /*vector_with_ids*/,
uint32_t /*topk*/, const std::vector<std::shared_ptr<FilterFunctor>>& /*filters*/,
bool, const pb::common::VectorSearchParameter&,
std::vector<pb::index::VectorWithDistanceResult>& /*results*/) {
return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT, "not support");
}

butil::Status VectorIndexBruteforce::RangeSearch(std::vector<pb::common::VectorWithId> /*vector_with_ids*/,
butil::Status VectorIndexBruteforce::RangeSearch(const std::vector<pb::common::VectorWithId>& /*vector_with_ids*/,
float /*radius*/,
std::vector<std::shared_ptr<VectorIndex::FilterFunctor>> /*filters*/,
const std::vector<std::shared_ptr<VectorIndex::FilterFunctor>>& /*filters*/,
bool /*reconstruct*/,
const pb::common::VectorSearchParameter& /*parameter*/,
std::vector<pb::index::VectorWithDistanceResult>& /*results*/) {
Expand Down
8 changes: 4 additions & 4 deletions src/vector/vector_index_bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ class VectorIndexBruteforce : public VectorIndex {

butil::Status Delete(const std::vector<int64_t>& delete_ids) override;

butil::Status Search(std::vector<pb::common::VectorWithId> vector_with_ids, uint32_t topk,
std::vector<std::shared_ptr<FilterFunctor>> filters, bool reconstruct,
butil::Status Search(const std::vector<pb::common::VectorWithId>& vector_with_ids, uint32_t topk,
const std::vector<std::shared_ptr<FilterFunctor>>& filters, bool reconstruct,
const pb::common::VectorSearchParameter& parameter,
std::vector<pb::index::VectorWithDistanceResult>& results) override;

butil::Status RangeSearch(std::vector<pb::common::VectorWithId> vector_with_ids, float radius,
std::vector<std::shared_ptr<VectorIndex::FilterFunctor>> filters, bool reconstruct,
butil::Status RangeSearch(const std::vector<pb::common::VectorWithId>& vector_with_ids, float radius,
const std::vector<std::shared_ptr<VectorIndex::FilterFunctor>>& filters, bool reconstruct,
const pb::common::VectorSearchParameter& parameter,
std::vector<pb::index::VectorWithDistanceResult>& results) override;

Expand Down
Loading

0 comments on commit 3b89c9b

Please sign in to comment.