Skip to content

Commit

Permalink
[feat][index] Add ivf flat vector index.
Browse files Browse the repository at this point in the history
  • Loading branch information
Haijun Yu authored and ketor committed Sep 21, 2023
1 parent 93f227b commit 79a8ae1
Show file tree
Hide file tree
Showing 20 changed files with 2,205 additions and 110 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ cmake-build-release
.DS_Store
tags
.cache
.run

# eclipse
.cproject
Expand Down
1 change: 1 addition & 0 deletions cmake/openblas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ ExternalProject_Add(
-DBUILD_WITHOUT_CBLAS=ON
-DBUILD_STATIC_LIBS=ON
-DBUILD_TESTING=OFF
-DC_LAPACK=ON
${EXTERNAL_OPTIONAL_ARGS}
LIST_SEPARATOR |
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${OPENBLAS_INSTALL_DIR}
Expand Down
1 change: 1 addition & 0 deletions proto/error.proto
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ enum Errno {
EVECTOR_INDEX_EXIST = 70017;
EVECTOR_INDEX_SWITCHING = 70018;
EVECTOR_INDEX_NOT_READY = 70019;
EVECTOR_NOT_TRAIN = 70020;

// file [80000, 90000)
EFILE_NOT_FOUND_READER = 80000;
Expand Down
8 changes: 8 additions & 0 deletions src/client/coordinator_client_function_meta.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ DECLARE_bool(with_auto_increment);
DECLARE_string(vector_index_type);
DECLARE_bool(auto_split);
DECLARE_uint32(part_count);
DECLARE_uint32(ncentroids);

void SendGetSchemas(std::shared_ptr<dingodb::CoordinatorInteraction> coordinator_interaction) {
dingodb::pb::meta::GetSchemasRequest request;
Expand Down Expand Up @@ -711,6 +712,8 @@ void SendCreateIndex(std::shared_ptr<dingodb::CoordinatorInteraction> coordinato
vector_index_parameter->set_vector_index_type(::dingodb::pb::common::VectorIndexType::VECTOR_INDEX_TYPE_HNSW);
} else if (FLAGS_vector_index_type == "flat") {
vector_index_parameter->set_vector_index_type(::dingodb::pb::common::VectorIndexType::VECTOR_INDEX_TYPE_FLAT);
} else if (FLAGS_vector_index_type == "ivf_flat") {
vector_index_parameter->set_vector_index_type(::dingodb::pb::common::VectorIndexType::VECTOR_INDEX_TYPE_IVF_FLAT);
} else {
DINGO_LOG(WARNING) << "vector_index_type is invalid, now only support hnsw and flat";
return;
Expand Down Expand Up @@ -748,6 +751,11 @@ void SendCreateIndex(std::shared_ptr<dingodb::CoordinatorInteraction> coordinato
auto* flat_index_parameter = vector_index_parameter->mutable_flat_parameter();
flat_index_parameter->set_dimension(FLAGS_dimension);
flat_index_parameter->set_metric_type(::dingodb::pb::common::MetricType::METRIC_TYPE_COSINE);
} else if (FLAGS_vector_index_type == "ivf_flat") {
auto* ivf_flat_index_parameter = vector_index_parameter->mutable_ivf_flat_parameter();
ivf_flat_index_parameter->set_dimension(FLAGS_dimension);
ivf_flat_index_parameter->set_metric_type(::dingodb::pb::common::MetricType::METRIC_TYPE_COSINE);
ivf_flat_index_parameter->set_ncentroids(FLAGS_ncentroids);
}

index_definition->set_version(1);
Expand Down
3 changes: 2 additions & 1 deletion src/client/dingodb_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ DEFINE_int64(max_elements, 0, "max_elements");
DEFINE_int64(dimension, 0, "dimension");
DEFINE_int64(efconstruction, 0, "efconstruction");
DEFINE_int64(nlinks, 0, "nlinks");
DEFINE_uint32(ncentroids, 10, "ncentroids default : 10");
DEFINE_uint32(part_count, 1, "partition count");
DEFINE_bool(with_auto_increment, true, "with_auto_increment");
DEFINE_string(vector_index_type, "", "vector_index_type");
DEFINE_string(vector_index_type, "", "vector_index_type:flat, hnsw, ivf_flat");
DEFINE_int32(round_num, 1, "Round of requests");
DEFINE_string(store_addrs, "", "server addrs");
DEFINE_string(raft_addrs, "127.0.0.1:10101:0,127.0.0.1:10102:0,127.0.0.1:10103:0", "raft addrs");
Expand Down
3 changes: 3 additions & 0 deletions src/common/constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ class Constant {

static const uint32_t kBuildVectorIndexBatchSize = 8192;

static constexpr int32_t kCreateIvfFlatParamNcentroids = 2048;
static constexpr int32_t kSearchIvfFlatParamNprobe = 80;

// split region
static constexpr int kSplitDoSnapshotRetryTimes = 5;
inline static const std::string kSplitStrategy = "PRE_CREATE_REGION";
Expand Down
60 changes: 41 additions & 19 deletions src/vector/vector_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <memory>

#include "bthread/bthread.h"
#include "butil/compiler_specific.h"
#include "butil/status.h"
#include "common/constant.h"
#include "fmt/core.h"
Expand Down Expand Up @@ -312,28 +313,12 @@ bool VectorIndexWrapper::IsExceedsMaxElements() {
}

bool VectorIndexWrapper::NeedToRebuild() {
if (Type() == pb::common::VECTOR_INDEX_TYPE_FLAT) {
return false;
}
auto vector_index = GetOwnVectorIndex();
if (vector_index == nullptr) {
return false;
}
uint64_t element_count = 0, deleted_count = 0;
auto status = vector_index->GetCount(element_count);
if (!status.ok()) {
return false;
}
status = vector_index->GetDeletedCount(deleted_count);
if (!status.ok()) {
return false;
}

if (element_count == 0 || deleted_count == 0) {
return false;
}

return (deleted_count > 0 && deleted_count > element_count / 2);
return vector_index->NeedToRebuild();
}

bool VectorIndexWrapper::NeedToSave(uint64_t last_save_log_behind) {
Expand Down Expand Up @@ -417,6 +402,23 @@ butil::Status VectorIndexWrapper::Add(const std::vector<pb::common::VectorWithId
}

auto status = vector_index->Add(vector_with_ids);
if (BAIDU_UNLIKELY(pb::error::Errno::EVECTOR_NOT_TRAIN == status.error_code())) {
std::vector<float> train_datas;
train_datas.reserve(vector_index->GetDimension() * vector_with_ids.size());
for (const auto& vector_with_id : vector_with_ids) {
train_datas.insert(train_datas.end(), vector_with_id.vector().float_values().begin(),
vector_with_id.vector().float_values().end());
}
status = vector_index->Train(train_datas);
if (BAIDU_LIKELY(status.ok())) {
// try again
status = vector_index->Add(vector_with_ids);
} else {
DINGO_LOG(ERROR) << fmt::format("[vector_index.wrapper][index_id({})] train failed size : {}", Id(),
train_datas.size());
return status;
}
}
if (status.ok()) {
write_key_count_ += vector_with_ids.size();
}
Expand Down Expand Up @@ -444,6 +446,23 @@ butil::Status VectorIndexWrapper::Upsert(const std::vector<pb::common::VectorWit
}

auto status = vector_index->Upsert(vector_with_ids);
if (BAIDU_UNLIKELY(pb::error::Errno::EVECTOR_NOT_TRAIN == status.error_code())) {
std::vector<float> train_datas;
train_datas.reserve(vector_index->GetDimension() * vector_with_ids.size());
for (const auto& vector_with_id : vector_with_ids) {
train_datas.insert(train_datas.end(), vector_with_id.vector().float_values().begin(),
vector_with_id.vector().float_values().end());
}
status = vector_index->Train(train_datas);
if (BAIDU_LIKELY(status.ok())) {
// try again
status = vector_index->Upsert(vector_with_ids);
} else {
DINGO_LOG(ERROR) << fmt::format("[vector_index.wrapper][index_id({})] train failed size : {}", Id(),
train_datas.size());
return status;
}
}
if (status.ok()) {
write_key_count_ += vector_with_ids.size();
}
Expand Down Expand Up @@ -480,7 +499,8 @@ butil::Status VectorIndexWrapper::Delete(const std::vector<uint64_t>& delete_ids
butil::Status VectorIndexWrapper::Search(std::vector<pb::common::VectorWithId> vector_with_ids, uint32_t topk,
const pb::common::Range& region_range,
std::vector<std::shared_ptr<VectorIndex::FilterFunctor>> filters,
std::vector<pb::index::VectorWithDistanceResult>& results, bool reconstruct) {
std::vector<pb::index::VectorWithDistanceResult>& results, bool reconstruct,
const pb::common::VectorSearchParameter& parameter) {
if (!IsReady()) {
DINGO_LOG(WARNING) << fmt::format("[vector_index.wrapper][index_id({})] vector index is not ready.", Id());
return butil::Status(pb::error::EVECTOR_INDEX_NOT_FOUND, "vector index %lu is not ready.", Id());
Expand All @@ -500,10 +520,12 @@ butil::Status VectorIndexWrapper::Search(std::vector<pb::common::VectorWithId> v
filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
} else if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_FLAT) {
filters.push_back(std::make_shared<VectorIndex::FlatRangeFilterFunctor>(min_vector_id, max_vector_id));
} else if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_IVF_FLAT) {
filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
}
}

return vector_index->Search(vector_with_ids, topk, filters, results, reconstruct);
return vector_index->Search(vector_with_ids, topk, filters, results, reconstruct, parameter);
}

} // namespace dingodb
37 changes: 33 additions & 4 deletions src/vector/vector_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <optional>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

#include "bthread/types.h"
Expand Down Expand Up @@ -127,7 +128,7 @@ class VectorIndex {
// List filter just for flat
class FlatListFilterFunctor : public FilterFunctor {
public:
FlatListFilterFunctor(std::vector<uint64_t>&& vector_ids)
explicit FlatListFilterFunctor(std::vector<uint64_t>&& vector_ids)
: vector_ids_(std::forward<std::vector<uint64_t>>(vector_ids)) {}
FlatListFilterFunctor(const FlatListFilterFunctor&) = delete;
FlatListFilterFunctor(FlatListFilterFunctor&&) = delete;
Expand Down Expand Up @@ -157,6 +158,27 @@ class VectorIndex {
std::vector<faiss::idx_t>* id_map_{nullptr};
};

// List filter just for ivf flat
class IvfFlatListFilterFunctor : public FilterFunctor {
public:
explicit IvfFlatListFilterFunctor(std::vector<uint64_t>&& vector_ids)
: vector_ids_(std::forward<std::vector<uint64_t>>(vector_ids)) {
// highly optimized code, do not modify it
array_indexs_.rehash(vector_ids_.size());
array_indexs_.insert(vector_ids_.begin(), vector_ids_.end());
}
IvfFlatListFilterFunctor(const IvfFlatListFilterFunctor&) = delete;
IvfFlatListFilterFunctor(IvfFlatListFilterFunctor&&) = delete;
IvfFlatListFilterFunctor& operator=(const IvfFlatListFilterFunctor&) = delete;
IvfFlatListFilterFunctor& operator=(IvfFlatListFilterFunctor&&) = delete;

bool Check(uint64_t index) override { return array_indexs_.find(index) != array_indexs_.end(); }

private:
std::vector<uint64_t> vector_ids_;
std::unordered_set<uint64_t> array_indexs_;
};

virtual int32_t GetDimension() = 0;
virtual butil::Status GetCount([[maybe_unused]] uint64_t& count);
virtual butil::Status GetDeletedCount([[maybe_unused]] uint64_t& deleted_count);
Expand All @@ -176,11 +198,17 @@ class VectorIndex {
virtual butil::Status Search([[maybe_unused]] std::vector<pb::common::VectorWithId> vector_with_ids,
[[maybe_unused]] uint32_t topk,
[[maybe_unused]] std::vector<std::shared_ptr<FilterFunctor>> filters,
std::vector<pb::index::VectorWithDistanceResult>& results,
[[maybe_unused]] bool reconstruct = false) = 0;
std::vector<pb::index::VectorWithDistanceResult>& results, // NOLINT
[[maybe_unused]] bool reconstruct = false,
[[maybe_unused]] const pb::common::VectorSearchParameter& parameter = {}) = 0;

virtual void LockWrite() = 0;
virtual void UnlockWrite() = 0;
virtual butil::Status Train(const std::vector<float>& train_datas) = 0;
virtual butil::Status Train(const std::vector<pb::common::VectorWithId>& vectors) = 0;
virtual bool NeedToRebuild() = 0;
virtual bool NeedTrain() { return false; }
virtual bool IsTrained() { return true; }

uint64_t Id() const { return id; }

Expand Down Expand Up @@ -301,7 +329,8 @@ class VectorIndexWrapper : public std::enable_shared_from_this<VectorIndexWrappe
butil::Status Search(std::vector<pb::common::VectorWithId> vector_with_ids, uint32_t topk,
const pb::common::Range& region_range,
std::vector<std::shared_ptr<VectorIndex::FilterFunctor>> filters,
std::vector<pb::index::VectorWithDistanceResult>& results, bool reconstruct = false);
std::vector<pb::index::VectorWithDistanceResult>& results, bool reconstruct = false,
const pb::common::VectorSearchParameter& parameter = {});

private:
// vector index id
Expand Down
Loading

0 comments on commit 79a8ae1

Please sign in to comment.