Skip to content

Commit

Permalink
[feat][sdk] Impl vector scan query
Browse files Browse the repository at this point in the history
  • Loading branch information
wchuande authored and ketor committed Mar 11, 2024
1 parent 302c4ef commit 32978fd
Show file tree
Hide file tree
Showing 13 changed files with 529 additions and 0 deletions.
49 changes: 49 additions & 0 deletions src/example/sdk_vector_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,53 @@ static void VectorGetBorder(bool use_index_name = false) {
}
}

static void VectorScanQuery(bool use_index_name = false) {
{
// forward
dingodb::sdk::ScanQueryParam param;
param.vector_id_start = g_vector_ids[0];
param.vector_id_end = g_vector_ids[g_vector_ids.size() - 1];
param.max_scan_count = 2;

dingodb::sdk::ScanQueryResult result;
Status tmp;
if (use_index_name) {
tmp = g_vector_client->ScanQueryByIndexName(g_schema_id, g_index_name, param, result);
} else {
tmp = g_vector_client->ScanQueryByIndexId(g_index_id, param, result);
}

DINGO_LOG(INFO) << "vector scan query:" << tmp.ToString() << ", result:" << result.ToString();
if (tmp.ok()) {
CHECK_EQ(result.vectors[0].id, g_vector_ids[0]);
CHECK_EQ(result.vectors[1].id, g_vector_ids[1]);
}
}

{
// backward
dingodb::sdk::ScanQueryParam param;
param.vector_id_start = g_vector_ids[g_vector_ids.size() - 1];
param.vector_id_end = g_vector_ids[0];
param.max_scan_count = 2;
param.is_reverse = true;

dingodb::sdk::ScanQueryResult result;
Status tmp;
if (use_index_name) {
tmp = g_vector_client->ScanQueryByIndexName(g_schema_id, g_index_name, param, result);
} else {
tmp = g_vector_client->ScanQueryByIndexId(g_index_id, param, result);
}

DINGO_LOG(INFO) << "vector scan query:" << tmp.ToString() << ", result:" << result.ToString();
if (tmp.ok()) {
CHECK_EQ(result.vectors[0].id, g_vector_ids[g_vector_ids.size() - 1]);
CHECK_EQ(result.vectors[1].id, g_vector_ids[g_vector_ids.size() - 2]);
}
}
}

static void VectorDelete(bool use_index_name = false) {
Status tmp;
std::vector<dingodb::sdk::DeleteResult> result;
Expand Down Expand Up @@ -295,6 +342,7 @@ int main(int argc, char* argv[]) {
VectorSearch();
VectorQuey();
VectorGetBorder();
VectorScanQuery();
VectorDelete();
VectorSearch();

Expand All @@ -309,6 +357,7 @@ int main(int argc, char* argv[]) {
VectorSearch(true);
VectorQuey(true);
VectorGetBorder(true);
VectorScanQuery(true);
VectorDelete(true);
VectorSearch(true);

Expand Down
10 changes: 10 additions & 0 deletions src/pysdk/dingosdk.swg
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ namespace dingodb {
%feature("docstring") VectorClient::BatchQueryByIndexName "return Status, QueryResult out_result"
%feature("docstring") VectorClient::GetBorderByIndexId "return Status, int64_t out_vector_id"
%feature("docstring") VectorClient::GetBorderByIndexName "return Status, int64_t out_vector_id"
%feature("docstring") VectorClient::ScanQueryByIndexId "return Status, ScanQueryResult out_result"
%feature("docstring") VectorClient::ScanQueryByIndexName "return Status, ScanQueryResult out_result"

%typemap(in, numinputs=0) Client** (Client* temp){
temp = NULL;
Expand Down Expand Up @@ -164,6 +166,14 @@ namespace dingodb {
$result = SWIG_AppendOutput($result, obj);
}

%typemap(in, numinputs=0) ScanQueryResult& (ScanQueryResult temp) {
$1 = &temp;
}
%typemap(argout) ScanQueryResult& {
PyObject* obj = SWIG_NewPointerObj(new ScanQueryResult(*$1), SWIGTYPE_p_dingodb__sdk__ScanQueryResult, SWIG_POINTER_OWN);
$result = SWIG_AppendOutput($result, obj);
}

}

}
Expand Down
36 changes: 36 additions & 0 deletions src/pysdk/pysdk_vector_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,40 @@ def vector_get_border(use_index_name=False):
if tmp.ok():
assert vector_id == g_vector_ids[0]

def vector_scan_query(use_index_name=False):
# forward
param = dingosdk.ScanQueryParam()
param.vector_id_start = g_vector_ids[0]
param.vector_id_end = g_vector_ids[-1]
param.max_scan_count = 2

if use_index_name:
tmp, result = g_vector_client.ScanQueryByIndexName(g_schema_id, g_index_name, param)
else:
tmp, result = g_vector_client.ScanQueryByIndexId(g_index_id, param)

print(f"vector scan query status:{tmp.ToString()}, result: {result.ToString()}")
if tmp.ok():
assert result.vectors[0].id == g_vector_ids[0]
assert result.vectors[1].id == g_vector_ids[1]

# backward
param = dingosdk.ScanQueryParam()
param.vector_id_start = g_vector_ids[-1]
param.vector_id_end = g_vector_ids[0]
param.max_scan_count = 2
param.is_reverse = True

if use_index_name:
tmp, result = g_vector_client.ScanQueryByIndexName(g_schema_id, g_index_name, param)
else:
tmp, result = g_vector_client.ScanQueryByIndexId(g_index_id, param)

print(f"vector scan query status:{tmp.ToString()}, result: {result.ToString()}")
if tmp.ok():
assert result.vectors[0].id == g_vector_ids[-1]
assert result.vectors[1].id == g_vector_ids[-2]

def vector_delete(use_index_name=False):
if use_index_name:
tmp, result = g_vector_client.DeleteByIndexName(g_schema_id, g_index_name, g_vector_ids)
Expand All @@ -167,6 +201,7 @@ def vector_delete(use_index_name=False):
vector_search()
vector_query()
vector_get_border()
vector_scan_query()
vector_delete()
post_clean()

Expand All @@ -175,6 +210,7 @@ def vector_delete(use_index_name=False):
vector_search(True)
vector_query(True)
vector_get_border(True)
vector_scan_query(True)
vector_delete(True)
post_clean(True)

1 change: 1 addition & 0 deletions src/sdk/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_library(sdk
vector/vector_batch_query_task.cc
vector/vector_delete_task.cc
vector/vector_get_border_task.cc
vector/vector_scan_query_task.cc
vector/vector_search_task.cc
utils/thread_pool_actuator.cc
common/param_config.cc
Expand Down
62 changes: 62 additions & 0 deletions src/sdk/vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,64 @@ struct QueryResult {
std::string ToString() const;
};

struct ScanQueryParam {
int64_t vector_id_start;
// the end id of scan
// if is_reverse is true, vector_id_end must be less than vector_id_start
// if is_reverse is false, vector_id_end must be greater than vector_id_start
// the real range is [start, end], include start and end
// if vector_id_end == 0, scan to the end of the region
int64_t vector_id_end{0};
int64_t max_scan_count;
bool is_reverse{false};

bool with_vector_data{true};
bool with_scalar_data{false};
// If with_scalar_data is true, selected_keys is used to select scalar data, and if this parameter is null, all scalar
// data will be returned.
std::vector<std::string> selected_keys;
bool with_table_data{false}; // Default false, if true, response without table data

// TODO: support use_scalar_filter
bool use_scalar_filter{false};
// std::map<std::string, ScalarValue> scalar_data;

explicit ScanQueryParam() = default;

ScanQueryParam(ScanQueryParam&& other) noexcept
: vector_id_start(other.vector_id_start),
vector_id_end(other.vector_id_end),
max_scan_count(other.max_scan_count),
is_reverse(other.is_reverse),
with_vector_data(other.with_vector_data),
with_scalar_data(other.with_scalar_data),
selected_keys(std::move(other.selected_keys)),
with_table_data(other.with_table_data),
use_scalar_filter(other.use_scalar_filter) {}

ScanQueryParam& operator=(ScanQueryParam&& other) noexcept {
if (this != &other) {
vector_id_start = other.vector_id_start;
vector_id_end = other.vector_id_end;
max_scan_count = other.max_scan_count;
is_reverse = other.is_reverse;
with_vector_data = other.with_vector_data;
with_scalar_data = other.with_scalar_data;
selected_keys = std::move(other.selected_keys);
with_table_data = other.with_table_data;
use_scalar_filter = other.use_scalar_filter;
// You can add more fields here if you add more fields to the struct
}
return *this;
}
};

struct ScanQueryResult {
std::vector<VectorWithId> vectors;

std::string ToString() const;
};

class VectorIndexCreator {
public:
~VectorIndexCreator();
Expand Down Expand Up @@ -434,6 +492,10 @@ class VectorClient {
Status GetBorderByIndexId(int64_t index_id, bool is_max, int64_t& out_vector_id);
Status GetBorderByIndexName(int64_t schema_id, const std::string& index_name, bool is_max, int64_t& out_vector_id);

Status ScanQueryByIndexId(int64_t index_id, const ScanQueryParam& query_param, ScanQueryResult& out_result);
Status ScanQueryByIndexName(int64_t schema_id, const std::string& index_name, const ScanQueryParam& query_param,
ScanQueryResult& out_result);

private:
friend class Client;

Expand Down
1 change: 1 addition & 0 deletions src/sdk/vector/index_service_rpc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ DEFINE_INDEX_SERVICE_RPC(VectorSearch);
DEFINE_INDEX_SERVICE_RPC(VectorDelete);
DEFINE_INDEX_SERVICE_RPC(VectorBatchQuery);
DEFINE_INDEX_SERVICE_RPC(VectorGetBorderId);
DEFINE_INDEX_SERVICE_RPC(VectorScanQuery);

} // namespace sdk
} // namespace dingodb
1 change: 1 addition & 0 deletions src/sdk/vector/index_service_rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ DECLARE_INDEX_SERVICE_RPC(VectorSearch);
DECLARE_INDEX_SERVICE_RPC(VectorDelete);
DECLARE_INDEX_SERVICE_RPC(VectorBatchQuery);
DECLARE_INDEX_SERVICE_RPC(VectorGetBorderId);
DECLARE_INDEX_SERVICE_RPC(VectorScanQuery);

} // namespace sdk
} // namespace dingodb
Expand Down
16 changes: 16 additions & 0 deletions src/sdk/vector/vector_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "sdk/vector/vector_delete_task.h"
#include "sdk/vector/vector_get_border_task.h"
#include "sdk/vector/vector_index_cache.h"
#include "sdk/vector/vector_scan_query_task.h"
#include "sdk/vector/vector_search_task.h"

namespace dingodb {
Expand Down Expand Up @@ -109,6 +110,21 @@ Status VectorClient::GetBorderByIndexName(int64_t schema_id, const std::string&
return task.Run();
}

Status VectorClient::ScanQueryByIndexId(int64_t index_id, const ScanQueryParam& query_param,
ScanQueryResult& out_result) {
VectorScanQueryTask task(stub_, index_id, query_param, out_result);
return task.Run();
}

Status VectorClient::ScanQueryByIndexName(int64_t schema_id, const std::string& index_name,
const ScanQueryParam& query_param, ScanQueryResult& out_result) {
int64_t index_id{0};
DINGO_RETURN_NOT_OK(
stub_.GetVectorIndexCache()->GetIndexIdByKey(EncodeVectorIndexCacheKey(schema_id, index_name), index_id));
CHECK_GT(index_id, 0);
VectorScanQueryTask task(stub_, index_id, query_param, out_result);
return task.Run();
}
} // namespace sdk

} // namespace dingodb
7 changes: 7 additions & 0 deletions src/sdk/vector/vector_get_border_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ void VectorGetBorderTask::SubTaskCallback(Status status, VectorGetBorderPartTask
int64_t result_vecotr_id = sub_task->GetResult();
target_vector_id_ =
is_max_ ? std::max(target_vector_id_, result_vecotr_id) : std::min(target_vector_id_, result_vecotr_id);

next_part_ids_.erase(sub_task->part_id_);
}

if (sub_tasks_count_.fetch_sub(1) == 1) {
Expand All @@ -103,6 +105,11 @@ void VectorGetBorderPartTask::DoAsync() {
return;
}

{
std::unique_lock<std::shared_mutex> w(rw_lock_);
result_vector_id_ = is_max_ ? -1 : INT64_MAX;
}

controllers_.clear();
rpcs_.clear();

Expand Down
12 changes: 12 additions & 0 deletions src/sdk/vector/vector_param.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ std::string QueryResult::ToString() const {
return oss.str();
}

std::string ScanQueryResult::ToString() const {
std::ostringstream oss;
oss << "ScanQueryResult: {";
oss << "vectors: [";
for (const auto& vector : vectors) {
oss << vector.ToString() << ", ";
}
oss << "]";
oss << "}";
return oss.str();
}

std::string VectorIndexTypeToString(VectorIndexType type) {
switch (type) {
case VectorIndexType::kNoneIndexType:
Expand Down
Loading

0 comments on commit 32978fd

Please sign in to comment.