Skip to content

Commit

Permalink
Add unit test for query errors
Browse files Browse the repository at this point in the history
  • Loading branch information
philippgille committed Mar 17, 2024
1 parent a8b7e80 commit 98516b1
Showing 1 changed file with 79 additions and 0 deletions.
79 changes: 79 additions & 0 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,85 @@ func TestCollection_AddConcurrently_Error(t *testing.T) {
}
}

func TestCollection_QueryError(t *testing.T) {
// Create collection
db := NewDB()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
if c == nil {
t.Fatal("expected collection, got nil")
}
// Add a document
err = c.AddDocument(context.Background(), Document{ID: "1", Content: "hello world"})
if err != nil {
t.Fatal("expected nil, got", err)
}

tt := []struct {
name string
query func() error
expErr string
}{
{
name: "Empty query",
query: func() error {
_, err := c.Query(context.Background(), "", 1, nil, nil)
return err
},
expErr: "queryText is empty",
},
{
name: "Negative limit",
query: func() error {
_, err := c.Query(context.Background(), "foo", -1, nil, nil)
return err
},
expErr: "nResults must be > 0",
},
{
name: "Zero limit",
query: func() error {
_, err := c.Query(context.Background(), "foo", 0, nil, nil)
return err
},
expErr: "nResults must be > 0",
},
{
name: "Limit greater than number of documents",
query: func() error {
_, err := c.Query(context.Background(), "foo", 2, nil, nil)
return err
},
expErr: "nResults must be <= the number of documents in the collection",
},
{
name: "Bad content filter",
query: func() error {
_, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"})
return err
},
expErr: "unsupported operator",
},
}

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
err := tc.query()
if err.Error() != tc.expErr {
t.Fatal("expected", tc.expErr, "got", err)
}
})
}
}

func TestCollection_Count(t *testing.T) {
// Create collection
db := NewDB()
Expand Down

0 comments on commit 98516b1

Please sign in to comment.