Skip to content

Commit

Permalink
feat: WhereDocument filter with $and, $or, $contains and $not_contain…
Browse files Browse the repository at this point in the history
…s filters
  • Loading branch information
iwilltry42 committed Aug 14, 2024
1 parent a194428 commit 3c24b09
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 60 deletions.
35 changes: 17 additions & 18 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"path/filepath"
"slices"
"sync"
)

Expand Down Expand Up @@ -64,7 +63,7 @@ type QueryOptions struct {
Where map[string]string

// Conditional filtering on documents.
WhereDocument map[string]string
WhereDocuments []WhereDocument

// Negative is the negative query options.
// They can be used to exclude certain results from the query.
Expand Down Expand Up @@ -296,19 +295,19 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
// - ids: The ids of the documents to delete. If empty, all documents are deleted.
func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error {
func (c *Collection) Delete(_ context.Context, where map[string]string, whereDocuments []WhereDocument, ids ...string) error {
// must have at least one of where, whereDocument or ids
if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 {
if len(where) == 0 && len(whereDocuments) == 0 && len(ids) == 0 {
return fmt.Errorf("must have at least one of where, whereDocument or ids")
}

if len(c.documents) == 0 {
return nil
}

for k := range whereDocument {
if !slices.Contains(supportedFilters, k) {
return errors.New("unsupported whereDocument operator")
for _, whereDocument := range whereDocuments {
if err := whereDocument.Validate(); err != nil {
return fmt.Errorf("invalid whereDocument %#v: %w", whereDocument, err)
}
}

Expand All @@ -317,9 +316,9 @@ func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]s
c.documentsLock.Lock()
defer c.documentsLock.Unlock()

if where != nil || whereDocument != nil {
if where != nil || len(whereDocuments) > 0 {
// metadata + content filters
filteredDocs := filterDocs(c.documents, where, whereDocument)
filteredDocs := filterDocs(c.documents, where, whereDocuments)
for _, doc := range filteredDocs {
docIDs = append(docIDs, doc.ID)
}
Expand Down Expand Up @@ -376,7 +375,7 @@ type Result struct {
// There can be fewer results if a filter is applied.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where map[string]string, whereDocument []WhereDocument) ([]Result, error) {
if queryText == "" {
return nil, errors.New("queryText is empty")
}
Expand Down Expand Up @@ -432,7 +431,7 @@ func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions)
}
}

result, err := c.queryEmbedding(ctx, queryVector, negativeVector, negativeFilterThreshold, options.NResults, options.Where, options.WhereDocument)
result, err := c.queryEmbedding(ctx, queryVector, negativeVector, negativeFilterThreshold, options.NResults, options.Where, options.WhereDocuments)
if err != nil {
return nil, err
}
Expand All @@ -449,12 +448,12 @@ func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions)
// There can be fewer results if a filter is applied.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
return c.queryEmbedding(ctx, queryEmbedding, nil, 0, nResults, where, whereDocument)
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where map[string]string, whereDocuments []WhereDocument) ([]Result, error) {
return c.queryEmbedding(ctx, queryEmbedding, nil, 0, nResults, where, whereDocuments)
}

// queryEmbedding performs an exhaustive nearest neighbor search on the collection.
func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativeEmbeddings []float32, negativeFilterThreshold float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativeEmbeddings []float32, negativeFilterThreshold float32, nResults int, where map[string]string, whereDocuments []WhereDocument) ([]Result, error) {
if len(queryEmbedding) == 0 {
return nil, errors.New("queryEmbedding is empty")
}
Expand All @@ -472,14 +471,14 @@ func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativ
}

// Validate whereDocument operators
for k := range whereDocument {
if !slices.Contains(supportedFilters, k) {
return nil, errors.New("unsupported operator")
for _, whereDocument := range whereDocuments {
if err := whereDocument.Validate(); err != nil {
return nil, fmt.Errorf("invalid whereDocument %#v: %w", whereDocument, err)
}
}

// Filter docs by metadata and content
filteredDocs := filterDocs(c.documents, where, whereDocument)
filteredDocs := filterDocs(c.documents, where, whereDocuments)

// No need to continue if the filters got rid of all documents
if len(filteredDocs) == 0 {
Expand Down
9 changes: 5 additions & 4 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"slices"
"strconv"
"strings"
"testing"
)

Expand Down Expand Up @@ -372,10 +373,10 @@ func TestCollection_QueryError(t *testing.T) {
{
name: "Bad content filter",
query: func() error {
_, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"})
_, err := c.Query(context.Background(), "foo", 1, nil, []WhereDocument{{Operator: "invalid", Value: "foo"}})
return err
},
expErr: "unsupported operator",
expErr: "unsupported where document operator invalid",
},
}

Expand All @@ -384,7 +385,7 @@ func TestCollection_QueryError(t *testing.T) {
err := tc.query()
if err == nil {
t.Fatal("expected error, got nil")
} else if err.Error() != tc.expErr {
} else if !strings.Contains(err.Error(), tc.expErr) {
t.Fatal("expected", tc.expErr, "got", err)
}
})
Expand Down Expand Up @@ -502,7 +503,7 @@ func TestCollection_Delete(t *testing.T) {
checkCount(1)

// Test 3 - Remove document by content
err = c.Delete(context.Background(), nil, map[string]string{"$contains": "hallo welt"})
err = c.Delete(context.Background(), nil, []WhereDocument{WhereDocument{Operator: WhereDocumentOperatorContains, Value: "hallo welt"}})
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down
104 changes: 83 additions & 21 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (
"sync"
)

var supportedFilters = []string{"$contains", "$not_contains"}

type docSim struct {
docID string
similarity float32
Expand Down Expand Up @@ -84,7 +82,7 @@ func (d *maxDocSims) values() []docSim {

// filterDocs filters a map of documents by metadata and content.
// It does this concurrently.
func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document {
func filterDocs(docs map[string]*Document, where map[string]string, whereDocuments []WhereDocument) []*Document {
filteredDocs := make([]*Document, 0, len(docs))
filteredDocsLock := sync.Mutex{}

Expand All @@ -104,7 +102,7 @@ func filterDocs(docs map[string]*Document, where, whereDocument map[string]strin
go func() {
defer wg.Done()
for doc := range docChan {
if documentMatchesFilters(doc, where, whereDocument) {
if documentMatchesFilters(doc, where, whereDocuments) {
filteredDocsLock.Lock()
filteredDocs = append(filteredDocs, doc)
filteredDocsLock.Unlock()
Expand All @@ -128,9 +126,84 @@ func filterDocs(docs map[string]*Document, where, whereDocument map[string]strin
return filteredDocs
}

type WhereDocumentOperator string

const (
WhereDocumentOperatorContains WhereDocumentOperator = "$contains"
WhereDocumentOperatorNotContains WhereDocumentOperator = "$not_contains"
WhereDocumentOperatorOr WhereDocumentOperator = "$or"
WhereDocumentOperatorAnd WhereDocumentOperator = "$and"
)

type WhereDocument struct {
Operator WhereDocumentOperator
Value string
WhereDocuments []WhereDocument
}

func (wd *WhereDocument) Validate() error {

if !slices.Contains([]WhereDocumentOperator{WhereDocumentOperatorContains, WhereDocumentOperatorNotContains, WhereDocumentOperatorOr, WhereDocumentOperatorAnd}, wd.Operator) {
return fmt.Errorf("unsupported where document operator %s", wd.Operator)
}

if wd.Operator == "" {
return fmt.Errorf("where document operator is empty")
}

// $contains and $not_contains require a string value
if slices.Contains([]WhereDocumentOperator{WhereDocumentOperatorContains, WhereDocumentOperatorNotContains}, wd.Operator) {
if wd.Value == "" {
return fmt.Errorf("where document operator %s requires a value", wd.Operator)
}
}

// $or requires sub-filters
if slices.Contains([]WhereDocumentOperator{WhereDocumentOperatorOr, WhereDocumentOperatorAnd}, wd.Operator) {
if len(wd.WhereDocuments) == 0 {
return fmt.Errorf("where document operator %s must have at least one sub-filter", wd.Operator)
}
}

for _, wd := range wd.WhereDocuments {
if err := wd.Validate(); err != nil {
return err
}
}

return nil
}

// Matches checks if a document matches the WhereDocument filter(s)
// There is no error checking on the WhereDocument struct, so it must be validated before calling this function.
func (wd *WhereDocument) Matches(doc *Document) bool {
switch wd.Operator {
case WhereDocumentOperatorContains:
return strings.Contains(doc.Content, wd.Value)
case WhereDocumentOperatorNotContains:
return !strings.Contains(doc.Content, wd.Value)
case WhereDocumentOperatorOr:
for _, subFilter := range wd.WhereDocuments {
if subFilter.Matches(doc) {
return true
}
}
return false
case WhereDocumentOperatorAnd:
for _, subFilter := range wd.WhereDocuments {
if !subFilter.Matches(doc) {
return false
}
}
return true
default:
return false
}
}

// documentMatchesFilters checks if a document matches the given filters.
// When calling this function, the whereDocument keys must already be validated!
func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool {
// When calling this function, the whereDocument structs must already be validated!
func documentMatchesFilters(document *Document, where map[string]string, whereDocuments []WhereDocument) bool {
// A document's metadata must have *all* the fields in the where clause.
for k, v := range where {
// TODO: Do we want to check for existence of the key? I.e. should
Expand All @@ -141,21 +214,10 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string]
}
}

// A document must satisfy *all* filters, until we support the `$or` operator.
for k, v := range whereDocument {
switch k {
case "$contains":
if !strings.Contains(document.Content, v) {
return false
}
case "$not_contains":
if strings.Contains(document.Content, v) {
return false
}
default:
// No handling (error) required because we already validated the
// operators. This simplifies the concurrency logic (no err var
// and lock, no context to cancel).
// A document must satisfy *all* WhereDocument filters (that's basically a top-level $and operator)
for _, whereDocument := range whereDocuments {
if !whereDocument.Matches(document) {
return false
}
}

Expand Down
Loading

0 comments on commit 3c24b09

Please sign in to comment.