Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(search): implement vector query for sql/redisearch parser & transformer #2450

Merged
merged 12 commits into from
Aug 2, 2024
20 changes: 20 additions & 0 deletions src/search/common_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,26 @@ struct TreeTransformer {

return result;
}

template <typename T = double>
static StatusOr<std::vector<T>> Binary2Vector(std::string_view str) {
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
if (str.size() % sizeof(T) != 0) {
return {Status::NotOK, "data size is not a multiple of the target type size"};
}

std::vector<T> values;
const size_t type_size = sizeof(T);
values.reserve(str.size() / type_size);

while (!str.empty()) {
T value;
memcpy(&value, str.data(), type_size);
values.push_back(value);
str.remove_prefix(type_size);
}

return values;
}
};

} // namespace kqir
68 changes: 68 additions & 0 deletions src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,63 @@ struct NumericCompareExpr : BoolAtomExpr {
}
};

struct VectorLiteral : Literal {
std::vector<double> values;

explicit VectorLiteral(std::vector<double> &&values) : values(std::move(values)){};

std::string_view Name() const override { return "VectorLiteral"; }
std::string Dump() const override {
return fmt::format("[{}]", util::StringJoin(values, [](auto v) { return std::to_string(v); }));
}
std::string Content() const override { return Dump(); }

std::unique_ptr<Node> Clone() const override { return std::make_unique<VectorLiteral>(*this); }
};

struct VectorRangeExpr : BoolAtomExpr {
std::unique_ptr<FieldRef> field;
std::unique_ptr<NumericLiteral> range;
std::unique_ptr<VectorLiteral> vector;

VectorRangeExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&range,
std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), range(std::move(range)), vector(std::move(vector)) {}

std::string_view Name() const override { return "VectorRangeExpr"; }
std::string Dump() const override {
return fmt::format("{} vector_range {} {}", field->Dump(), range->Dump(), vector->Dump());
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
}

std::unique_ptr<Node> Clone() const override {
return std::make_unique<VectorRangeExpr>(Node::MustAs<FieldRef>(field->Clone()),
Node::MustAs<NumericLiteral>(range->Clone()),
Node::MustAs<VectorLiteral>(vector->Clone()));
}
};

struct VectorSearchExpr : BoolAtomExpr {
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
// TODO: Support pre-filter for hybrid query
std::unique_ptr<FieldRef> field;
std::unique_ptr<NumericLiteral> k;
std::unique_ptr<VectorLiteral> vector;

VectorSearchExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&k,
std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {}

std::string_view Name() const override { return "VectorSearchExpr"; }
std::string Dump() const override {
return fmt::format("{} vector_search {} {}", field->Dump(), k->Dump(), vector->Dump());
}

std::unique_ptr<Node> Clone() const override {
return std::make_unique<VectorRangeExpr>(Node::MustAs<FieldRef>(field->Clone()),
Node::MustAs<NumericLiteral>(k->Clone()),
Node::MustAs<VectorLiteral>(vector->Clone()));
}
};

struct BoolLiteral : BoolAtomExpr, Literal {
bool val;

Expand Down Expand Up @@ -336,15 +393,22 @@ struct LimitClause : Node {
std::string Content() const override { return fmt::format("{}, {}", offset, count); }

std::unique_ptr<Node> Clone() const override { return std::make_unique<LimitClause>(*this); }
size_t Offset() const { return offset; }

size_t Count() const { return count; }
};

struct SortByClause : Node {
enum Order { ASC, DESC } order = ASC;
std::unique_ptr<FieldRef> field;
std::unique_ptr<VectorLiteral> vector = nullptr;

SortByClause(Order order, std::unique_ptr<FieldRef> &&field) : order(order), field(std::move(field)) {}
SortByClause(std::unique_ptr<FieldRef> &&field, std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), vector(std::move(vector)) {}

static constexpr const char *OrderToString(Order order) { return order == ASC ? "asc" : "desc"; }
bool IsKnn() const { return vector != nullptr; }
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved

std::string_view Name() const override { return "SortByClause"; }
std::string Dump() const override { return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); }
Expand All @@ -356,6 +420,10 @@ struct SortByClause : Node {
std::unique_ptr<Node> Clone() const override {
return std::make_unique<SortByClause>(order, Node::MustAs<FieldRef>(field->Clone()));
}

std::unique_ptr<FieldRef> GetFieldRef() { return std::move(field); }

std::unique_ptr<VectorLiteral> GetVectorLiteral() { return std::move(vector); }
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
};

struct SelectClause : Node {
Expand Down
36 changes: 36 additions & 0 deletions src/search/ir_sema_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,42 @@ struct SemaChecker {
} else {
v->field->info = &iter->second;
}
} else if (auto v = dynamic_cast<VectorSearchExpr *>(node)) {
if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
} else if (!iter->second.MetadataAs<redis::HnswVectorFieldMetadata>()) {
return {Status::NotOK, fmt::format("field `{}` is not a vector field", v->field->name)};
} else {
v->field->info = &iter->second;

auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)};
}
}
} else if (auto v = dynamic_cast<VectorRangeExpr *>(node)) {
if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
} else if (!iter->second.MetadataAs<redis::HnswVectorFieldMetadata>()) {
return {Status::NotOK, fmt::format("field `{}` is not a vector field", v->field->name)};
} else {
v->field->info = &iter->second;

auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (meta->distance_metric == redis::DistanceMetric::L2 && v->range->val < 0) {
return {Status::NotOK, "range cannot be a negative number for l2 distance metric"};
}

if (meta->distance_metric == redis::DistanceMetric::COSINE && (v->range->val < 0 || v->range->val > 2)) {
return {Status::NotOK, "range has to be between 0 and 2 for cosine distance metric"};
}

if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)};
}
}
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
} else if (auto v = dynamic_cast<SelectClause *>(node)) {
for (const auto &n : v->fields) {
if (auto iter = current_index->fields.find(n->name); iter == current_index->fields.end()) {
Expand Down
16 changes: 13 additions & 3 deletions src/search/redis_query_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ namespace redis_query {

using namespace peg;

struct VectorRangeToken : string<'V', 'E', 'C', 'T', 'O', 'R', '_', 'R', 'A', 'N', 'G', 'E'> {};
struct KnnToken : string<'K', 'N', 'N'> {};
struct ArrowOp : string<'=', '>'> {};
struct Wildcard : one<'*'> {};

struct Field : seq<one<'@'>, Identifier> {};

struct Param : seq<one<'$'>, Identifier> {};
Expand All @@ -44,9 +49,10 @@ struct ExclusiveNumber : seq<one<'('>, NumberOrParam> {};
struct NumericRangePart : sor<Inf, ExclusiveNumber, NumberOrParam> {};
struct NumericRange : seq<one<'['>, WSPad<NumericRangePart>, WSPad<NumericRangePart>, one<']'>> {};

struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<TagList, NumericRange>>> {};
struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<NumberOrParam>, WSPad<Field>, WSPad<Param>, one<']'>> {};
struct VectorRange : seq<one<'['>, WSPad<VectorRangeToken>, WSPad<NumberOrParam>, WSPad<Param>, one<']'>> {};

struct Wildcard : one<'*'> {};
struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<VectorRange, TagList, NumericRange>>> {};

struct QueryExpr;

Expand All @@ -64,7 +70,11 @@ struct AndExprP : sor<AndExpr, BooleanExpr> {};
struct OrExpr : seq<AndExprP, plus<seq<one<'|'>, AndExprP>>> {};
struct OrExprP : sor<OrExpr, AndExprP> {};

struct QueryExpr : seq<OrExprP> {};
struct PrefilterExpr : seq<WSPad<BooleanExpr>, ArrowOp, WSPad<KnnSearch>> {};

struct QueryP : sor<PrefilterExpr, OrExprP> {};

struct QueryExpr : seq<QueryP> {};

} // namespace redis_query

Expand Down
64 changes: 45 additions & 19 deletions src/search/redis_query_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ namespace redis_query {
namespace ir = kqir;

template <typename Rule>
using TreeSelector =
parse_tree::selector<Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, ExclusiveNumber, FieldQuery, NotExpr,
AndExpr, OrExpr, Wildcard>>;
using TreeSelector = parse_tree::selector<
Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, VectorRange, ExclusiveNumber, FieldQuery, NotExpr, AndExpr,
OrExpr, PrefilterExpr, KnnSearch, Wildcard, VectorRangeToken, KnnToken, ArrowOp>>;

template <typename Input>
StatusOr<std::unique_ptr<parse_tree::node>> ParseToTree(Input&& in) {
Expand All @@ -53,7 +53,31 @@ StatusOr<std::unique_ptr<parse_tree::node>> ParseToTree(Input&& in) {
struct Transformer : ir::TreeTransformer {
explicit Transformer(const ParamMap& param_map) : TreeTransformer(param_map) {}

StatusOr<std::unique_ptr<VectorLiteral>> Transform2Vector(const TreeNode& node) {
std::string vector_str = GET_OR_RET(GetParam(node));

std::vector<double> values = GET_OR_RET(Binary2Vector<double>(vector_str));
if (values.empty()) {
return {Status::NotOK, "empty vector is invalid"};
}
return std::make_unique<ir::VectorLiteral>(std::move(values));
};

auto Transform(const TreeNode& node) -> StatusOr<std::unique_ptr<Node>> {
auto number_or_param = [this](const TreeNode& node) -> StatusOr<std::unique_ptr<NumericLiteral>> {
if (Is<Number>(node)) {
return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
} else if (Is<Param>(node)) {
auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
.Prefixed(fmt::format("parameter {} is not a number", node->string_view())));

return std::make_unique<ir::NumericLiteral>(val);
} else {
return {Status::NotOK,
fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)};
}
};

if (Is<Number>(node)) {
return Node::Create<ir::NumericLiteral>(*ParseFloat(node->string()));
} else if (Is<Wildcard>(node)) {
Expand Down Expand Up @@ -88,26 +112,12 @@ struct Transformer : ir::TreeTransformer {
} else {
return std::make_unique<ir::OrExpr>(std::move(exprs));
}
} else { // NumericRange
} else if (Is<NumericRange>(query)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;

const auto& lhs = query->children[0];
const auto& rhs = query->children[1];

auto number_or_param = [this](const TreeNode& node) -> StatusOr<std::unique_ptr<NumericLiteral>> {
if (Is<Number>(node)) {
return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
} else if (Is<Param>(node)) {
auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
.Prefixed(fmt::format("parameter {} is not a number", node->string_view())));

return std::make_unique<ir::NumericLiteral>(val);
} else {
return {Status::NotOK,
fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)};
}
};

if (Is<ExclusiveNumber>(lhs)) {
exprs.push_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GT,
std::make_unique<FieldRef>(field),
Expand Down Expand Up @@ -141,11 +151,27 @@ struct Transformer : ir::TreeTransformer {
} else {
return std::make_unique<ir::AndExpr>(std::move(exprs));
}
} else if (Is<VectorRange>(query)) {
return std::make_unique<VectorRangeExpr>(std::make_unique<FieldRef>(field),
GET_OR_RET(number_or_param(query->children[1])),
GET_OR_RET(Transform2Vector(query->children[2])));
}
} else if (Is<NotExpr>(node)) {
CHECK(node->children.size() == 1);

return Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
} else if (Is<PrefilterExpr>(node)) {
CHECK(node->children.size() == 3);

// TODO(Beihao): Support Hybrid Query
// const auto& prefilter = node->children[0];
const auto& knn_search = node->children[2];
CHECK(knn_search->children.size() == 4);

return std::make_unique<VectorSearchExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
GET_OR_RET(number_or_param(knn_search->children[1])),
GET_OR_RET(Transform2Vector(knn_search->children[3])));

} else if (Is<AndExpr>(node)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;

Expand Down
10 changes: 8 additions & 2 deletions src/search/sql_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ struct NumericAtomExpr : WSPad<sor<NumberOrParam, Identifier>> {};
struct NumericCompareOp : sor<string<'!', '='>, string<'<', '='>, string<'>', '='>, one<'=', '<', '>'>> {};
struct NumericCompareExpr : seq<NumericAtomExpr, NumericCompareOp, NumericAtomExpr> {};

struct BooleanAtomExpr : sor<HasTagExpr, NumericCompareExpr, WSPad<Boolean>> {};
struct VectorCompareOp : string<'<', '-', '>'> {};
struct VectorLiteral : seq<WSPad<one<'['>>, Number, star<seq<WSPad<one<','>>>, Number>, WSPad<one<']'>>> {};
struct VectorCompareExpr : seq<WSPad<Identifier>, VectorCompareOp, WSPad<VectorLiteral>> {};
struct VectorRangeExpr : seq<VectorCompareExpr, one<'<'>, WSPad<NumberOrParam>> {};

struct BooleanAtomExpr : sor<HasTagExpr, NumericCompareExpr, VectorRangeExpr, WSPad<Boolean>> {};

struct QueryExpr;

Expand Down Expand Up @@ -84,7 +89,8 @@ struct Limit : string<'l', 'i', 'm', 'i', 't'> {};

struct WhereClause : seq<Where, QueryExpr> {};
struct AscOrDesc : sor<Asc, Desc> {};
struct OrderByClause : seq<OrderBy, WSPad<Identifier>, opt<WSPad<AscOrDesc>>> {};
struct OrderByExpr : sor<WSPad<VectorCompareExpr>, seq<WSPad<Identifier>, opt<WSPad<AscOrDesc>>>> {};
PragmaTwice marked this conversation as resolved.
Show resolved Hide resolved
struct OrderByClause : seq<OrderBy, OrderByExpr> {};
struct LimitClause : seq<Limit, opt<seq<WSPad<UnsignedInteger>, one<','>>>, WSPad<UnsignedInteger>> {};

struct SearchStmt
Expand Down
Loading
Loading