Skip to content

Commit

Permalink
feat(search): implement vector query for sql/redisearch parser & tran…
Browse files Browse the repository at this point in the history
…sformer (#2450)

Co-authored-by: Twice <[email protected]>
  • Loading branch information
Beihao-Zhou and PragmaTwice authored Aug 2, 2024
1 parent a3863f9 commit 0f5f18e
Show file tree
Hide file tree
Showing 11 changed files with 362 additions and 49 deletions.
20 changes: 20 additions & 0 deletions src/search/common_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,26 @@ struct TreeTransformer {

return result;
}

template <typename T = double>
static StatusOr<std::vector<T>> Binary2Vector(std::string_view str) {
if (str.size() % sizeof(T) != 0) {
return {Status::NotOK, "data size is not a multiple of the target type size"};
}

std::vector<T> values;
const size_t type_size = sizeof(T);
values.reserve(str.size() / type_size);

while (!str.empty()) {
T value;
memcpy(&value, str.data(), type_size);
values.push_back(value);
str.remove_prefix(type_size);
}

return values;
}
};

} // namespace kqir
71 changes: 70 additions & 1 deletion src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,63 @@ struct NumericCompareExpr : BoolAtomExpr {
}
};

struct VectorLiteral : Literal {
std::vector<double> values;

explicit VectorLiteral(std::vector<double> &&values) : values(std::move(values)){};

std::string_view Name() const override { return "VectorLiteral"; }
std::string Dump() const override {
return fmt::format("[{}]", util::StringJoin(values, [](auto v) { return std::to_string(v); }));
}
std::string Content() const override { return Dump(); }

std::unique_ptr<Node> Clone() const override { return std::make_unique<VectorLiteral>(*this); }
};

struct VectorRangeExpr : BoolAtomExpr {
std::unique_ptr<FieldRef> field;
std::unique_ptr<NumericLiteral> range;
std::unique_ptr<VectorLiteral> vector;

VectorRangeExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&range,
std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), range(std::move(range)), vector(std::move(vector)) {}

std::string_view Name() const override { return "VectorRangeExpr"; }
std::string Dump() const override {
return fmt::format("{} <-> {} < {}", field->Dump(), vector->Dump(), range->Dump());
}

std::unique_ptr<Node> Clone() const override {
return std::make_unique<VectorRangeExpr>(Node::MustAs<FieldRef>(field->Clone()),
Node::MustAs<NumericLiteral>(range->Clone()),
Node::MustAs<VectorLiteral>(vector->Clone()));
}
};

struct VectorKnnExpr : BoolAtomExpr {
// TODO: Support pre-filter for hybrid query
std::unique_ptr<FieldRef> field;
std::unique_ptr<NumericLiteral> k;
std::unique_ptr<VectorLiteral> vector;

VectorKnnExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&k,
std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {}

std::string_view Name() const override { return "VectorKnnExpr"; }
std::string Dump() const override {
return fmt::format("KNN k={}, {} <-> {}", k->Dump(), field->Dump(), vector->Dump());
}

std::unique_ptr<Node> Clone() const override {
return std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(field->Clone()),
Node::MustAs<NumericLiteral>(k->Clone()),
Node::MustAs<VectorLiteral>(vector->Clone()));
}
};

struct BoolLiteral : BoolAtomExpr, Literal {
bool val;

Expand Down Expand Up @@ -336,18 +393,30 @@ struct LimitClause : Node {
std::string Content() const override { return fmt::format("{}, {}", offset, count); }

std::unique_ptr<Node> Clone() const override { return std::make_unique<LimitClause>(*this); }
size_t Offset() const { return offset; }

size_t Count() const { return count; }
};

struct SortByClause : Node {
enum Order { ASC, DESC } order = ASC;
std::unique_ptr<FieldRef> field;
std::unique_ptr<VectorLiteral> vector = nullptr;

SortByClause(Order order, std::unique_ptr<FieldRef> &&field) : order(order), field(std::move(field)) {}
SortByClause(std::unique_ptr<FieldRef> &&field, std::unique_ptr<VectorLiteral> &&vector)
: field(std::move(field)), vector(std::move(vector)) {}

static constexpr const char *OrderToString(Order order) { return order == ASC ? "asc" : "desc"; }
bool IsVectorField() const { return vector != nullptr; }

std::string_view Name() const override { return "SortByClause"; }
std::string Dump() const override { return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); }
std::string Dump() const override {
if (!IsVectorField()) {
return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order));
}
return fmt::format("sortby {} <-> {}", field->Dump(), vector->Dump());
}
std::string Content() const override { return OrderToString(order); }

NodeIterator ChildBegin() override { return NodeIterator(field.get()); };
Expand Down
63 changes: 63 additions & 0 deletions src/search/ir_sema_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ struct SemaChecker {
GET_OR_RET(Check(v->query_expr.get()));
if (v->limit) GET_OR_RET(Check(v->limit.get()));
if (v->sort_by) GET_OR_RET(Check(v->sort_by.get()));
if (v->sort_by && v->sort_by->IsVectorField() && !v->limit) {
return {Status::NotOK, "expect a LIMIT clause for vector field to construct a KNN search"};
}
} else {
return {Status::NotOK, fmt::format("index `{}` not found", index_name)};
}
Expand All @@ -60,8 +63,25 @@ struct SemaChecker {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
} else if (!iter->second.IsSortable()) {
return {Status::NotOK, fmt::format("field `{}` is not sortable", v->field->name)};
} else if (auto is_vector = iter->second.MetadataAs<redis::HnswVectorFieldMetadata>() != nullptr;
is_vector != v->IsVectorField()) {
std::string not_str = is_vector ? "" : "not ";
return {Status::NotOK,
fmt::format("field `{}` is {}a vector field according to metadata and does {}expect a vector parameter",
v->field->name, not_str, not_str)};
} else {
v->field->info = &iter->second;
if (v->IsVectorField()) {
auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (!v->field->info->HasIndex()) {
return {Status::NotOK,
fmt::format("field `{}` is marked as NOINDEX and cannot be used for KNN search", v->field->name)};
}
if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)};
}
}
}
} else if (auto v = dynamic_cast<AndExpr *>(node)) {
for (const auto &n : v->inners) {
Expand Down Expand Up @@ -97,6 +117,49 @@ struct SemaChecker {
} else {
v->field->info = &iter->second;
}
} else if (auto v = dynamic_cast<VectorKnnExpr *>(node)) {
if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
} else if (!iter->second.MetadataAs<redis::HnswVectorFieldMetadata>()) {
return {Status::NotOK, fmt::format("field `{}` is not a vector field", v->field->name)};
} else {
v->field->info = &iter->second;

if (!v->field->info->HasIndex()) {
return {Status::NotOK,
fmt::format("field `{}` is marked as NOINDEX and cannot be used for KNN search", v->field->name)};
}
if (v->k->val <= 0) {
return {Status::NotOK, fmt::format("KNN search parameter `k` must be greater than 0")};
}
auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)};
}
}
} else if (auto v = dynamic_cast<VectorRangeExpr *>(node)) {
if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
} else if (!iter->second.MetadataAs<redis::HnswVectorFieldMetadata>()) {
return {Status::NotOK, fmt::format("field `{}` is not a vector field", v->field->name)};
} else {
v->field->info = &iter->second;

auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (meta->distance_metric == redis::DistanceMetric::L2 && v->range->val < 0) {
return {Status::NotOK, "range cannot be a negative number for l2 distance metric"};
}

if (meta->distance_metric == redis::DistanceMetric::COSINE && (v->range->val < 0 || v->range->val > 2)) {
return {Status::NotOK, "range has to be between 0 and 2 for cosine distance metric"};
}

if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
fmt::format("vector should be of size `{}` for field `{}`", meta->dim, v->field->name)};
}
}
} else if (auto v = dynamic_cast<SelectClause *>(node)) {
for (const auto &n : v->fields) {
if (auto iter = current_index->fields.find(n->name); iter == current_index->fields.end()) {
Expand Down
16 changes: 13 additions & 3 deletions src/search/redis_query_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ namespace redis_query {

using namespace peg;

struct VectorRangeToken : string<'V', 'E', 'C', 'T', 'O', 'R', '_', 'R', 'A', 'N', 'G', 'E'> {};
struct KnnToken : string<'K', 'N', 'N'> {};
struct ArrowOp : string<'=', '>'> {};
struct Wildcard : one<'*'> {};

struct Field : seq<one<'@'>, Identifier> {};

struct Param : seq<one<'$'>, Identifier> {};
Expand All @@ -44,9 +49,10 @@ struct ExclusiveNumber : seq<one<'('>, NumberOrParam> {};
struct NumericRangePart : sor<Inf, ExclusiveNumber, NumberOrParam> {};
struct NumericRange : seq<one<'['>, WSPad<NumericRangePart>, WSPad<NumericRangePart>, one<']'>> {};

struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<TagList, NumericRange>>> {};
struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<NumberOrParam>, WSPad<Field>, WSPad<Param>, one<']'>> {};
struct VectorRange : seq<one<'['>, WSPad<VectorRangeToken>, WSPad<NumberOrParam>, WSPad<Param>, one<']'>> {};

struct Wildcard : one<'*'> {};
struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<VectorRange, TagList, NumericRange>>> {};

struct QueryExpr;

Expand All @@ -64,7 +70,11 @@ struct AndExprP : sor<AndExpr, BooleanExpr> {};
struct OrExpr : seq<AndExprP, plus<seq<one<'|'>, AndExprP>>> {};
struct OrExprP : sor<OrExpr, AndExprP> {};

struct QueryExpr : seq<OrExprP> {};
struct PrefilterExpr : seq<WSPad<BooleanExpr>, ArrowOp, WSPad<KnnSearch>> {};

struct QueryP : sor<PrefilterExpr, OrExprP> {};

struct QueryExpr : seq<QueryP> {};

} // namespace redis_query

Expand Down
64 changes: 45 additions & 19 deletions src/search/redis_query_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ namespace redis_query {
namespace ir = kqir;

template <typename Rule>
using TreeSelector =
parse_tree::selector<Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, ExclusiveNumber, FieldQuery, NotExpr,
AndExpr, OrExpr, Wildcard>>;
using TreeSelector = parse_tree::selector<
Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, VectorRange, ExclusiveNumber, FieldQuery, NotExpr, AndExpr,
OrExpr, PrefilterExpr, KnnSearch, Wildcard, VectorRangeToken, KnnToken, ArrowOp>>;

template <typename Input>
StatusOr<std::unique_ptr<parse_tree::node>> ParseToTree(Input&& in) {
Expand All @@ -53,7 +53,31 @@ StatusOr<std::unique_ptr<parse_tree::node>> ParseToTree(Input&& in) {
struct Transformer : ir::TreeTransformer {
explicit Transformer(const ParamMap& param_map) : TreeTransformer(param_map) {}

StatusOr<std::unique_ptr<VectorLiteral>> Transform2Vector(const TreeNode& node) {
std::string vector_str = GET_OR_RET(GetParam(node));

std::vector<double> values = GET_OR_RET(Binary2Vector<double>(vector_str));
if (values.empty()) {
return {Status::NotOK, "empty vector is invalid"};
}
return std::make_unique<ir::VectorLiteral>(std::move(values));
};

auto Transform(const TreeNode& node) -> StatusOr<std::unique_ptr<Node>> {
auto number_or_param = [this](const TreeNode& node) -> StatusOr<std::unique_ptr<NumericLiteral>> {
if (Is<Number>(node)) {
return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
} else if (Is<Param>(node)) {
auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
.Prefixed(fmt::format("parameter {} is not a number", node->string_view())));

return std::make_unique<ir::NumericLiteral>(val);
} else {
return {Status::NotOK,
fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)};
}
};

if (Is<Number>(node)) {
return Node::Create<ir::NumericLiteral>(*ParseFloat(node->string()));
} else if (Is<Wildcard>(node)) {
Expand Down Expand Up @@ -88,26 +112,12 @@ struct Transformer : ir::TreeTransformer {
} else {
return std::make_unique<ir::OrExpr>(std::move(exprs));
}
} else { // NumericRange
} else if (Is<NumericRange>(query)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;

const auto& lhs = query->children[0];
const auto& rhs = query->children[1];

auto number_or_param = [this](const TreeNode& node) -> StatusOr<std::unique_ptr<NumericLiteral>> {
if (Is<Number>(node)) {
return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
} else if (Is<Param>(node)) {
auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
.Prefixed(fmt::format("parameter {} is not a number", node->string_view())));

return std::make_unique<ir::NumericLiteral>(val);
} else {
return {Status::NotOK,
fmt::format("expected a number or a parameter in numeric comparison but got {}", node->type)};
}
};

if (Is<ExclusiveNumber>(lhs)) {
exprs.push_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GT,
std::make_unique<FieldRef>(field),
Expand Down Expand Up @@ -141,11 +151,27 @@ struct Transformer : ir::TreeTransformer {
} else {
return std::make_unique<ir::AndExpr>(std::move(exprs));
}
} else if (Is<VectorRange>(query)) {
return std::make_unique<VectorRangeExpr>(std::make_unique<FieldRef>(field),
GET_OR_RET(number_or_param(query->children[1])),
GET_OR_RET(Transform2Vector(query->children[2])));
}
} else if (Is<NotExpr>(node)) {
CHECK(node->children.size() == 1);

return Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
} else if (Is<PrefilterExpr>(node)) {
CHECK(node->children.size() == 3);

// TODO(Beihao): Support Hybrid Query
// const auto& prefilter = node->children[0];
const auto& knn_search = node->children[2];
CHECK(knn_search->children.size() == 4);

return std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
GET_OR_RET(number_or_param(knn_search->children[1])),
GET_OR_RET(Transform2Vector(knn_search->children[3])));

} else if (Is<AndExpr>(node)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;

Expand Down
2 changes: 2 additions & 0 deletions src/search/search_encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ struct HnswVectorFieldMetadata : IndexFieldMetadata {

HnswVectorFieldMetadata() : IndexFieldMetadata(IndexFieldType::VECTOR) {}

bool IsSortable() const override { return true; }

void Encode(std::string *dst) const override {
IndexFieldMetadata::Encode(dst);
PutFixed8(dst, uint8_t(vector_type));
Expand Down
11 changes: 9 additions & 2 deletions src/search/sql_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ struct NumericAtomExpr : WSPad<sor<NumberOrParam, Identifier>> {};
struct NumericCompareOp : sor<string<'!', '='>, string<'<', '='>, string<'>', '='>, one<'=', '<', '>'>> {};
struct NumericCompareExpr : seq<NumericAtomExpr, NumericCompareOp, NumericAtomExpr> {};

struct BooleanAtomExpr : sor<HasTagExpr, NumericCompareExpr, WSPad<Boolean>> {};
struct VectorCompareOp : string<'<', '-', '>'> {};
struct VectorLiteral : seq<WSPad<one<'['>>, Number, star<seq<WSPad<one<','>>>, Number>, WSPad<one<']'>>> {};
struct VectorCompareExpr : seq<WSPad<Identifier>, VectorCompareOp, WSPad<VectorLiteral>> {};
struct VectorRangeExpr : seq<VectorCompareExpr, one<'<'>, WSPad<NumberOrParam>> {};

struct BooleanAtomExpr : sor<HasTagExpr, NumericCompareExpr, VectorRangeExpr, WSPad<Boolean>> {};

struct QueryExpr;

Expand Down Expand Up @@ -84,7 +89,9 @@ struct Limit : string<'l', 'i', 'm', 'i', 't'> {};

struct WhereClause : seq<Where, QueryExpr> {};
struct AscOrDesc : sor<Asc, Desc> {};
struct OrderByClause : seq<OrderBy, WSPad<Identifier>, opt<WSPad<AscOrDesc>>> {};
struct SortableFieldExpr : seq<WSPad<Identifier>, opt<AscOrDesc>> {};
struct OrderByExpr : sor<WSPad<VectorCompareExpr>, WSPad<SortableFieldExpr>> {};
struct OrderByClause : seq<OrderBy, OrderByExpr> {};
struct LimitClause : seq<Limit, opt<seq<WSPad<UnsignedInteger>, one<','>>>, WSPad<UnsignedInteger>> {};

struct SearchStmt
Expand Down
Loading

0 comments on commit 0f5f18e

Please sign in to comment.