Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

add: -w/--keyword flag for retrieve/askdir + allow selecting multiple datasets #68

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ toolchain go1.22.4
replace (
github.com/hupe1980/golc => github.com/iwilltry42/golc v0.0.113-0.20240802113826-d065a3c5b0c7 // nbformat extension
github.com/ledongthuc/pdf => github.com/iwilltry42/pdf v0.0.0-20240517145113-99fbaebc5dd3 // fix for reading some PDFs: https://github.com/ledongthuc/pdf/pull/36 + https://github.com/iwilltry42/pdf/pull/2
github.com/philippgille/chromem-go => github.com/iwilltry42/chromem-go v0.0.0-20240813194839-d838df05b583 // OpenAI Compat Fixes
github.com/philippgille/chromem-go => github.com/iwilltry42/chromem-go v0.0.0-20240814135107-86b4f217a8e8 // OpenAI Compat Fixes
github.com/tmc/langchaingo => github.com/StrongMonkey/langchaingo v0.0.0-20240617180437-9af4bee04c8b // Context-Aware Markdown Splitting
)

4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -201,8 +201,8 @@ github.com/hupe1980/go-tiktoken v0.0.9 h1:qNs/XGTe7UHDUaFkU+jAPbhGzyi9BusOpxrNC8
github.com/hupe1980/go-tiktoken v0.0.9/go.mod h1:NME6d8hrE+Jo+kLUZHhXShYV8e40hYkm4BbSLQKtvAo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/iwilltry42/chromem-go v0.0.0-20240813194839-d838df05b583 h1:xTsr6cysGZGpu9xYaLiYItFu47Lh54jC49OwYX7fE2M=
github.com/iwilltry42/chromem-go v0.0.0-20240813194839-d838df05b583/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo=
github.com/iwilltry42/chromem-go v0.0.0-20240814135107-86b4f217a8e8 h1:Tob2qUvv7zEeVNDb4kNhAmboaj0zUYUlZ+fcJg/ru14=
github.com/iwilltry42/chromem-go v0.0.0-20240814135107-86b4f217a8e8/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo=
github.com/iwilltry42/golc v0.0.113-0.20240802113826-d065a3c5b0c7 h1:2AzzbKVW1iP2F+ovqJKq801l6tgxYPt9m2zFKbs+i/Y=
github.com/iwilltry42/golc v0.0.113-0.20240802113826-d065a3c5b0c7/go.mod h1:w692KzkSTSvXROfyu+jYauNXB4YaL1s8zHPDMnNW88o=
github.com/iwilltry42/pdf v0.0.0-20240517145113-99fbaebc5dd3 h1:rCVwFT7Q+HxpijWfSzKTYX4pCDMS7oy/I/WzU30VXyI=
2 changes: 1 addition & 1 deletion pkg/client/client.go
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ type Client interface {
AskDirectory(ctx context.Context, path string, query string, opts *IngestPathsOpts, ropts *datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error)
PrunePath(ctx context.Context, datasetID string, path string, keep []string) ([]index.File, error)
DeleteDocuments(ctx context.Context, datasetID string, documentIDs ...string) error
Retrieve(ctx context.Context, datasetID string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error)
Retrieve(ctx context.Context, datasetIDs []string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error)
ExportDatasets(ctx context.Context, path string, datasets ...string) error
ImportDatasets(ctx context.Context, path string, datasets ...string) error
UpdateDataset(ctx context.Context, dataset index.Dataset, opts *datastore.UpdateDatasetOpts) (*index.Dataset, error)
2 changes: 1 addition & 1 deletion pkg/client/common.go
Original file line number Diff line number Diff line change
@@ -241,7 +241,7 @@ func AskDir(ctx context.Context, c Client, path string, query string, opts *Inge
slog.Debug("Ingested files", "count", ingested, "path", abspath)

// retrieve documents
return c.Retrieve(ctx, datasetID, query, *ropts)
return c.Retrieve(ctx, []string{datasetID}, query, *ropts)
}

func getOrCreateDataset(ctx context.Context, c Client, datasetID string, create bool) (*index.Dataset, error) {
5 changes: 3 additions & 2 deletions pkg/client/default.go
Original file line number Diff line number Diff line change
@@ -173,7 +173,7 @@ func (c *DefaultClient) DeleteDocuments(_ context.Context, datasetID string, doc
return nil
}

func (c *DefaultClient) Retrieve(_ context.Context, datasetID string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) {
func (c *DefaultClient) Retrieve(_ context.Context, datasetIDs []string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) {
q := types.Query{Prompt: query}

if opts.TopK != 0 {
@@ -185,7 +185,8 @@ func (c *DefaultClient) Retrieve(_ context.Context, datasetID string, query stri
return nil, err
}

resp, err := c.request(http.MethodPost, fmt.Sprintf("/datasets/%s/retrieve", datasetID), bytes.NewBuffer(data))
// TODO: change to allow for multiple datasets
resp, err := c.request(http.MethodPost, fmt.Sprintf("/datasets/%s/retrieve", datasetIDs), bytes.NewBuffer(data))
if err != nil {
return nil, err
}
4 changes: 2 additions & 2 deletions pkg/client/standalone.go
Original file line number Diff line number Diff line change
@@ -124,8 +124,8 @@ func (c *StandaloneClient) DeleteDocuments(ctx context.Context, datasetID string
return nil
}

func (c *StandaloneClient) Retrieve(ctx context.Context, datasetID string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) {
return c.Datastore.Retrieve(ctx, datasetID, query, opts)
func (c *StandaloneClient) Retrieve(ctx context.Context, datasetIDs []string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) {
return c.Datastore.Retrieve(ctx, datasetIDs, query, opts)
}

func (c *StandaloneClient) AskDirectory(ctx context.Context, path string, query string, opts *IngestPathsOpts, ropts *datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) {
3 changes: 2 additions & 1 deletion pkg/cmd/askdir.go
Original file line number Diff line number Diff line change
@@ -49,7 +49,8 @@ func (s *ClientAskDir) Run(cmd *cobra.Command, args []string) error {
}

retrieveOpts := &datastore.RetrieveOpts{
TopK: s.TopK,
TopK: s.TopK,
Keywords: s.Keywords,
}

if s.FlowsFile != "" {
38 changes: 25 additions & 13 deletions pkg/cmd/retrieve.go
Original file line number Diff line number Diff line change
@@ -13,14 +13,15 @@ import (

type ClientRetrieve struct {
Client
Dataset string `usage:"Target Dataset ID" short:"d" default:"default" env:"KNOW_TARGET_DATASET"`
Archive string `usage:"Path to the archive file"`
Datasets []string `usage:"Target Dataset IDs" short:"d" default:"default" env:"KNOW_TARGET_DATASETS" name:"dataset"`
Archive string `usage:"Path to the archive file"`
ClientRetrieveOpts
ClientFlowsConfig
}

type ClientRetrieveOpts struct {
TopK int `usage:"Number of sources to retrieve" short:"k" default:"10"`
TopK int `usage:"Number of sources to retrieve" short:"k" default:"10"`
Keywords []string `usage:"Keywords that retrieved documents must contain" short:"w" name:"keyword" env:"KNOW_RETRIEVE_KEYWORDS"`
}

func (s *ClientRetrieve) Customize(cmd *cobra.Command) {
@@ -35,15 +36,19 @@ func (s *ClientRetrieve) Run(cmd *cobra.Command, args []string) error {
return err
}

datasetID := s.Dataset
datasetIDs := s.Datasets
if len(s.Datasets) == 0 {
datasetIDs = []string{"default"}
}
query := args[0]

retrieveOpts := datastore.RetrieveOpts{
TopK: s.TopK,
TopK: s.TopK,
Keywords: s.Keywords,
}

if s.FlowsFile != "" {
slog.Debug("Loading retrieval flows from config", "flows_file", s.FlowsFile, "dataset", datasetID)
slog.Debug("Loading retrieval flows from config", "flows_file", s.FlowsFile, "dataset", datasetIDs)
flowCfg, err := flowconfig.FromFile(s.FlowsFile)
if err != nil {
return err
@@ -55,29 +60,36 @@ func (s *ClientRetrieve) Run(cmd *cobra.Command, args []string) error {
return err
}
} else {
flow, err = flowCfg.ForDataset(datasetID) // get flow for the dataset
if err != nil {
return err
if len(datasetIDs) == 1 {
flow, err = flowCfg.ForDataset(datasetIDs[0]) // get flow for the dataset
if err != nil {
return err
}
} else {
flow, err = flowCfg.GetDefaultFlowConfigEntry()
if err != nil {
return err
}
}
}

if flow.Retrieval == nil {
slog.Info("No retrieval config in assigned flow", "flows_file", s.FlowsFile, "dataset", datasetID)
slog.Info("No retrieval config in assigned flow", "flows_file", s.FlowsFile, "dataset", datasetIDs)
} else {
rf, err := flow.Retrieval.AsRetrievalFlow()
if err != nil {
return err
}
retrieveOpts.RetrievalFlow = rf
slog.Debug("Loaded retrieval flow from config", "flows_file", s.FlowsFile, "dataset", datasetID)
slog.Debug("Loaded retrieval flow from config", "flows_file", s.FlowsFile, "dataset", datasetIDs)
}
}

retrievalResp, err := c.Retrieve(cmd.Context(), datasetID, query, retrieveOpts)
retrievalResp, err := c.Retrieve(cmd.Context(), datasetIDs, query, retrieveOpts)
if err != nil {
// An empty collection is not a hard error - the LLM session can "recover" from it
if errors.Is(err, vserr.ErrCollectionEmpty) {
fmt.Printf("Dataset %q does not contain any documents\n", datasetID)
fmt.Printf("Dataset %q does not contain any documents\n", datasetIDs)
return nil
}
return err
44 changes: 39 additions & 5 deletions pkg/datastore/retrieve.go
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ package datastore
import (
"context"
"github.com/gptscript-ai/knowledge/pkg/datastore/types"
"github.com/philippgille/chromem-go"
"log/slog"

"github.com/gptscript-ai/knowledge/pkg/datastore/defaults"
@@ -12,11 +13,12 @@ import (

type RetrieveOpts struct {
TopK int
Keywords []string
RetrievalFlow *flows.RetrievalFlow
}

func (s *Datastore) Retrieve(ctx context.Context, datasetID string, query string, opts RetrieveOpts) (*types.RetrievalResponse, error) {
slog.Debug("Retrieving content from dataset", "dataset", datasetID, "query", query)
func (s *Datastore) Retrieve(ctx context.Context, datasetIDs []string, query string, opts RetrieveOpts) (*types.RetrievalResponse, error) {
slog.Debug("Retrieving content from dataset", "dataset", datasetIDs, "query", query)

retrievalFlow := opts.RetrievalFlow
if retrievalFlow == nil {
@@ -28,9 +30,41 @@ func (s *Datastore) Retrieve(ctx context.Context, datasetID string, query string
}
retrievalFlow.FillDefaults(topK)

return retrievalFlow.Run(ctx, s, query, datasetID)
var whereDocs []chromem.WhereDocument
if len(opts.Keywords) > 0 {
whereDoc := chromem.WhereDocument{
Operator: chromem.WhereDocumentOperatorOr,
WhereDocuments: []chromem.WhereDocument{},
}
whereDocNot := chromem.WhereDocument{
Operator: chromem.WhereDocumentOperatorAnd,
WhereDocuments: []chromem.WhereDocument{},
}
for _, kw := range opts.Keywords {
if kw[0] == '-' {
whereDocNot.WhereDocuments = append(whereDocNot.WhereDocuments, chromem.WhereDocument{
Operator: chromem.WhereDocumentOperatorNotContains,
Value: kw[1:],
})
} else {
whereDoc.WhereDocuments = append(whereDoc.WhereDocuments, chromem.WhereDocument{
Operator: chromem.WhereDocumentOperatorContains,
Value: kw,
})
}
}
if len(whereDoc.WhereDocuments) > 0 {
whereDocs = append(whereDocs, whereDoc)
}
if len(whereDocNot.WhereDocuments) > 0 {
whereDocs = append(whereDocs, whereDocNot)
}

}

return retrievalFlow.Run(ctx, s, query, datasetIDs, &flows.RetrievalFlowOpts{Where: nil, WhereDocument: whereDocs})
}

func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string) ([]vectorstore.Document, error) {
return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID)
func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vectorstore.Document, error) {
return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID, where, whereDocument)
}
27 changes: 27 additions & 0 deletions pkg/datastore/retrievers/keyword.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package retrievers

import (
"regexp"
"strings"
)

// regex pattern to match double-quoted substrings
var doubleQuotePattern = regexp.MustCompile(`"([^"]*)"`)

// Extract double-quoted substrings from a string
func ExtractQuotedSubstrings(input string) []string {

matches := doubleQuotePattern.FindAllStringSubmatch(input, -1)

var substrings []string
for _, match := range matches {
if len(match) > 1 {
m := strings.TrimSpace(match[1])
if m != "" {
substrings = append(substrings, m)
}
}
}

return substrings
}
19 changes: 16 additions & 3 deletions pkg/datastore/retrievers/retrievers.go
Original file line number Diff line number Diff line change
@@ -4,14 +4,15 @@ import (
"context"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/datastore/store"
"github.com/philippgille/chromem-go"
"log/slog"

"github.com/gptscript-ai/knowledge/pkg/datastore/defaults"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
)

type Retriever interface {
Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error)
Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error)
Name() string
}

@@ -42,11 +43,23 @@ func (r *BasicRetriever) Name() string {
return BasicRetrieverName
}

func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) {
func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) {

if len(datasetIDs) > 1 {
return nil, fmt.Errorf("basic retriever does not support querying multiple datasets")
}

var datasetID string
if len(datasetIDs) == 0 {
datasetID = "default"
} else {
datasetID = datasetIDs[0]
}

log := slog.With("retriever", r.Name())
if r.TopK <= 0 {
log.Debug("[BasicRetriever] TopK not set, using default", "default", defaults.TopK)
r.TopK = defaults.TopK
}
return store.SimilaritySearch(ctx, query, r.TopK, datasetID)
return store.SimilaritySearch(ctx, query, r.TopK, datasetID, where, whereDocument)
}
8 changes: 5 additions & 3 deletions pkg/datastore/retrievers/routing.go
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ import (
"github.com/gptscript-ai/knowledge/pkg/datastore/store"
"github.com/gptscript-ai/knowledge/pkg/llm"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"github.com/philippgille/chromem-go"
"log/slog"
)

@@ -35,10 +36,11 @@ type routingResp struct {
Result string `json:"result"`
}

func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) {
func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) {
log := slog.With("component", "RoutingRetriever")

log.Debug("Ignoring input datasetID in routing retriever, as it chooses one by itself", "query", query, "inputDataset", datasetID)
// TODO: properly handle the datasetIDs input
log.Debug("Ignoring input datasetIDs in routing retriever, as it chooses one by itself", "query", query, "inputDataset", datasetIDs)

if r.TopK <= 0 {
log.Debug("TopK not set, using default", "default", defaults.TopK)
@@ -91,5 +93,5 @@ func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, quer

slog.Debug("Routing query to dataset", "query", query, "dataset", resp.Result)

return store.SimilaritySearch(ctx, query, r.TopK, resp.Result)
return store.SimilaritySearch(ctx, query, r.TopK, resp.Result, where, whereDocument)
}
17 changes: 15 additions & 2 deletions pkg/datastore/retrievers/subquery.go
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ import (
"github.com/gptscript-ai/knowledge/pkg/datastore/store"
"github.com/gptscript-ai/knowledge/pkg/llm"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"github.com/philippgille/chromem-go"
"log/slog"
"strings"
)
@@ -37,7 +38,19 @@ type subqueryResp struct {
Results []string `json:"results"`
}

func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) {
func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) {

if len(datasetIDs) > 1 {
return nil, fmt.Errorf("basic retriever does not support querying multiple datasets")
}

var datasetID string
if len(datasetIDs) == 0 {
datasetID = "default"
} else {
datasetID = datasetIDs[0]
}

m, err := llm.NewFromConfig(s.Model)
if err != nil {
return nil, err
@@ -72,7 +85,7 @@ func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, quer

var resultDocs []vs.Document
for _, q := range queries {
docs, err := store.SimilaritySearch(ctx, q, s.TopK, datasetID)
docs, err := store.SimilaritySearch(ctx, q, s.TopK, datasetID, where, whereDocument)
if err != nil {
return nil, err
}
3 changes: 2 additions & 1 deletion pkg/datastore/store/store.go
Original file line number Diff line number Diff line change
@@ -4,10 +4,11 @@ import (
"context"
"github.com/gptscript-ai/knowledge/pkg/index"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"github.com/philippgille/chromem-go"
)

type Store interface {
ListDatasets(ctx context.Context) ([]index.Dataset, error)
GetDataset(ctx context.Context, datasetID string) (*index.Dataset, error)
SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string) ([]vs.Document, error)
SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error)
}
17 changes: 14 additions & 3 deletions pkg/flows/flows.go
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/datastore/store"
"github.com/philippgille/chromem-go"
"io"
"log/slog"
"slices"
@@ -114,7 +115,16 @@ func (f *RetrievalFlow) FillDefaults(topK int) {
}
}

func (f *RetrievalFlow) Run(ctx context.Context, store store.Store, query string, datasetID string) (*dstypes.RetrievalResponse, error) {
type RetrievalFlowOpts struct {
Where map[string]string
WhereDocument []chromem.WhereDocument
}

func (f *RetrievalFlow) Run(ctx context.Context, store store.Store, query string, datasetIDs []string, opts *RetrievalFlowOpts) (*dstypes.RetrievalResponse, error) {
if opts == nil {
opts = &RetrievalFlowOpts{}
}

queries := []string{query}
for _, m := range f.QueryModifiers {
mq, err := m.ModifyQueries(queries)
@@ -131,11 +141,12 @@ func (f *RetrievalFlow) Run(ctx context.Context, store store.Store, query string
Responses: make(map[string][]vs.Document, len(queries)),
}
for _, q := range queries {
docs, err := f.Retriever.Retrieve(ctx, store, q, datasetID)

docs, err := f.Retriever.Retrieve(ctx, store, q, datasetIDs, opts.Where, opts.WhereDocument)
if err != nil {
return nil, fmt.Errorf("failed to retrieve documents for query %q using retriever %q: %w", q, f.Retriever.Name(), err)
}
slog.Debug("Retrieved documents", "num_documents", len(docs), "query", q, "dataset", datasetID, "retriever", f.Retriever.Name())
slog.Debug("Retrieved documents", "num_documents", len(docs), "query", q, "datasets", datasetIDs, "retriever", f.Retriever.Name())
response.Responses[q] = docs
}

4 changes: 2 additions & 2 deletions pkg/server/routes.go
Original file line number Diff line number Diff line change
@@ -80,8 +80,8 @@ func (s *Server) RetrieveFromDS(c *gin.Context) {
return
}

// TODO: support retrieval flows
docs, err := s.Retrieve(c, id, query.Prompt, datastore.RetrieveOpts{TopK: z.Dereference(query.TopK)})
// TODO: support retrieval flows and keywords
docs, err := s.Retrieve(c, []string{id}, query.Prompt, datastore.RetrieveOpts{TopK: z.Dereference(query.TopK)})
if err != nil {
slog.Error("Failed to retrieve documents", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
17 changes: 9 additions & 8 deletions pkg/vectorstore/chromem/chromem.go
Original file line number Diff line number Diff line change
@@ -3,17 +3,16 @@ package chromem
import (
"context"
"fmt"
"github.com/google/uuid"
"github.com/gptscript-ai/knowledge/pkg/env"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"github.com/gptscript-ai/knowledge/pkg/vectorstore/errors"
"github.com/philippgille/chromem-go"
"log/slog"
"maps"
"os"
"path/filepath"
"strconv"

"github.com/google/uuid"
"github.com/gptscript-ai/knowledge/pkg/env"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"github.com/philippgille/chromem-go"
)

// VsChromemEmbeddingParallelThread can be set as an environment variable to control the number of parallel API calls to create embedding for documents. Default is 100
@@ -108,7 +107,7 @@ func convertStringMapToAnyMap(m map[string]string) map[string]any {
return convertedMap
}

func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string) ([]vs.Document, error) {
func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) {
col := s.db.GetCollection(collection, s.embeddingFunc)
if col == nil {
return nil, fmt.Errorf("%w: %q", errors.ErrCollectionNotFound, collection)
@@ -123,7 +122,9 @@ func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments
slog.Debug("Reduced number of documents to search for", "numDocuments", numDocuments)
}

qr, err := col.Query(ctx, query, numDocuments, nil, nil)
slog.Debug("filtering documents", "where", where, "whereDocument", whereDocument)

qr, err := col.Query(ctx, query, numDocuments, where, whereDocument)
if err != nil {
return nil, err
}
@@ -149,7 +150,7 @@ func (s *Store) RemoveCollection(_ context.Context, collection string) error {
return s.db.DeleteCollection(collection)
}

func (s *Store) RemoveDocument(ctx context.Context, documentID string, collection string, where, whereDocument map[string]string) error {
func (s *Store) RemoveDocument(ctx context.Context, documentID string, collection string, where map[string]string, whereDocument []chromem.WhereDocument) error {
col := s.db.GetCollection(collection, s.embeddingFunc)
if col == nil {
return fmt.Errorf("%w: %q", errors.ErrCollectionNotFound, collection)
7 changes: 4 additions & 3 deletions pkg/vectorstore/vectorstores.go
Original file line number Diff line number Diff line change
@@ -2,14 +2,15 @@ package vectorstore

import (
"context"
"github.com/philippgille/chromem-go"
)

type VectorStore interface {
CreateCollection(ctx context.Context, collection string) error
AddDocuments(ctx context.Context, docs []Document, collection string) ([]string, error) // @return documentIDs, error
SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string) ([]Document, error) //nolint:lll
AddDocuments(ctx context.Context, docs []Document, collection string) ([]string, error) // @return documentIDs, error
SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, where map[string]string, whereDocument []chromem.WhereDocument) ([]Document, error) //nolint:lll
RemoveCollection(ctx context.Context, collection string) error
RemoveDocument(ctx context.Context, documentID string, collection string, where, whereDocument map[string]string) error
RemoveDocument(ctx context.Context, documentID string, collection string, where map[string]string, whereDocument []chromem.WhereDocument) error

ImportCollectionsFromFile(ctx context.Context, path string, collections ...string) error
ExportCollectionsToFile(ctx context.Context, path string, collections ...string) error