Skip to content

Commit

Permalink
[feat][sdk] Support create vector index with scalar schema
Browse files Browse the repository at this point in the history
  • Loading branch information
wchuande authored and ketor committed Apr 12, 2024
1 parent df2a714 commit 2d3b5d1
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 41 deletions.
4 changes: 4 additions & 0 deletions src/example/sdk_vector_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,17 @@ static void PrepareVectorIndex() {
CHECK_NOTNULL(creator);
dingodb::ScopeGuard guard([&]() { delete creator; });

dingodb::sdk::VectorScalarSchema schema;
// NOTE: may be add more
schema.cols.push_back({g_scalar_col[0], dingodb::sdk::ScalarFieldType::kInt64, true});
Status create = creator->SetSchemaId(g_schema_id)
.SetName(g_index_name)
.SetReplicaNum(3)
.SetRangePartitions(g_range_partition_seperator_ids)
.SetFlatParam(g_flat_param)
.SetAutoIncrement(true)
.SetAutoIncrementStart(1)
.SetScalarSchema(schema)
.Create(g_index_id);
DINGO_LOG(INFO) << "Create index status: " << create.ToString() << ", index_id:" << g_index_id;
sleep(20);
Expand Down
43 changes: 30 additions & 13 deletions src/sdk/vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cstdint>
#include <map>
#include <string>
#include <variant>
#include <vector>

#include "sdk/status.h"
Expand Down Expand Up @@ -120,6 +121,33 @@ struct BruteForceParam {
static VectorIndexType Type() { return VectorIndexType::kBruteForce; }
};

enum class ScalarFieldType : uint8_t {
kNone,
kBool,
kInt8,
kInt16,
kInt32,
kInt64,
kFloat32,
kDouble,
kString,
kBytes
};

struct VectorScalarColumnSchema {
std::string key;
ScalarFieldType type;
bool speed;

VectorScalarColumnSchema(std::string& key, ScalarFieldType type, bool speed = false)
: key(key), type(type), speed(speed) {}
};

// TODO: maybe use builder to build VectorScalarSchema
struct VectorScalarSchema {
std::vector<VectorScalarColumnSchema> cols;
};

enum ValueType : uint8_t { kNoneValueType, kFloat, kUint8 };

std::string ValueTypeToString(ValueType type);
Expand Down Expand Up @@ -158,19 +186,6 @@ struct Vector {
std::string ToString() const;
};

enum class ScalarFieldType : uint8_t {
kNone,
kBool,
kInt8,
kInt16,
kInt32,
kInt64,
kFloat32,
kDouble,
kString,
kBytes
};

struct ScalarField {
bool bool_data;
int32_t int_data;
Expand Down Expand Up @@ -467,6 +482,8 @@ class VectorIndexCreator {
// start_id should greater than 0, when set auto_increment is set to true
VectorIndexCreator& SetAutoIncrementStart(int64_t start_id);

VectorIndexCreator& SetScalarSchema(const VectorScalarSchema& schema);

Status Create(int64_t& out_index_id);

private:
Expand Down
77 changes: 68 additions & 9 deletions src/sdk/vector/vector_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,60 @@ static VectorIndexType InternalVectorIndexTypePB2VectorIndexType(pb::common::Vec
}
}

static pb::common::ScalarFieldType ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType type) {
switch (type) {
case ScalarFieldType::kNone:
return pb::common::ScalarFieldType::NONE;
case ScalarFieldType::kBool:
return pb::common::ScalarFieldType::BOOL;
case ScalarFieldType::kInt8:
return pb::common::ScalarFieldType::INT8;
case ScalarFieldType::kInt16:
return pb::common::ScalarFieldType::INT16;
case ScalarFieldType::kInt32:
return pb::common::ScalarFieldType::INT32;
case ScalarFieldType::kInt64:
return pb::common::ScalarFieldType::INT64;
case ScalarFieldType::kFloat32:
return pb::common::ScalarFieldType::FLOAT32;
case ScalarFieldType::kDouble:
return pb::common::ScalarFieldType::DOUBLE;
case ScalarFieldType::kString:
return pb::common::ScalarFieldType::STRING;
case ScalarFieldType::kBytes:
return pb::common::ScalarFieldType::BYTES;
default:
CHECK(false) << "unsupported scalar field type:" << static_cast<int>(type);
}
}

static ScalarFieldType InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType type) {
switch (type) {
case pb::common::ScalarFieldType::NONE:
return ScalarFieldType::kNone;
case pb::common::ScalarFieldType::BOOL:
return ScalarFieldType::kBool;
case pb::common::ScalarFieldType::INT8:
return ScalarFieldType::kInt8;
case pb::common::ScalarFieldType::INT16:
return ScalarFieldType::kInt16;
case pb::common::ScalarFieldType::INT32:
return ScalarFieldType::kInt32;
case pb::common::ScalarFieldType::INT64:
return ScalarFieldType::kInt64;
case pb::common::ScalarFieldType::FLOAT32:
return ScalarFieldType::kFloat32;
case pb::common::ScalarFieldType::DOUBLE:
return ScalarFieldType::kDouble;
case pb::common::ScalarFieldType::STRING:
return ScalarFieldType::kString;
case pb::common::ScalarFieldType::BYTES:
return ScalarFieldType::kBytes;
default:
CHECK(false) << "unsupported scalar field type:" << pb::common::ScalarFieldType_Name(type);
}
}

static void FillFlatParmeter(pb::common::VectorIndexParameter* parameter, const FlatParam& param) {
parameter->set_vector_index_type(pb::common::VECTOR_INDEX_TYPE_FLAT);
auto* flat = parameter->mutable_flat_parameter();
Expand Down Expand Up @@ -181,48 +235,41 @@ static pb::common::ValueType ValueType2InternalValueTypePB(ValueType value_type)

static pb::common::ScalarValue ScalarValue2InternalScalarValuePB(const sdk::ScalarValue& scalar_value) {
pb::common::ScalarValue result;
result.set_field_type(ScalarFieldType2InternalScalarFieldTypePB(scalar_value.type));

if (scalar_value.type == sdk::ScalarFieldType::kBool) {
result.set_field_type(pb::common::ScalarFieldType::BOOL);
for (const auto& field : scalar_value.fields) {
result.add_fields()->set_bool_data(field.bool_data);
}
} else if (scalar_value.type == sdk::ScalarFieldType::kInt8) {
result.set_field_type(pb::common::ScalarFieldType::INT8);
for (const auto& field : scalar_value.fields) {
result.add_fields()->set_int_data(field.int_data);
}
} else if (scalar_value.type == sdk::ScalarFieldType::kInt16) {
result.set_field_type(pb::common::ScalarFieldType::INT16);
for (const auto& field : scalar_value.fields) {
result.add_fields()->set_int_data(field.int_data);
}
} else if (scalar_value.type == sdk::ScalarFieldType::kInt32) {
result.set_field_type(pb::common::ScalarFieldType::INT32);
for (const auto& field : scalar_value.fields) {
result.add_fields()->set_int_data(field.int_data);
}
} else if (scalar_value.type == sdk::ScalarFieldType::kInt64) {
result.set_field_type(pb::common::ScalarFieldType::INT64);
for (const auto& field : scalar_value.fields) {
result.add_fields()->set_long_data(field.long_data);
}
} else if (scalar_value.type == sdk::ScalarFieldType::kFloat32) {
result.set_field_type(pb::common::ScalarFieldType::FLOAT32);
for (const auto& field : scalar_value.fields) {
result.add_fields()->set_float_data(field.float_data);
}
} else if (scalar_value.type == sdk::ScalarFieldType::kDouble) {
result.set_field_type(pb::common::ScalarFieldType::DOUBLE);
for (const auto& field : scalar_value.fields) {
result.add_fields()->set_double_data(field.double_data);
}
} else if (scalar_value.type == sdk::ScalarFieldType::kString) {
result.set_field_type(pb::common::ScalarFieldType::STRING);
for (const auto& field : scalar_value.fields) {
result.add_fields()->set_string_data(field.string_data);
}
} else if (scalar_value.type == sdk::ScalarFieldType::kBytes) {
result.set_field_type(pb::common::ScalarFieldType::BYTES);
for (const auto& field : scalar_value.fields) {
result.add_fields()->set_bytes_data(field.bytes_data);
}
Expand All @@ -231,6 +278,18 @@ static pb::common::ScalarValue ScalarValue2InternalScalarValuePB(const sdk::Scal
return result;
}

static void FillScalarSchemaItem(pb::common::ScalarSchemaItem* pb, const VectorScalarColumnSchema& schema) {
pb->set_key(schema.key);
pb->set_field_type(ScalarFieldType2InternalScalarFieldTypePB(schema.type));
pb->set_enable_speed_up(schema.speed);
}

static void FillScalarSchema(pb::common::ScalarSchema* pb, const VectorScalarSchema& schema) {
for (const auto& col : schema.cols) {
FillScalarSchemaItem(pb->add_fields(), col);
}
}

static void FillVectorWithIdPB(pb::common::VectorWithId* pb, const VectorWithId& vector_with_id, bool with_id = true) {
if (with_id) {
pb->set_id(vector_with_id.id);
Expand Down
5 changes: 5 additions & 0 deletions src/sdk/vector/vector_index_creator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ VectorIndexCreator& VectorIndexCreator::SetAutoIncrement(bool auto_incr) {
return *this;
}

VectorIndexCreator& VectorIndexCreator::SetScalarSchema(const VectorScalarSchema& schema) {
data_->schema = schema;
return *this;
}

VectorIndexCreator& VectorIndexCreator::SetAutoIncrementStart(int64_t start_id) {
data_->auto_incr = true;
data_->auto_incr_start = start_id;
Expand Down
7 changes: 7 additions & 0 deletions src/sdk/vector/vector_index_creator_internal_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ class VectorIndexCreator::Data {
} else {
CHECK(false) << "unsupported index type, " << index_type;
}

if(schema.has_value()) {
VectorScalarSchema& s = schema.value();
FillScalarSchema(parameter->mutable_scalar_schema(), s);
}
}

const ClientStub& stub;
Expand All @@ -88,6 +93,8 @@ class VectorIndexCreator::Data {
bool auto_incr{false};
std::optional<int64_t> auto_incr_start;

std::optional<VectorScalarSchema> schema;

bool wait;
};

Expand Down
2 changes: 1 addition & 1 deletion test/unit_test/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ int main(int argc, char* argv[]) {
default_run_case += ":RegionTest.*";
default_run_case += ":StoreRpcControllerTest.*";
default_run_case += ":ThreadPoolActuatorTest.*";
default_run_case += ":VectorCommonTest.*";
default_run_case += ":SDKVectorCommonTest.*";
default_run_case += ":VectorIndexCacheKeyTest.*";
default_run_case += ":VectorIndexTest.*";
default_run_case += ":TxnBufferTest.*";
Expand Down
Loading

0 comments on commit 2d3b5d1

Please sign in to comment.