diff --git a/src/benchmark/dataset.cc b/src/benchmark/dataset.cc index 448dca214..45ec5186f 100644 --- a/src/benchmark/dataset.cc +++ b/src/benchmark/dataset.cc @@ -580,7 +580,7 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = item["id"].GetInt64() + 1; scalar_value.fields.push_back(field); @@ -589,7 +589,7 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = item["wiki_id"].GetInt64(); scalar_value.fields.push_back(field); @@ -598,7 +598,7 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = item["paragraph_id"].GetInt64(); scalar_value.fields.push_back(field); @@ -607,7 +607,7 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = item["langs"].GetInt64(); scalar_value.fields.push_back(field); @@ -616,7 +616,7 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = item["title"].GetString(); scalar_value.fields.push_back(field); @@ -624,7 +624,7 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect } { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = item["url"].GetString(); scalar_value.fields.push_back(field); @@ -632,7 +632,7 @@ bool Wikipedia2212Dataset::ParseTrainData(const rapidjson::Value& obj, sdk::Vect } { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = item["text"].GetString(); scalar_value.fields.push_back(field); @@ -674,35 +674,35 @@ Dataset::TestEntryPtr Wikipedia2212Dataset::ParseTestData(const rapidjson::Value const auto& value = kv[1]; if (key == "title") { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = value; scalar_value.fields.push_back(field); vector_with_id.scalar_data["title"] = scalar_value; } else if (key == "text") { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = value; scalar_value.fields.push_back(field); vector_with_id.scalar_data["text"] = scalar_value; } else if (key == "langs") { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = std::stoll(value); scalar_value.fields.push_back(field); vector_with_id.scalar_data["langs"] = scalar_value; } else if (key == "paragraph_id") { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = std::stoll(value); scalar_value.fields.push_back(field); vector_with_id.scalar_data["paragraph_id"] = scalar_value; } else if (key == "wiki_id") { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = std::stoll(value); scalar_value.fields.push_back(field); @@ -714,14 +714,14 @@ Dataset::TestEntryPtr Wikipedia2212Dataset::ParseTestData(const rapidjson::Value for (const auto& m : item["filter"].GetObject()) { if (m.value.IsString()) { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = m.value.GetString(); scalar_value.fields.push_back(field); vector_with_id.scalar_data[m.name.GetString()] = scalar_value; } else if (m.value.IsInt64()) { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = m.value.GetInt64(); scalar_value.fields.push_back(field); @@ -770,7 +770,7 @@ bool BeirBioasqDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorW { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = std::stoll(item["_id"].GetString()) + 1; scalar_value.fields.push_back(field); @@ -779,7 +779,7 @@ bool BeirBioasqDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorW { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = item["title"].GetString(); scalar_value.fields.push_back(field); @@ -788,7 +788,7 @@ bool BeirBioasqDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorW { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = item["text"].GetString(); scalar_value.fields.push_back(field); @@ -830,14 +830,14 @@ Dataset::TestEntryPtr BeirBioasqDataset::ParseTestData(const rapidjson::Value& o const auto& value = kv[1]; if (key == "title") { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = value; scalar_value.fields.push_back(field); vector_with_id.scalar_data["title"] = scalar_value; } else if (key == "text") { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = value; scalar_value.fields.push_back(field); @@ -849,14 +849,14 @@ Dataset::TestEntryPtr BeirBioasqDataset::ParseTestData(const rapidjson::Value& o for (const auto& m : item["filter"].GetObject()) { if (m.value.IsString()) { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = m.value.GetString(); scalar_value.fields.push_back(field); vector_with_id.scalar_data[m.name.GetString()] = scalar_value; } else if (m.value.IsInt64()) { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = m.value.GetInt64(); scalar_value.fields.push_back(field); @@ -916,7 +916,7 @@ bool MiraclDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorWithI { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = id; scalar_value.fields.push_back(field); @@ -925,7 +925,7 @@ bool MiraclDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorWithI { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = item["title"].GetString(); scalar_value.fields.push_back(field); @@ -934,7 +934,7 @@ bool MiraclDataset::ParseTrainData(const rapidjson::Value& obj, sdk::VectorWithI { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = item["text"].GetString(); scalar_value.fields.push_back(field); @@ -976,14 +976,14 @@ Dataset::TestEntryPtr MiraclDataset::ParseTestData(const rapidjson::Value& obj) const auto& value = kv[1]; if (key == "title") { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = value; scalar_value.fields.push_back(field); vector_with_id.scalar_data["title"] = scalar_value; } else if (key == "text") { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = value; scalar_value.fields.push_back(field); @@ -995,14 +995,14 @@ Dataset::TestEntryPtr MiraclDataset::ParseTestData(const rapidjson::Value& obj) for (const auto& m : item["filter"].GetObject()) { if (m.value.IsString()) { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kString; + scalar_value.type = sdk::Type::kSTRING; sdk::ScalarField field; field.string_data = m.value.GetString(); scalar_value.fields.push_back(field); vector_with_id.scalar_data[m.name.GetString()] = scalar_value; } else if (m.value.IsInt64()) { sdk::ScalarValue scalar_value; - scalar_value.type = sdk::ScalarFieldType::kInt64; + scalar_value.type = sdk::Type::kINT64; sdk::ScalarField field; field.long_data = m.value.GetInt64(); scalar_value.fields.push_back(field); diff --git a/src/example/sdk_vector_example.cc b/src/example/sdk_vector_example.cc index cbdb9c0f6..784acf521 100644 --- a/src/example/sdk_vector_example.cc +++ b/src/example/sdk_vector_example.cc @@ -22,6 +22,7 @@ #include "glog/logging.h" #include "sdk/client.h" #include "sdk/status.h" +#include "sdk/types.h" #include "sdk/vector.h" #include "sdk/vector/vector_common.h" #include "sdk/vector/vector_get_index_metrics_task.h" @@ -40,7 +41,10 @@ static int32_t g_dimension = 2; static dingodb::sdk::FlatParam g_flat_param(g_dimension, dingodb::sdk::MetricType::kL2); static std::vector g_vector_ids; static dingodb::sdk::VectorClient* g_vector_client; -static std::vector g_scalar_col{"id", "name"}; + +static const dingodb::sdk::Type kDefaultType = dingodb::sdk::Type::kINT64; +static std::vector g_scalar_col{"id", "fake_id"}; +static std::vector g_scalar_col_typ{kDefaultType, dingodb::sdk::Type::kDOUBLE}; static void PrepareVectorIndex() { dingodb::sdk::VectorIndexCreator* creator; @@ -51,7 +55,8 @@ static void PrepareVectorIndex() { dingodb::sdk::VectorScalarSchema schema; // NOTE: may be add more - schema.cols.push_back({g_scalar_col[0], dingodb::sdk::ScalarFieldType::kInt64, true}); + schema.cols.push_back({g_scalar_col[0], g_scalar_col_typ[0], true}); + schema.cols.push_back({g_scalar_col[1], g_scalar_col_typ[1], true}); Status create = creator->SetSchemaId(g_schema_id) .SetName(g_index_name) .SetReplicaNum(3) @@ -147,14 +152,26 @@ static void VectorAdd(bool use_index_name = false) { dingodb::sdk::VectorWithId tmp(id, std::move(tmp_vector)); { - dingodb::sdk::ScalarValue scalar_value; - scalar_value.type = dingodb::sdk::ScalarFieldType::kInt64; - - dingodb::sdk::ScalarField field; - field.long_data = id; - scalar_value.fields.push_back(field); - - tmp.scalar_data.insert(std::make_pair(g_scalar_col[0], scalar_value)); + { + dingodb::sdk::ScalarValue scalar_value; + scalar_value.type = kDefaultType; + + dingodb::sdk::ScalarField field; + field.long_data = id; + scalar_value.fields.push_back(field); + + tmp.scalar_data.insert(std::make_pair(g_scalar_col[0], scalar_value)); + } + { + dingodb::sdk::ScalarValue scalar_value; + scalar_value.type = dingodb::sdk::kDOUBLE; + + dingodb::sdk::ScalarField field; + field.double_data = id; + scalar_value.fields.push_back(field); + + tmp.scalar_data.insert(std::make_pair(g_scalar_col[1], scalar_value)); + } } vectors.push_back(std::move(tmp)); @@ -235,14 +252,15 @@ static void VectorSearchUseExpr(bool use_index_name = false) { init = init + 0.1; } - dingodb::sdk::SearchParam param; { - param.topk = 100; - param.with_scalar_data = true; - param.extra_params.insert(std::make_pair(dingodb::sdk::kParallelOnQueries, 10)); + dingodb::sdk::SearchParam param; + { + param.topk = 100; + param.with_scalar_data = true; + param.extra_params.insert(std::make_pair(dingodb::sdk::kParallelOnQueries, 10)); - std::string json_str = - R"({ + std::string json_str = + R"({ "type": "operator", "operator": "and", "arguments": [ @@ -264,28 +282,132 @@ static void VectorSearchUseExpr(bool use_index_name = false) { } )"; - param.langchain_expr_json = json_str; - } + param.langchain_expr_json = json_str; + } - Status tmp; - std::vector result; - if (use_index_name) { - tmp = g_vector_client->SearchByIndexName(g_schema_id, g_index_name, param, target_vectors, result); - } else { - tmp = g_vector_client->SearchByIndexId(g_index_id, param, target_vectors, result); + Status tmp; + std::vector result; + if (use_index_name) { + tmp = g_vector_client->SearchByIndexName(g_schema_id, g_index_name, param, target_vectors, result); + } else { + tmp = g_vector_client->SearchByIndexId(g_index_id, param, target_vectors, result); + } + + DINGO_LOG(INFO) << "vector search expr status: " << tmp.ToString(); + for (const auto& r : result) { + DINGO_LOG(INFO) << "vector search expr result: " << r.ToString(); + } + + for (auto& search_result : result) { + for (auto& distance : search_result.vector_datas) { + const auto& vector_id = distance.vector_data.id; + CHECK_GE(vector_id, 5); + CHECK_LT(vector_id, 20); + } + } } - DINGO_LOG(INFO) << "vector search expr status: " << tmp.ToString(); - for (const auto& r : result) { - DINGO_LOG(INFO) << "vector search expr result: " << r.ToString(); + { + // schema type convert + dingodb::sdk::SearchParam param; + { + param.topk = 100; + param.with_scalar_data = true; + param.extra_params.insert(std::make_pair(dingodb::sdk::kParallelOnQueries, 10)); + + // fake_id schema type is double, int64 can convert to double + std::string json_str = + R"({ + "type": "operator", + "operator": "and", + "arguments": [ + { + "type": "comparator", + "comparator": "gte", + "attribute": "fake_id", + "value": 5, + "value_type": "INT64" + }, + { + "type": "comparator", + "comparator": "lt", + "attribute": "fake_id", + "value": 20, + "value_type": "INT64" + } + ] + } + )"; + + param.langchain_expr_json = json_str; + } + + Status tmp; + std::vector result; + if (use_index_name) { + tmp = g_vector_client->SearchByIndexName(g_schema_id, g_index_name, param, target_vectors, result); + } else { + tmp = g_vector_client->SearchByIndexId(g_index_id, param, target_vectors, result); + } + + DINGO_LOG(INFO) << "vector search expr with schema convert status: " << tmp.ToString(); + for (const auto& r : result) { + DINGO_LOG(INFO) << "vector search expr with schema convert result: " << r.ToString(); + } + + for (auto& search_result : result) { + for (auto& distance : search_result.vector_datas) { + const auto& vector_id = distance.vector_data.id; + CHECK_GE(vector_id, 5); + CHECK_LT(vector_id, 20); + } + } } + { + // schema type convert + dingodb::sdk::SearchParam param; + { + param.topk = 100; + param.with_scalar_data = true; + param.extra_params.insert(std::make_pair(dingodb::sdk::kParallelOnQueries, 10)); + + // id schema type is int64, double can't convert to int64 + std::string json_str = + R"({ + "type": "operator", + "operator": "and", + "arguments": [ + { + "type": "comparator", + "comparator": "gte", + "attribute": "id", + "value": 5, + "value_type": "DOUBLE" + }, + { + "type": "comparator", + "comparator": "lt", + "attribute": "id", + "value": 20, + "value_type": "DOUBLE" + } + ] + } + )"; + + param.langchain_expr_json = json_str; + } - for (auto& search_result : result) { - for (auto& distance : search_result.vector_datas) { - const auto& vector_id = distance.vector_data.id; - CHECK_GE(vector_id, 5); - CHECK_LT(vector_id, 20); + Status tmp; + std::vector result; + if (use_index_name) { + tmp = g_vector_client->SearchByIndexName(g_schema_id, g_index_name, param, target_vectors, result); + } else { + tmp = g_vector_client->SearchByIndexId(g_index_id, param, target_vectors, result); } + + DINGO_LOG(INFO) << "vector search expr with schema can't convert status: " << tmp.ToString(); + CHECK(!tmp.ok()); } } @@ -398,7 +520,7 @@ static void VectorScanQuery(bool use_index_name = false) { int64_t filter_id = 5; { dingodb::sdk::ScalarValue scalar_value; - scalar_value.type = dingodb::sdk::ScalarFieldType::kInt64; + scalar_value.type = kDefaultType; dingodb::sdk::ScalarField field; field.long_data = filter_id; diff --git a/src/sdk/CMakeLists.txt b/src/sdk/CMakeLists.txt index 00cc7d9c3..10b46f959 100644 --- a/src/sdk/CMakeLists.txt +++ b/src/sdk/CMakeLists.txt @@ -49,7 +49,6 @@ add_library(sdk expression/langchain_expr_encoder.cc expression/langchain_expr_factory.cc expression/langchain_expr.cc - expression/types.cc # TODO: use libary ${PROJECT_SOURCE_DIR}/src/coordinator/coordinator_interaction.cc ${PROJECT_SOURCE_DIR}/src/common/role.cc diff --git a/src/sdk/expression/common.h b/src/sdk/expression/common.h deleted file mode 100644 index d9ef71d53..000000000 --- a/src/sdk/expression/common.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2023 dingodb.com, Inc. All Rights Reserved -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef DINGODB_SDK_EXPRESSION_COMMON_H_ -#define DINGODB_SDK_EXPRESSION_COMMON_H_ - -#include "glog/logging.h" -#include "proto/common.pb.h" -#include "sdk/expression/types.h" - -namespace dingodb { -namespace sdk { -namespace expression { - -static pb::common::Schema::Type Type2InternalSchemaTypePB(Type type) { - switch (type) { - case STRING: - return pb::common::Schema::STRING; - case DOUBLE: - return pb::common::Schema::DOUBLE; - case BOOL: - return pb::common::Schema::BOOL; - case INT64: - return pb::common::Schema::LONG; - default: - CHECK(false) << "Unimplement convert type: " << type; - } -} - -} // namespace expression -} // namespace sdk -} // namespace dingodb -#endif // DINGODB_SDK_EXPRESSION_COMMON_H_ diff --git a/src/sdk/expression/langchain_expr.cc b/src/sdk/expression/langchain_expr.cc index 5fd472dff..adf34a00f 100644 --- a/src/sdk/expression/langchain_expr.cc +++ b/src/sdk/expression/langchain_expr.cc @@ -16,8 +16,8 @@ #include +#include "glog/logging.h" #include "sdk/expression/langchain_expr_visitor.h" -#include "sdk/expression/types.h" namespace dingodb { namespace sdk { @@ -120,20 +120,22 @@ std::any Val::Accept(LangchainExprVisitor* visitor, void* target) { return visit std::string Val::ToString() const { std::ostringstream oss; - oss << "Val(Type: " << TypeToString(type); + oss << "Val(Type: " << TypeToString(type) << ", Name: " + name; switch (type) { - case STRING: - oss << ", Value: " << std::any_cast>(value); + case kBOOL: + oss << ", Value: " << std::any_cast>(value); break; - case DOUBLE: - oss << ", Value: " << std::any_cast>(value); + case kINT64: + oss << ", Value: " << std::any_cast>(value); break; - case BOOL: - oss << ", Value: " << std::any_cast>(value); + case kDOUBLE: + oss << ", Value: " << std::any_cast>(value); break; - case INT64: - oss << ", Value: " << std::any_cast>(value); + case kSTRING: + oss << ", Value: " << std::any_cast>(value); break; + default: + CHECK(false) << "Unknown type: " << static_cast(type); } oss << ")"; diff --git a/src/sdk/expression/langchain_expr.h b/src/sdk/expression/langchain_expr.h index edc96a17a..9d9fc7d41 100644 --- a/src/sdk/expression/langchain_expr.h +++ b/src/sdk/expression/langchain_expr.h @@ -24,7 +24,7 @@ #include #include "common/logging.h" -#include "sdk/expression/types.h" +#include "sdk/types.h" namespace dingodb { namespace sdk { @@ -176,9 +176,9 @@ class Var : public LangchainExpr { Type type; }; -class Val : public LangchainExpr { +class Val : public Var { public: - Val(std::any value, Type type) : value(value), type(type) {} + Val(std::string name, Type type, std::any value) : Var(std::move(name), type), value(value) {} ~Val() override = default; @@ -187,7 +187,6 @@ class Val : public LangchainExpr { std::string ToString() const override; std::any value; - Type type; }; } // namespace expression diff --git a/src/sdk/expression/langchain_expr_encoder.cc b/src/sdk/expression/langchain_expr_encoder.cc index 98e7e4a4e..add5c9cf0 100644 --- a/src/sdk/expression/langchain_expr_encoder.cc +++ b/src/sdk/expression/langchain_expr_encoder.cc @@ -22,10 +22,9 @@ #include "glog/logging.h" #include "sdk/common/param_config.h" #include "sdk/expression/coding.h" -#include "sdk/expression/common.h" #include "sdk/expression/encodes.h" #include "sdk/expression/langchain_expr.h" -#include "sdk/expression/types.h" +#include "sdk/types_util.h" namespace dingodb { namespace sdk { @@ -91,7 +90,7 @@ std::any LangChainExprEncoder::VisitOrOperatorExpr(OrOperatorExpr* expr, void* t } std::any LangChainExprEncoder::VisitNotOperatorExpr(NotOperatorExpr* expr, void* target) { - // TODO: check + CHECK_EQ(expr->args.size(), 1); Visit(expr->args[0].get(), target); std::string* dst = static_cast(target); dst->append(sizeof(NOT), NOT); @@ -101,14 +100,14 @@ std::any LangChainExprEncoder::VisitNotOperatorExpr(NotOperatorExpr* expr, void* Byte GetEncode(Type type) { switch (type) { - case STRING: - return TYPE_STRING; - case DOUBLE: - return TYPE_DOUBLE; - case BOOL: + case kBOOL: return TYPE_BOOL; - case INT64: + case kINT64: return TYPE_INT64; + case kDOUBLE: + return TYPE_DOUBLE; + case kSTRING: + return TYPE_STRING; default: CHECK(false) << "unknown type: " << static_cast(type); } @@ -244,21 +243,21 @@ static void Encode(int64_t value, std::string* dst) { std::any LangChainExprEncoder::VisitVal(Val* expr, void* target) { std::string* dst = static_cast(target); switch (expr->type) { - case STRING: { - std::string v = std::any_cast>(expr->value); + case kSTRING: { + std::string v = std::any_cast>(expr->value); Encode(v, dst); break; } - case DOUBLE: { - Encode(std::any_cast>(expr->value), dst); + case kDOUBLE: { + Encode(std::any_cast>(expr->value), dst); break; } - case BOOL: { - Encode(std::any_cast>(expr->value), dst); + case kBOOL: { + Encode(std::any_cast>(expr->value), dst); break; } - case INT64: { - Encode(std::any_cast>(expr->value), dst); + case kINT64: { + Encode(std::any_cast>(expr->value), dst); break; } default: diff --git a/src/sdk/expression/langchain_expr_encoder.h b/src/sdk/expression/langchain_expr_encoder.h index e4e3a1fa6..8aa414c6d 100644 --- a/src/sdk/expression/langchain_expr_encoder.h +++ b/src/sdk/expression/langchain_expr_encoder.h @@ -22,7 +22,7 @@ #include "proto/common.pb.h" #include "sdk/expression/langchain_expr.h" #include "sdk/expression/langchain_expr_visitor.h" -#include "sdk/expression/types.h" +#include "sdk/types.h" namespace dingodb { namespace sdk { diff --git a/src/sdk/expression/langchain_expr_factory.cc b/src/sdk/expression/langchain_expr_factory.cc index 1358c9af2..b11ab70c8 100644 --- a/src/sdk/expression/langchain_expr_factory.cc +++ b/src/sdk/expression/langchain_expr_factory.cc @@ -15,17 +15,22 @@ #include "sdk/expression/langchain_expr_factory.h" #include +#include +#include "fmt/core.h" #include "sdk/common/param_config.h" #include "sdk/expression/langchain_expr.h" #include "sdk/status.h" +#include "sdk/types.h" +#include "sdk/types_util.h" namespace dingodb { namespace sdk { namespace expression { namespace { -Status CreateOperatorExpr(const nlohmann::json& j, std::shared_ptr& expr) { +Status CreateOperatorExpr(LangchainExprFactory* expr_factory, const nlohmann::json& j, + std::shared_ptr& expr) { std::shared_ptr tmp; std::string operator_type = j.at("operator"); @@ -43,7 +48,7 @@ Status CreateOperatorExpr(const nlohmann::json& j, std::shared_ptr sub_expr; - DINGO_RETURN_NOT_OK(LangchainExprFactory::CreateExpr(json_str, sub_expr)); + DINGO_RETURN_NOT_OK(expr_factory->CreateExpr(json_str, sub_expr)); tmp->AddArgument(sub_expr); } @@ -51,7 +56,8 @@ Status CreateOperatorExpr(const nlohmann::json& j, std::shared_ptr& expr) { +Status CreateComparatorExpr(LangchainExprFactory* expr_factory, const nlohmann::json& j, + std::shared_ptr& expr) { std::shared_ptr tmp; std::string comparator_type = j.at("comparator"); @@ -74,22 +80,42 @@ Status CreateComparatorExpr(const nlohmann::json& j, std::shared_ptrvar = std::make_shared(std::move(name), STRING); - tmp->val = std::make_shared(j.at("value").get(), STRING); + type = kSTRING; } else if (value_type == "INT64") { - tmp->var = std::make_shared(std::move(name), INT64); - tmp->val = std::make_shared(j.at("value").get(), INT64); + type = kINT64; } else if (value_type == "DOUBLE") { - tmp->var = std::make_shared(std::move(name), DOUBLE); - tmp->val = std::make_shared(j.at("value").get(), DOUBLE); + type = kDOUBLE; } else if (value_type == "BOOL") { - tmp->var = std::make_shared(std::move(name), BOOL); - tmp->val = std::make_shared(j.at("value").get(), BOOL); + type = kBOOL; } else { return Status::InvalidArgument("Unknown value type: " + value_type); } + DINGO_RETURN_NOT_OK(expr_factory->MaybeRemapType(name, type)); + + switch (type) { + case kBOOL: + tmp->var = std::make_shared(name, kBOOL); + tmp->val = std::make_shared(name, kBOOL, j.at("value").get>()); + break; + case kINT64: + tmp->var = std::make_shared(name, kINT64); + tmp->val = std::make_shared(name, kINT64, j.at("value").get>()); + break; + case kDOUBLE: + tmp->var = std::make_shared(name, kDOUBLE); + tmp->val = std::make_shared(name, kDOUBLE, j.at("value").get>()); + break; + case kSTRING: + tmp->var = std::make_shared(name, kSTRING); + tmp->val = std::make_shared(name, kSTRING, j.at("value").get>()); + break; + default: + CHECK(false) << "Unknown value type: " << value_type; + } + expr = std::move(tmp); return Status::OK(); } @@ -102,9 +128,9 @@ Status LangchainExprFactory::CreateExpr(const std::string& expr_json_str, std::s nlohmann::json j = nlohmann::json::parse(expr_json_str); std::string type = j.at("type"); if (type == "operator") { - DINGO_RETURN_NOT_OK(CreateOperatorExpr(j, tmp)); + DINGO_RETURN_NOT_OK(CreateOperatorExpr(this, j, tmp)); } else if (type == "comparator") { - DINGO_RETURN_NOT_OK(CreateComparatorExpr(j, tmp)); + DINGO_RETURN_NOT_OK(CreateComparatorExpr(this, j, tmp)); } else { return Status::InvalidArgument("Unknown expression type: " + type); } @@ -114,6 +140,41 @@ Status LangchainExprFactory::CreateExpr(const std::string& expr_json_str, std::s VLOG(kSdkVlogLevel) << "expr_json_str: " << expr_json_str << " expr: " << expr->ToString(); return Status::OK(); } + +Status LangchainExprFactory::MaybeRemapType(const std::string& name, Type& type) { + (void)name; + (void)type; + return Status::OK(); +} + +SchemaLangchainExprFactory::SchemaLangchainExprFactory(const pb::common::ScalarSchema& schema) { + for (const auto& schema_item : schema.fields()) { + CHECK(attribute_type_ + .insert(std::make_pair(schema_item.key(), InternalScalarFieldTypePB2Type(schema_item.field_type()))) + .second); + } +} + +Status SchemaLangchainExprFactory::MaybeRemapType(const std::string& name, Type& type) { + auto iter = attribute_type_.find(name); + if (iter != attribute_type_.end()) { + Type schema_type = iter->second; + if (kTypeConversionMatrix[type][schema_type]) { + type = schema_type; + } else { + std::string err_msg = fmt::format("attribute: {}, type: {}, can't convert to schema type: {}", name, + TypeToString(type), TypeToString(schema_type)); + DINGO_LOG(WARNING) << err_msg; + return Status::InvalidArgument(err_msg); + } + } else { + // TODO: if not found in schema, should we return not ok? + DINGO_LOG(INFO) << "attribute: " << name << " type:" << TypeToString(type) << " not found in schema"; + } + + return Status::OK(); +} + } // namespace expression } // namespace sdk diff --git a/src/sdk/expression/langchain_expr_factory.h b/src/sdk/expression/langchain_expr_factory.h index e7bdb0e34..3ab4f49c3 100644 --- a/src/sdk/expression/langchain_expr_factory.h +++ b/src/sdk/expression/langchain_expr_factory.h @@ -27,9 +27,23 @@ namespace expression { class LangchainExprFactory { public: LangchainExprFactory() = default; - ~LangchainExprFactory() = default; + virtual ~LangchainExprFactory() = default; - static Status CreateExpr(const std::string& expr_json_str, std::shared_ptr& expr); + Status CreateExpr(const std::string& expr_json_str, std::shared_ptr& expr); + + virtual Status MaybeRemapType(const std::string& name, Type& type); +}; + +class SchemaLangchainExprFactory : public LangchainExprFactory { + public: + SchemaLangchainExprFactory(const pb::common::ScalarSchema& schema); + + ~SchemaLangchainExprFactory() override = default; + + Status MaybeRemapType(const std::string& name, Type& type) override; + + private: + std::unordered_map attribute_type_; }; } // namespace expression diff --git a/src/sdk/expression/types.cc b/src/sdk/expression/types.cc deleted file mode 100644 index 168864af8..000000000 --- a/src/sdk/expression/types.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2023 dingodb.com, Inc. All Rights Reserved -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "sdk/expression/types.h" - -namespace dingodb { -namespace sdk { -namespace expression { - -std::string TypeToString(Type type) { - switch (type) { - case STRING: - return "kString"; - case DOUBLE: - return "kDouble"; - case BOOL: - return "kBool"; - case INT64: - return "kInt64"; - default: - return "Unknown ValueType"; - } -} - -} // namespace expression -} // namespace sdk -} // namespace dingodb \ No newline at end of file diff --git a/src/sdk/expression/types.h b/src/sdk/types.h similarity index 50% rename from src/sdk/expression/types.h rename to src/sdk/types.h index 6609e44f5..bf48bbc96 100644 --- a/src/sdk/expression/types.h +++ b/src/sdk/types.h @@ -1,3 +1,4 @@ + // Copyright (c) 2023 dingodb.com, Inc. All Rights Reserved // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,49 +13,73 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef DINGODB_SDK_EXPRESSION_TYPES_H_ -#define DINGODB_SDK_EXPRESSION_TYPES_H_ +#ifndef DINGODB_SDK_TYPES_H_ +#define DINGODB_SDK_TYPES_H_ #include #include namespace dingodb { namespace sdk { -namespace expression { -enum Type : uint8_t { STRING, DOUBLE, BOOL, INT64 }; +enum Type : uint8_t { + kBOOL = 0, + kINT64 = 1, + kDOUBLE = 2, + kSTRING = 3, + // This must be the last line + kTypeEnd +}; + +static const bool kTypeConversionMatrix[kTypeEnd][kTypeEnd] = { + {true, false, false, false}, // kBOOL can be converted to kBOOL + {false, true, true, false}, // kINT64 can be converted to kINT64, kDOUBLE + {false, false, true, false}, // kDOUBLE can be converted to kDOUBLE + {false, false, false, true} // kSTRING can be converted to kSTRING +}; -std::string TypeToString(Type type); +inline std::string TypeToString(Type type) { + switch (type) { + case kBOOL: + return "Bool"; + case kINT64: + return "Int64"; + case kDOUBLE: + return "Double"; + case kSTRING: + return "String"; + default: + return "Unknown ValueType"; + } +} template struct TypeTraits; template <> -struct TypeTraits { - using Type = std::string; +struct TypeTraits { + using Type = bool; }; template <> -struct TypeTraits { - using Type = double; +struct TypeTraits { + using Type = int64_t; }; template <> -struct TypeTraits { - using Type = bool; +struct TypeTraits { + using Type = double; }; template <> -struct TypeTraits { - using Type = int64_t; +struct TypeTraits { + using Type = std::string; }; template using TypeOf = typename TypeTraits::Type; -} // namespace expression } // namespace sdk - } // namespace dingodb -#endif // DINGODB_SDK_EXPRESSION_TYPES_H_ \ No newline at end of file +#endif // DINGODB_SDK_TYPES_H_ \ No newline at end of file diff --git a/src/sdk/types_util.h b/src/sdk/types_util.h new file mode 100644 index 000000000..2fc96d45f --- /dev/null +++ b/src/sdk/types_util.h @@ -0,0 +1,87 @@ +// Copyright (c) 2023 dingodb.com, Inc. All Rights Reserved +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DINGODB_SDK_TYPES_UTIL_H_ +#define DINGODB_SDK_TYPES_UTIL_H_ + +#include "glog/logging.h" +#include "proto/common.pb.h" +#include "sdk/types.h" + +namespace dingodb { +namespace sdk { + +static pb::common::Schema::Type Type2InternalSchemaTypePB(Type type) { + switch (type) { + case kBOOL: + return pb::common::Schema::BOOL; + case kINT64: + return pb::common::Schema::LONG; + case kDOUBLE: + return pb::common::Schema::DOUBLE; + case kSTRING: + return pb::common::Schema::STRING; + default: + CHECK(false) << "Unimplement convert type: " << type; + } +} + +static pb::common::ScalarFieldType Type2InternalScalarFieldTypePB(Type type) { + switch (type) { + case kBOOL: + return pb::common::ScalarFieldType::BOOL; + case kINT64: + return pb::common::ScalarFieldType::INT64; + case kDOUBLE: + return pb::common::ScalarFieldType::DOUBLE; + case kSTRING: + return pb::common::ScalarFieldType::STRING; + default: + CHECK(false) << "Unimplement convert type: " << type; + } +} + +static Type InternalSchemaTypePB2Type(pb::common::Schema::Type type) { + switch (type) { + case pb::common::Schema::BOOL: + return kBOOL; + case pb::common::Schema::LONG: + return kINT64; + case pb::common::Schema::DOUBLE: + return kDOUBLE; + case pb::common::Schema::STRING: + return kSTRING; + default: + CHECK(false) << "unsupported schema type:" << pb::common::Schema::Type_Name(type); + } +} + +static Type InternalScalarFieldTypePB2Type(pb::common::ScalarFieldType type) { + switch (type) { + case pb::common::ScalarFieldType::BOOL: + return kBOOL; + case pb::common::ScalarFieldType::INT64: + return kINT64; + case pb::common::ScalarFieldType::DOUBLE: + return kDOUBLE; + case pb::common::ScalarFieldType::STRING: + return kSTRING; + default: + CHECK(false) << "unsupported scalar field type:" << pb::common::ScalarFieldType_Name(type); + } +} + +} // namespace sdk +} // namespace dingodb +#endif // DINGODB_SDK_TYPES_UTIL_H_ diff --git a/src/sdk/vector.h b/src/sdk/vector.h index 7cd1c18ed..12649b6dd 100644 --- a/src/sdk/vector.h +++ b/src/sdk/vector.h @@ -22,6 +22,7 @@ #include #include "sdk/status.h" +#include "sdk/types.h" namespace dingodb { namespace sdk { @@ -121,26 +122,12 @@ struct BruteForceParam { static VectorIndexType Type() { return VectorIndexType::kBruteForce; } }; -enum class ScalarFieldType : uint8_t { - kNone, - kBool, - kInt8, - kInt16, - kInt32, - kInt64, - kFloat32, - kDouble, - kString, - kBytes -}; - struct VectorScalarColumnSchema { std::string key; - ScalarFieldType type; + Type type; bool speed; - VectorScalarColumnSchema(std::string& key, ScalarFieldType type, bool speed = false) - : key(key), type(type), speed(speed) {} + VectorScalarColumnSchema(std::string& key, Type type, bool speed = false) : key(key), type(type), speed(speed) {} }; // TODO: maybe use builder to build VectorScalarSchema @@ -186,18 +173,16 @@ struct Vector { std::string ToString() const; }; +// TODO: maybe use std::variant, when swig support struct ScalarField { bool bool_data; - int32_t int_data; int64_t long_data; - float float_data; double double_data; std::string string_data; - std::string bytes_data; }; struct ScalarValue { - ScalarFieldType type; + Type type; std::vector fields; }; diff --git a/src/sdk/vector/vector_common.h b/src/sdk/vector/vector_common.h index 3eebd0695..7275b1e38 100644 --- a/src/sdk/vector/vector_common.h +++ b/src/sdk/vector/vector_common.h @@ -20,6 +20,8 @@ #include "glog/logging.h" #include "proto/common.pb.h" #include "proto/meta.pb.h" +#include "sdk/types.h" +#include "sdk/types_util.h" #include "sdk/vector.h" #include "vector/codec.h" @@ -100,60 +102,6 @@ static VectorIndexType InternalVectorIndexTypePB2VectorIndexType(pb::common::Vec } } -static pb::common::ScalarFieldType ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType type) { - switch (type) { - case ScalarFieldType::kNone: - return pb::common::ScalarFieldType::NONE; - case ScalarFieldType::kBool: - return pb::common::ScalarFieldType::BOOL; - case ScalarFieldType::kInt8: - return pb::common::ScalarFieldType::INT8; - case ScalarFieldType::kInt16: - return pb::common::ScalarFieldType::INT16; - case ScalarFieldType::kInt32: - return pb::common::ScalarFieldType::INT32; - case ScalarFieldType::kInt64: - return pb::common::ScalarFieldType::INT64; - case ScalarFieldType::kFloat32: - return pb::common::ScalarFieldType::FLOAT32; - case ScalarFieldType::kDouble: - return pb::common::ScalarFieldType::DOUBLE; - case ScalarFieldType::kString: - return pb::common::ScalarFieldType::STRING; - case ScalarFieldType::kBytes: - return pb::common::ScalarFieldType::BYTES; - default: - CHECK(false) << "unsupported scalar field type:" << static_cast(type); - } -} - -static ScalarFieldType InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType type) { - switch (type) { - case pb::common::ScalarFieldType::NONE: - return ScalarFieldType::kNone; - case pb::common::ScalarFieldType::BOOL: - return ScalarFieldType::kBool; - case pb::common::ScalarFieldType::INT8: - return ScalarFieldType::kInt8; - case pb::common::ScalarFieldType::INT16: - return ScalarFieldType::kInt16; - case pb::common::ScalarFieldType::INT32: - return ScalarFieldType::kInt32; - case pb::common::ScalarFieldType::INT64: - return ScalarFieldType::kInt64; - case pb::common::ScalarFieldType::FLOAT32: - return ScalarFieldType::kFloat32; - case pb::common::ScalarFieldType::DOUBLE: - return ScalarFieldType::kDouble; - case pb::common::ScalarFieldType::STRING: - return ScalarFieldType::kString; - case pb::common::ScalarFieldType::BYTES: - return ScalarFieldType::kBytes; - default: - CHECK(false) << "unsupported scalar field type:" << pb::common::ScalarFieldType_Name(type); - } -} - static void FillFlatParmeter(pb::common::VectorIndexParameter* parameter, const FlatParam& param) { parameter->set_vector_index_type(pb::common::VECTOR_INDEX_TYPE_FLAT); auto* flat = parameter->mutable_flat_parameter(); @@ -235,44 +183,26 @@ static pb::common::ValueType ValueType2InternalValueTypePB(ValueType value_type) static pb::common::ScalarValue ScalarValue2InternalScalarValuePB(const sdk::ScalarValue& scalar_value) { pb::common::ScalarValue result; - result.set_field_type(ScalarFieldType2InternalScalarFieldTypePB(scalar_value.type)); + result.set_field_type(Type2InternalScalarFieldTypePB(scalar_value.type)); - if (scalar_value.type == sdk::ScalarFieldType::kBool) { + if (scalar_value.type == kBOOL) { for (const auto& field : scalar_value.fields) { result.add_fields()->set_bool_data(field.bool_data); } - } else if (scalar_value.type == sdk::ScalarFieldType::kInt8) { - for (const auto& field : scalar_value.fields) { - result.add_fields()->set_int_data(field.int_data); - } - } else if (scalar_value.type == sdk::ScalarFieldType::kInt16) { - for (const auto& field : scalar_value.fields) { - result.add_fields()->set_int_data(field.int_data); - } - } else if (scalar_value.type == sdk::ScalarFieldType::kInt32) { - for (const auto& field : scalar_value.fields) { - result.add_fields()->set_int_data(field.int_data); - } - } else if (scalar_value.type == sdk::ScalarFieldType::kInt64) { + } else if (scalar_value.type == kINT64) { for (const auto& field : scalar_value.fields) { result.add_fields()->set_long_data(field.long_data); } - } else if (scalar_value.type == sdk::ScalarFieldType::kFloat32) { - for (const auto& field : scalar_value.fields) { - result.add_fields()->set_float_data(field.float_data); - } - } else if (scalar_value.type == sdk::ScalarFieldType::kDouble) { + } else if (scalar_value.type == kDOUBLE) { for (const auto& field : scalar_value.fields) { result.add_fields()->set_double_data(field.double_data); } - } else if (scalar_value.type == sdk::ScalarFieldType::kString) { + } else if (scalar_value.type == kSTRING) { for (const auto& field : scalar_value.fields) { result.add_fields()->set_string_data(field.string_data); } - } else if (scalar_value.type == sdk::ScalarFieldType::kBytes) { - for (const auto& field : scalar_value.fields) { - result.add_fields()->set_bytes_data(field.bytes_data); - } + } else { + LOG(WARNING) << "unsupported scalar value type:" << scalar_value.type; } return result; @@ -280,7 +210,7 @@ static pb::common::ScalarValue ScalarValue2InternalScalarValuePB(const sdk::Scal static void FillScalarSchemaItem(pb::common::ScalarSchemaItem* pb, const VectorScalarColumnSchema& schema) { pb->set_key(schema.key); - pb->set_field_type(ScalarFieldType2InternalScalarFieldTypePB(schema.type)); + pb->set_field_type(Type2InternalScalarFieldTypePB(schema.type)); pb->set_enable_speed_up(schema.speed); } diff --git a/src/sdk/vector/vector_index.cc b/src/sdk/vector/vector_index.cc index 80444a184..39dd4ba33 100644 --- a/src/sdk/vector/vector_index.cc +++ b/src/sdk/vector/vector_index.cc @@ -74,6 +74,21 @@ const pb::common::Range& VectorIndex::GetPartitionRange(int64_t part_id) const { return iter->second; } +bool VectorIndex::HasScalarSchema() const { + if (index_def_with_id_.index_definition().index_parameter().vector_index_parameter().has_scalar_schema()) { + const auto& schema = + index_def_with_id_.index_definition().index_parameter().vector_index_parameter().scalar_schema(); + if (schema.fields_size() > 0) { + return true; + } + } + return false; +} + +const pb::common::ScalarSchema& VectorIndex::GetScalarSchema() const { + return index_def_with_id_.index_definition().index_parameter().vector_index_parameter().scalar_schema(); +} + std::string VectorIndex::ToString(bool verbose) const { std::ostringstream oss; for (const auto& [start_key, part_id] : start_key_to_part_id_) { diff --git a/src/sdk/vector/vector_index.h b/src/sdk/vector/vector_index.h index 4416b9064..f06337a23 100644 --- a/src/sdk/vector/vector_index.h +++ b/src/sdk/vector/vector_index.h @@ -50,6 +50,11 @@ class VectorIndex { bool IsStale() { return stale_.load(std::memory_order_relaxed); } + bool HasScalarSchema() const; + + // the caller must make sure the vector index is not destroyed and HasScalarSchema + const pb::common::ScalarSchema& GetScalarSchema() const; + std::string ToString(bool verbose = false) const; private: diff --git a/src/sdk/vector/vector_param.cc b/src/sdk/vector/vector_param.cc index 9961c06cd..df5de21dc 100644 --- a/src/sdk/vector/vector_param.cc +++ b/src/sdk/vector/vector_param.cc @@ -48,33 +48,6 @@ std::string Vector::ToString() const { ValueTypeToString(value_type), float_ss.str(), binary_ss.str()); } -std::string ScalarFieldToString(ScalarFieldType type) { - switch (type) { - case ScalarFieldType::kNone: - return "None"; - case ScalarFieldType::kBool: - return "Bool"; - case ScalarFieldType::kInt8: - return "Int8"; - case ScalarFieldType::kInt16: - return "Int16"; - case ScalarFieldType::kInt32: - return "Int32"; - case ScalarFieldType::kInt64: - return "Int64"; - case ScalarFieldType::kFloat32: - return "Float32"; - case ScalarFieldType::kDouble: - return "Double"; - case ScalarFieldType::kString: - return "String"; - case ScalarFieldType::kBytes: - return "Bytes"; - default: - return "Unknown"; - } -} - std::string VectorWithId::ToString() const { return fmt::format("VectorWithId {{ id: {}, vector: {} }}", id, vector.ToString()); } diff --git a/src/sdk/vector/vector_search_task.cc b/src/sdk/vector/vector_search_task.cc index 272eeb972..081063245 100644 --- a/src/sdk/vector/vector_search_task.cc +++ b/src/sdk/vector/vector_search_task.cc @@ -59,7 +59,14 @@ Status VectorSearchTask::Init() { FillInternalSearchParams(&search_parameter_, vector_index_->GetVectorIndexType(), search_param_); if (!search_param_.langchain_expr_json.empty()) { std::shared_ptr expr; - DINGO_RETURN_NOT_OK(expression::LangchainExprFactory::CreateExpr(search_param_.langchain_expr_json, expr)); + + std::unique_ptr expr_factory; + if (vector_index_->HasScalarSchema()) { + expr_factory = std::make_unique(vector_index_->GetScalarSchema()); + } else { + expr_factory = std::make_unique(); + } + DINGO_RETURN_NOT_OK(expr_factory->CreateExpr(search_param_.langchain_expr_json, expr)); expression::LangChainExprEncoder encoder; *(search_parameter_.mutable_vector_coprocessor()) = encoder.EncodeToCoprocessor(expr.get()); diff --git a/test/unit_test/sdk/expression/test_langchain_expr_encoder.cc b/test/unit_test/sdk/expression/test_langchain_expr_encoder.cc index 1088406ba..6ec4755cd 100644 --- a/test/unit_test/sdk/expression/test_langchain_expr_encoder.cc +++ b/test/unit_test/sdk/expression/test_langchain_expr_encoder.cc @@ -32,6 +32,11 @@ class SDKLangChainExprEncoder : public ::testing::Test { std::shared_ptr encoder; }; +static Status CreateExpr(const std::string& json_str, std::shared_ptr& expr) { + LangchainExprFactory expr_factory; + return expr_factory.CreateExpr(json_str, expr); +} + TEST_F(SDKLangChainExprEncoder, DoubleGt) { std::string json_str = R"({ @@ -44,7 +49,7 @@ TEST_F(SDKLangChainExprEncoder, DoubleGt) { )"; std::shared_ptr expr; - Status s = LangchainExprFactory::CreateExpr(json_str, expr); + Status s = CreateExpr(json_str, expr); EXPECT_TRUE(s.ok()); std::string bytes = encoder->EncodeToFilter(expr.get()); @@ -64,7 +69,7 @@ TEST_F(SDKLangChainExprEncoder, Int64Gt) { )"; std::shared_ptr expr; - Status s = LangchainExprFactory::CreateExpr(json_str, expr); + Status s = CreateExpr(json_str, expr); EXPECT_TRUE(s.ok()); std::string bytes = encoder->EncodeToFilter(expr.get()); @@ -84,7 +89,7 @@ TEST_F(SDKLangChainExprEncoder, StringEq) { )"; std::shared_ptr expr; - Status s = LangchainExprFactory::CreateExpr(json_str, expr); + Status s = CreateExpr(json_str, expr); EXPECT_TRUE(s.ok()); std::string bytes = encoder->EncodeToFilter(expr.get()); @@ -104,7 +109,7 @@ TEST_F(SDKLangChainExprEncoder, BoolEqFalse) { )"; std::shared_ptr expr; - Status s = LangchainExprFactory::CreateExpr(json_str, expr); + Status s = CreateExpr(json_str, expr); EXPECT_TRUE(s.ok()); std::string bytes = encoder->EncodeToFilter(expr.get()); @@ -124,7 +129,7 @@ TEST_F(SDKLangChainExprEncoder, BoolEqTrue) { )"; std::shared_ptr expr; - Status s = LangchainExprFactory::CreateExpr(json_str, expr); + Status s = CreateExpr(json_str, expr); EXPECT_TRUE(s.ok()); std::string bytes = encoder->EncodeToFilter(expr.get()); @@ -171,7 +176,7 @@ TEST_F(SDKLangChainExprEncoder, AndOperator) { )"; std::shared_ptr expr; - Status s = LangchainExprFactory::CreateExpr(json_str, expr); + Status s = CreateExpr(json_str, expr); EXPECT_TRUE(s.ok()); std::string bytes = encoder->EncodeToFilter(expr.get()); @@ -205,7 +210,7 @@ TEST_F(SDKLangChainExprEncoder, OrOperator) { )"; std::shared_ptr expr; - Status s = LangchainExprFactory::CreateExpr(json_str, expr); + Status s = CreateExpr(json_str, expr); EXPECT_TRUE(s.ok()); std::string bytes = encoder->EncodeToFilter(expr.get()); @@ -231,7 +236,7 @@ TEST_F(SDKLangChainExprEncoder, NotOperator) { )"; std::shared_ptr expr; - Status s = LangchainExprFactory::CreateExpr(json_str, expr); + Status s = CreateExpr(json_str, expr); EXPECT_TRUE(s.ok()); std::string bytes = encoder->EncodeToFilter(expr.get()); @@ -270,7 +275,7 @@ TEST_F(SDKLangChainExprEncoder, NotOperatorWithOperator) { )"; std::shared_ptr expr; - Status s = LangchainExprFactory::CreateExpr(json_str, expr); + Status s = CreateExpr(json_str, expr); EXPECT_TRUE(s.ok()); std::string bytes = encoder->EncodeToFilter(expr.get()); @@ -304,7 +309,7 @@ TEST_F(SDKLangChainExprEncoder, NotOperatorWithNestedComparator) { )"; std::shared_ptr expr; - Status s = LangchainExprFactory::CreateExpr(json_str, expr); + Status s = CreateExpr(json_str, expr); EXPECT_TRUE(s.ok()); std::string bytes = encoder->EncodeToFilter(expr.get()); diff --git a/test/unit_test/sdk/vector/test_vector_common.cc b/test/unit_test/sdk/vector/test_vector_common.cc index 4afafc196..6c52a8f8d 100644 --- a/test/unit_test/sdk/vector/test_vector_common.cc +++ b/test/unit_test/sdk/vector/test_vector_common.cc @@ -297,32 +297,6 @@ TEST(SDKVectorCommonTest, TestFillSearchIvfPqParamPB) { EXPECT_EQ(pb.recall_num(), 5); } -TEST(SDKVectorCommonTest, ScalarFieldType2InternalScalarFieldTypePB) { - EXPECT_EQ(pb::common::ScalarFieldType::NONE, ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType::kNone)); - EXPECT_EQ(pb::common::ScalarFieldType::BOOL, ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType::kBool)); - EXPECT_EQ(pb::common::ScalarFieldType::INT8, ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType::kInt8)); - EXPECT_EQ(pb::common::ScalarFieldType::INT16, ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType::kInt16)); - EXPECT_EQ(pb::common::ScalarFieldType::INT32, ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType::kInt32)); - EXPECT_EQ(pb::common::ScalarFieldType::INT64, ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType::kInt64)); - EXPECT_EQ(pb::common::ScalarFieldType::FLOAT32, ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType::kFloat32)); - EXPECT_EQ(pb::common::ScalarFieldType::DOUBLE, ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType::kDouble)); - EXPECT_EQ(pb::common::ScalarFieldType::STRING, ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType::kString)); - EXPECT_EQ(pb::common::ScalarFieldType::BYTES, ScalarFieldType2InternalScalarFieldTypePB(ScalarFieldType::kBytes)); -} - -TEST(SDKVectorCommonTest, InternalScalarFieldTypePB2ScalarFieldType) { - EXPECT_EQ(ScalarFieldType::kNone, InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType::NONE)); - EXPECT_EQ(ScalarFieldType::kBool, InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType::BOOL)); - EXPECT_EQ(ScalarFieldType::kInt8, InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType::INT8)); - EXPECT_EQ(ScalarFieldType::kInt16, InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType::INT16)); - EXPECT_EQ(ScalarFieldType::kInt32, InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType::INT32)); - EXPECT_EQ(ScalarFieldType::kInt64, InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType::INT64)); - EXPECT_EQ(ScalarFieldType::kFloat32, InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType::FLOAT32)); - EXPECT_EQ(ScalarFieldType::kDouble, InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType::DOUBLE)); - EXPECT_EQ(ScalarFieldType::kString, InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType::STRING)); - EXPECT_EQ(ScalarFieldType::kBytes, InternalScalarFieldTypePB2ScalarFieldType(pb::common::ScalarFieldType::BYTES)); -} - TEST(FillSearchHnswParamPBTest, TestFillSearchHnswParamPB) { SearchParam param; param.extra_params[SearchExtraParamType::kEfSearch] = 20;