Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(search): Hnsw Vector Search Optimizaton Pass #2466

Merged
merged 7 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/search/executors/filter_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <variant>

#include "parse_util.h"
#include "search/hnsw_indexer.h"
#include "search/ir.h"
#include "search/plan_executor.h"
#include "search/search_encoding.h"
Expand All @@ -44,6 +45,9 @@ struct QueryExprEvaluator {
if (auto v = dynamic_cast<NotExpr *>(e)) {
return Visit(v);
}
if (auto v = dynamic_cast<VectorRangeExpr *>(e)) {
return Visit(v);
}
if (auto v = dynamic_cast<NumericCompareExpr *>(e)) {
return Visit(v);
}
Expand Down Expand Up @@ -112,6 +116,24 @@ struct QueryExprEvaluator {
__builtin_unreachable();
}
}

StatusOr<bool> Visit(VectorRangeExpr *v) const {
auto val = GET_OR_RET(ctx->Retrieve(row, v->field->info));

CHECK(val.Is<kqir::NumericArray>());
auto l_values = val.Get<kqir::NumericArray>();
auto r_values = v->vector->values;
auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();

redis::VectorItem left, right;
GET_OR_RET(redis::VectorItem::Create({}, l_values, meta, &left));
GET_OR_RET(redis::VectorItem::Create({}, r_values, meta, &right));

auto dist = GET_OR_RET(redis::ComputeSimilarity(left, right));
auto effective_range = v->range->val * (1 + meta->epsilon);

return (dist >= -abs(effective_range) && dist <= abs(effective_range));
}
};

struct FilterExecutor : ExecutorNode {
Expand Down
19 changes: 9 additions & 10 deletions src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,24 +265,19 @@ struct VectorRangeExpr : BoolAtomExpr {
};

struct VectorKnnExpr : BoolAtomExpr {
// TODO: Support pre-filter for hybrid query
std::unique_ptr<FieldRef> field;
std::unique_ptr<NumericLiteral> k;
std::unique_ptr<VectorLiteral> vector;
size_t k;

VectorKnnExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&k,
std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {}
VectorKnnExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<VectorLiteral> &&vector, size_t k)
: field(std::move(field)), vector(std::move(vector)), k(k) {}

std::string_view Name() const override { return "VectorKnnExpr"; }
std::string Dump() const override {
return fmt::format("KNN k={}, {} <-> {}", k->Dump(), field->Dump(), vector->Dump());
}
std::string Dump() const override { return fmt::format("KNN k={}, {} <-> {}", k, field->Dump(), vector->Dump()); }

std::unique_ptr<Node> Clone() const override {
return std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(field->Clone()),
Node::MustAs<NumericLiteral>(k->Clone()),
Node::MustAs<VectorLiteral>(vector->Clone()));
Node::MustAs<VectorLiteral>(vector->Clone()), k);
}
};

Expand Down Expand Up @@ -425,6 +420,10 @@ struct SortByClause : Node {
std::unique_ptr<Node> Clone() const override {
return std::make_unique<SortByClause>(order, Node::MustAs<FieldRef>(field->Clone()));
}

std::unique_ptr<FieldRef> TakeFieldRef() { return std::move(field); }

std::unique_ptr<VectorLiteral> TakeVectorLiteral() { return std::move(vector); }
};

struct SelectClause : Node {
Expand Down
29 changes: 29 additions & 0 deletions src/search/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ struct Visitor : Pass {
return Visit(std::move(v));
} else if (auto v = Node::As<TagContainExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<VectorLiteral>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<VectorKnnExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<VectorRangeExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<StringLiteral>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<BoolLiteral>(std::move(node))) {
Expand All @@ -69,6 +75,10 @@ struct Visitor : Pass {
return Visit(std::move(v));
} else if (auto v = Node::As<TagFieldScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<HnswVectorFieldRangeScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<HnswVectorFieldKnnScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Filter>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Limit>(std::move(node))) {
Expand Down Expand Up @@ -125,6 +135,8 @@ struct Visitor : Pass {

virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericLiteral> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorLiteral> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericCompareExpr> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->num = VisitAs<NumericLiteral>(std::move(node->num));
Expand All @@ -137,6 +149,19 @@ struct Visitor : Pass {
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorKnnExpr> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->vector = VisitAs<VectorLiteral>(std::move(node->vector));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorRangeExpr> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->range = VisitAs<NumericLiteral>(std::move(node->range));
node->vector = VisitAs<VectorLiteral>(std::move(node->vector));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) {
for (auto &n : node->inners) {
n = TransformAs<QueryExpr>(std::move(n));
Expand Down Expand Up @@ -173,6 +198,10 @@ struct Visitor : Pass {

virtual std::unique_ptr<Node> Visit(std::unique_ptr<TagFieldScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<HnswVectorFieldRangeScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<HnswVectorFieldKnnScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<Filter> node) {
node->source = TransformAs<PlanOperator>(std::move(node->source));
node->filter_expr = TransformAs<QueryExpr>(std::move(node->filter_expr));
Expand Down
2 changes: 1 addition & 1 deletion src/search/ir_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct TagFieldScan : FieldScan {

struct HnswVectorFieldKnnScan : FieldScan {
kqir::NumericArray vector;
uint16_t k;
uint32_t k;

HnswVectorFieldKnnScan(std::unique_ptr<FieldRef> field, kqir::NumericArray vector, uint16_t k)
: FieldScan(std::move(field)), vector(std::move(vector)), k(k) {}
Expand Down
13 changes: 8 additions & 5 deletions src/search/ir_sema_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,14 @@ struct SemaChecker {
GET_OR_RET(Check(v->query_expr.get()));
if (v->limit) GET_OR_RET(Check(v->limit.get()));
if (v->sort_by) GET_OR_RET(Check(v->sort_by.get()));
if (v->sort_by && v->sort_by->IsVectorField() && !v->limit) {
return {Status::NotOK, "expect a LIMIT clause for vector field to construct a KNN search"};
if (v->sort_by && v->sort_by->IsVectorField()) {
if (!v->limit) {
return {Status::NotOK, "expect a LIMIT clause for vector field to construct a KNN search"};
}
// TODO: allow hybrid query
if (auto b = dynamic_cast<BoolLiteral *>(v->query_expr.get()); b == nullptr) {
return {Status::NotOK, "KNN search cannot be combined with other query expressions"};
}
}
} else {
return {Status::NotOK, fmt::format("index `{}` not found", index_name)};
Expand Down Expand Up @@ -129,9 +135,6 @@ struct SemaChecker {
return {Status::NotOK,
fmt::format("field `{}` is marked as NOINDEX and cannot be used for KNN search", v->field->name)};
}
if (v->k->val <= 0) {
return {Status::NotOK, fmt::format("KNN search parameter `k` must be greater than 0")};
}
auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
Expand Down
10 changes: 10 additions & 0 deletions src/search/passes/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ struct CostModel {
if (auto v = dynamic_cast<const FullIndexScan *>(node)) {
return Visit(v);
}
if (auto v = dynamic_cast<const HnswVectorFieldKnnScan *>(node)) {
return Visit(v);
}
if (auto v = dynamic_cast<const HnswVectorFieldRangeScan *>(node)) {
return Visit(v);
}
if (auto v = dynamic_cast<const NumericFieldScan *>(node)) {
return Visit(v);
}
Expand Down Expand Up @@ -74,6 +80,10 @@ struct CostModel {

static size_t Visit(const TagFieldScan *node) { return 10; }

static size_t Visit(const HnswVectorFieldKnnScan *node) { return 3; }

static size_t Visit(const HnswVectorFieldRangeScan *node) { return 4; }

static size_t Visit(const Filter *node) { return Transform(node->source.get()) + 1; }

static size_t Visit(const Merge *node) {
Expand Down
23 changes: 23 additions & 0 deletions src/search/passes/index_selection.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ struct IndexSelection : Visitor {
if (auto v = dynamic_cast<OrExpr *>(node)) {
return VisitExpr(v);
}
if (auto v = dynamic_cast<VectorKnnExpr *>(node)) {
return VisitExpr(v);
}
if (auto v = dynamic_cast<VectorRangeExpr *>(node)) {
return VisitExpr(v);
}
if (auto v = dynamic_cast<NumericCompareExpr *>(node)) {
return VisitExpr(v);
}
Expand Down Expand Up @@ -153,6 +159,23 @@ struct IndexSelection : Visitor {
return MakeFullIndexFilter(node);
}

std::unique_ptr<PlanOperator> VisitExpr(VectorRangeExpr *node) const {
if (node->field->info->HasIndex()) {
return std::make_unique<HnswVectorFieldRangeScan>(node->field->CloneAs<FieldRef>(), node->vector->values,
node->range->val);
}

return MakeFullIndexFilter(node);
}

std::unique_ptr<PlanOperator> VisitExpr(VectorKnnExpr *node) const {
if (node->field->info->HasIndex()) {
return std::make_unique<HnswVectorFieldKnnScan>(node->field->CloneAs<FieldRef>(), node->vector->values, node->k);
}

return MakeFullIndexFilter(node);
}

template <typename Expr>
std::unique_ptr<PlanOperator> VisitExprImpl(Expr *node) {
struct AggregatedNodes {
Expand Down
4 changes: 3 additions & 1 deletion src/search/passes/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "search/passes/simplify_and_or_expr.h"
#include "search/passes/simplify_boolean.h"
#include "search/passes/sort_limit_fuse.h"
#include "search/passes/sort_limit_to_knn.h"
#include "type_util.h"

namespace kqir {
Expand Down Expand Up @@ -86,7 +87,8 @@ struct PassManager {
}

static PassSequence ExprPasses() {
return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, SimplifyAndOrExpr{});
return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, SimplifyAndOrExpr{},
SortByWithLimitToKnnExpr{}, SimplifyAndOrExpr{});
}
static PassSequence NumericPasses() { return Create(IntervalAnalysis{true}, SimplifyAndOrExpr{}, SimplifyBoolean{}); }
static PassSequence PlanPasses() { return Create(LowerToPlan{}, IndexSelection{}, SortLimitFuse{}); }
Expand Down
50 changes: 50 additions & 0 deletions src/search/passes/sort_limit_to_knn.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#pragma once

#include <memory>

#include "search/ir.h"
#include "search/ir_pass.h"
#include "search/ir_plan.h"

namespace kqir {

struct SortByWithLimitToKnnExpr : Visitor {
std::unique_ptr<Node> Visit(std::unique_ptr<SearchExpr> node) override {
node = Node::MustAs<SearchExpr>(Visitor::Visit(std::move(node)));

// TODO: allow hybrid query
if (node->sort_by && node->sort_by->IsVectorField() && node->limit) {
if (auto b = dynamic_cast<BoolLiteral*>(node->query_expr.get()); b && b->val) {
node->query_expr =
std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(node->sort_by->TakeFieldRef()),
Node::MustAs<VectorLiteral>(node->sort_by->TakeVectorLiteral()),
node->limit->Offset() + node->limit->Count());
node->sort_by.reset();
}
}

return node;
}
};

} // namespace kqir
5 changes: 3 additions & 2 deletions src/search/redis_query_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ struct Tag : sor<Identifier, StringL, Param> {};
struct TagList : seq<one<'{'>, WSPad<Tag>, star<seq<one<'|'>, WSPad<Tag>>>, one<'}'>> {};

struct NumberOrParam : sor<Number, Param> {};
struct UintOrParam : sor<UnsignedInteger, Param> {};

struct Inf : seq<opt<one<'+', '-'>>, string<'i', 'n', 'f'>> {};
struct ExclusiveNumber : seq<one<'('>, NumberOrParam> {};
struct NumericRangePart : sor<Inf, ExclusiveNumber, NumberOrParam> {};
struct NumericRange : seq<one<'['>, WSPad<NumericRangePart>, WSPad<NumericRangePart>, one<']'>> {};

struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<NumberOrParam>, WSPad<Field>, WSPad<Param>, one<']'>> {};
struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<UintOrParam>, WSPad<Field>, WSPad<Param>, one<']'>> {};
struct VectorRange : seq<one<'['>, WSPad<VectorRangeToken>, WSPad<NumberOrParam>, WSPad<Param>, one<']'>> {};

struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<VectorRange, TagList, NumericRange>>> {};
Expand All @@ -70,7 +71,7 @@ struct AndExprP : sor<AndExpr, BooleanExpr> {};
struct OrExpr : seq<AndExprP, plus<seq<one<'|'>, AndExprP>>> {};
struct OrExprP : sor<OrExpr, AndExprP> {};

struct PrefilterExpr : seq<WSPad<BooleanExpr>, ArrowOp, WSPad<KnnSearch>> {};
struct PrefilterExpr : seq<WSPad<Wildcard>, ArrowOp, WSPad<KnnSearch>> {};

struct QueryP : sor<PrefilterExpr, OrExprP> {};

Expand Down
16 changes: 10 additions & 6 deletions src/search/redis_query_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace ir = kqir;

template <typename Rule>
using TreeSelector = parse_tree::selector<
Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, Inf>,
Rule, parse_tree::store_content::on<Number, UnsignedInteger, StringL, Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, VectorRange, ExclusiveNumber, FieldQuery, NotExpr, AndExpr,
OrExpr, PrefilterExpr, KnnSearch, Wildcard, VectorRangeToken, KnnToken, ArrowOp>>;

Expand Down Expand Up @@ -161,17 +161,21 @@ struct Transformer : ir::TreeTransformer {

return Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
} else if (Is<PrefilterExpr>(node)) {
// TODO: allow hybrid query
CHECK(node->children.size() == 3);

// TODO(Beihao): Support Hybrid Query
// const auto& prefilter = node->children[0];
const auto& knn_search = node->children[2];
CHECK(knn_search->children.size() == 4);

return std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
GET_OR_RET(number_or_param(knn_search->children[1])),
GET_OR_RET(Transform2Vector(knn_search->children[3])));
size_t k = 0;
if (Is<UnsignedInteger>(knn_search->children[1])) {
k = *ParseInt(knn_search->children[1]->string());
} else {
k = *ParseInt(GET_OR_RET(GetParam(node)));
}

return std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
GET_OR_RET(Transform2Vector(knn_search->children[3])), k);
} else if (Is<AndExpr>(node)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;

Expand Down
1 change: 0 additions & 1 deletion src/search/sql_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ struct Transformer : ir::TreeTransformer {
return {Status::NotOK, "the left and right side of numeric comparison should be an identifier and a number"};
}
} else if (Is<VectorRangeExpr>(node)) {
// TODO(Beihao): Handle distance metrics for operator
CHECK(node->children.size() == 2);
const auto& vector_comp_expr = node->children[0];
CHECK(vector_comp_expr->children.size() == 3);
Expand Down
Loading
Loading