diff --git a/collection.go b/collection.go index ec641a1..0e2e7f6 100644 --- a/collection.go +++ b/collection.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "path/filepath" - "slices" "sync" ) @@ -64,7 +63,7 @@ type QueryOptions struct { Where map[string]string // Conditional filtering on documents. - WhereDocument map[string]string + WhereDocuments []WhereDocument // Negative is the negative query options. // They can be used to exclude certain results from the query. @@ -296,9 +295,9 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { // - where: Conditional filtering on metadata. Optional. // - whereDocument: Conditional filtering on documents. Optional. // - ids: The ids of the documents to delete. If empty, all documents are deleted. -func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error { +func (c *Collection) Delete(_ context.Context, where map[string]string, whereDocuments []WhereDocument, ids ...string) error { // must have at least one of where, whereDocument or ids - if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 { + if len(where) == 0 && len(whereDocuments) == 0 && len(ids) == 0 { return fmt.Errorf("must have at least one of where, whereDocument or ids") } @@ -306,9 +305,9 @@ func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]s return nil } - for k := range whereDocument { - if !slices.Contains(supportedFilters, k) { - return errors.New("unsupported whereDocument operator") + for _, whereDocument := range whereDocuments { + if err := whereDocument.Validate(); err != nil { + return fmt.Errorf("invalid whereDocument %#v: %w", whereDocument, err) } } @@ -317,9 +316,9 @@ func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]s c.documentsLock.Lock() defer c.documentsLock.Unlock() - if where != nil || whereDocument != nil { + if where != nil || len(whereDocuments) > 0 { // metadata + content filters - filteredDocs := filterDocs(c.documents, where, whereDocument) + filteredDocs := filterDocs(c.documents, where, whereDocuments) for _, doc := range filteredDocs { docIDs = append(docIDs, doc.ID) } @@ -376,7 +375,7 @@ type Result struct { // There can be fewer results if a filter is applied. // - where: Conditional filtering on metadata. Optional. // - whereDocument: Conditional filtering on documents. Optional. -func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { +func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where map[string]string, whereDocument []WhereDocument) ([]Result, error) { if queryText == "" { return nil, errors.New("queryText is empty") } @@ -432,7 +431,7 @@ func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions) } } - result, err := c.queryEmbedding(ctx, queryVector, negativeVector, negativeFilterThreshold, options.NResults, options.Where, options.WhereDocument) + result, err := c.queryEmbedding(ctx, queryVector, negativeVector, negativeFilterThreshold, options.NResults, options.Where, options.WhereDocuments) if err != nil { return nil, err } @@ -449,12 +448,12 @@ func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions) // There can be fewer results if a filter is applied. // - where: Conditional filtering on metadata. Optional. // - whereDocument: Conditional filtering on documents. Optional. -func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) { - return c.queryEmbedding(ctx, queryEmbedding, nil, 0, nResults, where, whereDocument) +func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where map[string]string, whereDocuments []WhereDocument) ([]Result, error) { + return c.queryEmbedding(ctx, queryEmbedding, nil, 0, nResults, where, whereDocuments) } // queryEmbedding performs an exhaustive nearest neighbor search on the collection. -func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativeEmbeddings []float32, negativeFilterThreshold float32, nResults int, where, whereDocument map[string]string) ([]Result, error) { +func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativeEmbeddings []float32, negativeFilterThreshold float32, nResults int, where map[string]string, whereDocuments []WhereDocument) ([]Result, error) { if len(queryEmbedding) == 0 { return nil, errors.New("queryEmbedding is empty") } @@ -472,14 +471,14 @@ func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativ } // Validate whereDocument operators - for k := range whereDocument { - if !slices.Contains(supportedFilters, k) { - return nil, errors.New("unsupported operator") + for _, whereDocument := range whereDocuments { + if err := whereDocument.Validate(); err != nil { + return nil, fmt.Errorf("invalid whereDocument %#v: %w", whereDocument, err) } } // Filter docs by metadata and content - filteredDocs := filterDocs(c.documents, where, whereDocument) + filteredDocs := filterDocs(c.documents, where, whereDocuments) // No need to continue if the filters got rid of all documents if len(filteredDocs) == 0 { diff --git a/collection_test.go b/collection_test.go index 6ec2738..ec4f270 100644 --- a/collection_test.go +++ b/collection_test.go @@ -7,6 +7,7 @@ import ( "os" "slices" "strconv" + "strings" "testing" ) @@ -372,10 +373,10 @@ func TestCollection_QueryError(t *testing.T) { { name: "Bad content filter", query: func() error { - _, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"}) + _, err := c.Query(context.Background(), "foo", 1, nil, []WhereDocument{{Operator: "invalid", Value: "foo"}}) return err }, - expErr: "unsupported operator", + expErr: "unsupported where document operator invalid", }, } @@ -384,7 +385,7 @@ func TestCollection_QueryError(t *testing.T) { err := tc.query() if err == nil { t.Fatal("expected error, got nil") - } else if err.Error() != tc.expErr { + } else if !strings.Contains(err.Error(), tc.expErr) { t.Fatal("expected", tc.expErr, "got", err) } }) @@ -502,7 +503,7 @@ func TestCollection_Delete(t *testing.T) { checkCount(1) // Test 3 - Remove document by content - err = c.Delete(context.Background(), nil, map[string]string{"$contains": "hallo welt"}) + err = c.Delete(context.Background(), nil, []WhereDocument{WhereDocument{Operator: WhereDocumentOperatorContains, Value: "hallo welt"}}) if err != nil { t.Fatal("expected nil, got", err) } diff --git a/query.go b/query.go index 34d406e..fda01cc 100644 --- a/query.go +++ b/query.go @@ -11,8 +11,6 @@ import ( "sync" ) -var supportedFilters = []string{"$contains", "$not_contains"} - type docSim struct { docID string similarity float32 @@ -84,7 +82,7 @@ func (d *maxDocSims) values() []docSim { // filterDocs filters a map of documents by metadata and content. // It does this concurrently. -func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document { +func filterDocs(docs map[string]*Document, where map[string]string, whereDocuments []WhereDocument) []*Document { filteredDocs := make([]*Document, 0, len(docs)) filteredDocsLock := sync.Mutex{} @@ -104,7 +102,7 @@ func filterDocs(docs map[string]*Document, where, whereDocument map[string]strin go func() { defer wg.Done() for doc := range docChan { - if documentMatchesFilters(doc, where, whereDocument) { + if documentMatchesFilters(doc, where, whereDocuments) { filteredDocsLock.Lock() filteredDocs = append(filteredDocs, doc) filteredDocsLock.Unlock() @@ -128,9 +126,84 @@ func filterDocs(docs map[string]*Document, where, whereDocument map[string]strin return filteredDocs } +type WhereDocumentOperator string + +const ( + WhereDocumentOperatorContains WhereDocumentOperator = "$contains" + WhereDocumentOperatorNotContains WhereDocumentOperator = "$not_contains" + WhereDocumentOperatorOr WhereDocumentOperator = "$or" + WhereDocumentOperatorAnd WhereDocumentOperator = "$and" +) + +type WhereDocument struct { + Operator WhereDocumentOperator + Value string + WhereDocuments []WhereDocument +} + +func (wd *WhereDocument) Validate() error { + + if !slices.Contains([]WhereDocumentOperator{WhereDocumentOperatorContains, WhereDocumentOperatorNotContains, WhereDocumentOperatorOr, WhereDocumentOperatorAnd}, wd.Operator) { + return fmt.Errorf("unsupported where document operator %s", wd.Operator) + } + + if wd.Operator == "" { + return fmt.Errorf("where document operator is empty") + } + + // $contains and $not_contains require a string value + if slices.Contains([]WhereDocumentOperator{WhereDocumentOperatorContains, WhereDocumentOperatorNotContains}, wd.Operator) { + if wd.Value == "" { + return fmt.Errorf("where document operator %s requires a value", wd.Operator) + } + } + + // $or requires sub-filters + if slices.Contains([]WhereDocumentOperator{WhereDocumentOperatorOr, WhereDocumentOperatorAnd}, wd.Operator) { + if len(wd.WhereDocuments) == 0 { + return fmt.Errorf("where document operator %s must have at least one sub-filter", wd.Operator) + } + } + + for _, wd := range wd.WhereDocuments { + if err := wd.Validate(); err != nil { + return err + } + } + + return nil +} + +// Matches checks if a document matches the WhereDocument filter(s) +// There is no error checking on the WhereDocument struct, so it must be validated before calling this function. +func (wd *WhereDocument) Matches(doc *Document) bool { + switch wd.Operator { + case WhereDocumentOperatorContains: + return strings.Contains(doc.Content, wd.Value) + case WhereDocumentOperatorNotContains: + return !strings.Contains(doc.Content, wd.Value) + case WhereDocumentOperatorOr: + for _, subFilter := range wd.WhereDocuments { + if subFilter.Matches(doc) { + return true + } + } + return false + case WhereDocumentOperatorAnd: + for _, subFilter := range wd.WhereDocuments { + if !subFilter.Matches(doc) { + return false + } + } + return true + default: + return false + } +} + // documentMatchesFilters checks if a document matches the given filters. -// When calling this function, the whereDocument keys must already be validated! -func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool { +// When calling this function, the whereDocument structs must already be validated! +func documentMatchesFilters(document *Document, where map[string]string, whereDocuments []WhereDocument) bool { // A document's metadata must have *all* the fields in the where clause. for k, v := range where { // TODO: Do we want to check for existence of the key? I.e. should @@ -141,21 +214,10 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] } } - // A document must satisfy *all* filters, until we support the `$or` operator. - for k, v := range whereDocument { - switch k { - case "$contains": - if !strings.Contains(document.Content, v) { - return false - } - case "$not_contains": - if strings.Contains(document.Content, v) { - return false - } - default: - // No handling (error) required because we already validated the - // operators. This simplifies the concurrency logic (no err var - // and lock, no context to cancel). + // A document must satisfy *all* WhereDocument filters (that's basically a top-level $and operator) + for _, whereDocument := range whereDocuments { + if !whereDocument.Matches(document) { + return false } } diff --git a/query_test.go b/query_test.go index 104ed8f..1f42ba8 100644 --- a/query_test.go +++ b/query_test.go @@ -25,12 +25,24 @@ func TestFilterDocs(t *testing.T) { Embedding: []float32{0.2, 0.3, 0.4}, Content: "hallo welt", }, + "3": { + ID: "3", + Content: "bonjour and hello foo baz bom", + }, + "4": { + ID: "4", + Content: "bonjour and hello foo bar baz", + }, + "5": { + ID: "5", + Content: "bonjour and hello spam eggs", + }, } tt := []struct { name string where map[string]string - whereDocument map[string]string + whereDocument []WhereDocument want []*Document }{ { @@ -48,60 +60,81 @@ func TestFilterDocs(t *testing.T) { { name: "content contains all", where: nil, - whereDocument: map[string]string{"$contains": "llo"}, - want: []*Document{docs["1"], docs["2"]}, + whereDocument: []WhereDocument{{Operator: WhereDocumentOperatorContains, Value: "llo"}}, + want: []*Document{docs["1"], docs["2"], docs["3"], docs["4"], docs["5"]}, }, { name: "content contains one", where: nil, - whereDocument: map[string]string{"$contains": "hallo"}, + whereDocument: []WhereDocument{{Operator: WhereDocumentOperatorContains, Value: "hallo"}}, want: []*Document{docs["2"]}, }, { name: "content contains none", where: nil, - whereDocument: map[string]string{"$contains": "bonjour"}, + whereDocument: []WhereDocument{{Operator: WhereDocumentOperatorContains, Value: "salute"}}, want: nil, }, { name: "content not_contains all", where: nil, - whereDocument: map[string]string{"$not_contains": "bonjour"}, + whereDocument: []WhereDocument{{Operator: WhereDocumentOperatorNotContains, Value: "bonjour"}}, want: []*Document{docs["1"], docs["2"]}, }, { name: "content not_contains one", where: nil, - whereDocument: map[string]string{"$not_contains": "hello"}, + whereDocument: []WhereDocument{{Operator: WhereDocumentOperatorNotContains, Value: "hello"}}, want: []*Document{docs["2"]}, }, { name: "meta and content match", where: map[string]string{"language": "de"}, - whereDocument: map[string]string{"$contains": "hallo"}, + whereDocument: []WhereDocument{{Operator: WhereDocumentOperatorContains, Value: "hallo"}}, want: []*Document{docs["2"]}, }, { name: "meta + contains + not_contains", where: map[string]string{"language": "de"}, - whereDocument: map[string]string{"$contains": "hallo", "$not_contains": "bonjour"}, + whereDocument: []WhereDocument{{Operator: WhereDocumentOperatorContains, Value: "hallo"}, {Operator: WhereDocumentOperatorNotContains, Value: "bonjour"}}, want: []*Document{docs["2"]}, }, + { + name: "contains or (contains and not_contains", + whereDocument: []WhereDocument{ + {Operator: WhereDocumentOperatorOr, WhereDocuments: []WhereDocument{ + {Operator: WhereDocumentOperatorContains, Value: "bar"}, + {Operator: WhereDocumentOperatorAnd, WhereDocuments: []WhereDocument{ + {Operator: WhereDocumentOperatorContains, Value: "bonjour"}, + {Operator: WhereDocumentOperatorNotContains, Value: "foo"}, + }, + }, + }}, + }, + want: []*Document{docs["4"], docs["5"]}, + }, + } + + // To avoid issues with checking equality of concurrently produced slices, we sort by ID + sortDocs := func(d []*Document) { + slices.SortFunc(d, func(d1, d2 *Document) int { + if d1.ID < d2.ID { + return -1 + } + if d1.ID > d2.ID { + return 1 + } + return 0 + }) } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { got := filterDocs(docs, tc.where, tc.whereDocument) + sortDocs(got) + sortDocs(tc.want) if !reflect.DeepEqual(got, tc.want) { - // If len is 2, the order might be different (function under test - // is concurrent and order is not guaranteed). - if len(got) == 2 && len(tc.want) == 2 { - slices.Reverse(got) - if reflect.DeepEqual(got, tc.want) { - return - } - } t.Fatalf("got %v; want %v", got, tc.want) } })