Skip to content

Commit

Permalink
Merge pull request #80 from erikdubbelboer/QueryWithNegative
Browse files Browse the repository at this point in the history
Add QueryWithNegative
  • Loading branch information
philippgille authored Jun 26, 2024
2 parents e147c74 + d057669 commit 4b532fa
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 11 deletions.
132 changes: 122 additions & 10 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,67 @@ type Collection struct {
// versions in [DB.Export] and [DB.Import] as well!
}

// NegativeMode represents the mode to use for the negative text.
// See QueryOptions for more information.
type NegativeMode string

const (
// NEGATIVE_MODE_FILTER filters out results based on the similarity between the
// negative embedding and the document embeddings.
// NegativeFilterThreshold controls the threshold for filtering. Documents with
// similarity above the threshold will be removed from the results.
NEGATIVE_MODE_FILTER NegativeMode = "filter"

// NEGATIVE_MODE_SUBTRACT subtracts the negative embedding from the query embedding.
// This is the default behavior.
NEGATIVE_MODE_SUBTRACT NegativeMode = "subtract"

// The default threshold for the negative filter.
DEFAULT_NEGATIVE_FILTER_THRESHOLD = 0.5
)

// QueryOptions represents the options for a query.
type QueryOptions struct {
// The text to search for.
QueryText string

// The embedding of the query to search for. It must be created
// with the same embedding model as the document embeddings in the collection.
// The embedding will be normalized if it's not the case yet.
// If both QueryText and QueryEmbedding are set, QueryEmbedding will be used.
QueryEmbedding []float32

// The number of results to return.
NResults int

// Conditional filtering on metadata.
Where map[string]string

// Conditional filtering on documents.
WhereDocument map[string]string

// Negative is the negative query options.
// They can be used to exclude certain results from the query.
Negative NegativeQueryOptions
}

type NegativeQueryOptions struct {
// Mode is the mode to use for the negative text.
Mode NegativeMode

// Text is the text to exclude from the results.
Text string

// Embedding is the embedding of the negative text. It must be created
// with the same embedding model as the document embeddings in the collection.
// The embedding will be normalized if it's not the case yet.
// If both Text and Embedding are set, Embedding will be used.
Embedding []float32

// FilterThreshold is the threshold for the negative filter. Used when Mode is NEGATIVE_MODE_FILTER.
FilterThreshold float32
}

// We don't export this yet to keep the API surface to the bare minimum.
// Users create collections via [Client.CreateCollection].
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) {
Expand Down Expand Up @@ -336,12 +397,63 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
return nil, errors.New("queryText is empty")
}

queryVectors, err := c.embed(ctx, queryText)
queryVector, err := c.embed(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
}

return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
return c.QueryEmbedding(ctx, queryVector, nResults, where, whereDocument)
}

// QueryWithOptions performs an exhaustive nearest neighbor search on the collection.
//
// - options: The options for the query. See QueryOptions for more information.
func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions) ([]Result, error) {
if options.QueryText == "" && len(options.QueryEmbedding) == 0 {
return nil, errors.New("QueryText and QueryEmbedding options are empty")
}

var err error
queryVector := options.QueryEmbedding
if len(queryVector) == 0 {
queryVector, err = c.embed(ctx, options.QueryText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
}
}

negativeFilterThreshold := options.Negative.FilterThreshold
negativeVector := options.Negative.Embedding
if len(negativeVector) == 0 && options.Negative.Text != "" {
negativeVector, err = c.embed(ctx, options.Negative.Text)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of negative: %w", err)
}
}

if len(negativeVector) != 0 {
if !isNormalized(negativeVector) {
negativeVector = normalizeVector(negativeVector)
}

if options.Negative.Mode == NEGATIVE_MODE_SUBTRACT {
queryVector = subtractVector(queryVector, negativeVector)
queryVector = normalizeVector(queryVector)
} else if options.Negative.Mode == NEGATIVE_MODE_FILTER {
if negativeFilterThreshold == 0 {
negativeFilterThreshold = DEFAULT_NEGATIVE_FILTER_THRESHOLD
}
} else {
return nil, fmt.Errorf("unsupported negative mode: %q", options.Negative.Mode)
}
}

result, err := c.queryEmbedding(ctx, queryVector, negativeVector, negativeFilterThreshold, options.NResults, options.Where, options.WhereDocument)
if err != nil {
return nil, err
}

return result, nil
}

// QueryEmbedding performs an exhaustive nearest neighbor search on the collection.
Expand All @@ -354,6 +466,11 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
return c.queryEmbedding(ctx, queryEmbedding, nil, 0, nResults, where, whereDocument)
}

// queryEmbedding performs an exhaustive nearest neighbor search on the collection.
func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativeEmbeddings []float32, negativeFilterThreshold float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
if len(queryEmbedding) == 0 {
return nil, errors.New("queryEmbedding is empty")
}
Expand Down Expand Up @@ -399,18 +516,13 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
}

// For the remaining documents, get the most similar docs.
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, resLen)
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, negativeEmbeddings, negativeFilterThreshold, filteredDocs, resLen)
if err != nil {
return nil, fmt.Errorf("couldn't get most similar docs: %w", err)
}

// As long as we don't filter by threshold, resLen should match len(nMaxDocs).
if resLen != len(nMaxDocs) {
return nil, fmt.Errorf("internal error: expected %d results, got %d", resLen, len(nMaxDocs))
}

res := make([]Result, 0, resLen)
for i := 0; i < resLen; i++ {
res := make([]Result, 0, len(nMaxDocs))
for i := 0; i < len(nMaxDocs); i++ {
res = append(res, Result{
ID: nMaxDocs[i].docID,
Metadata: c.documents[nMaxDocs[i].docID].Metadata,
Expand Down
Loading

0 comments on commit 4b532fa

Please sign in to comment.