From 783e2ea41acc48537eec43e318d9ca0a1f173441 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Tue, 13 Aug 2024 21:32:28 +0200 Subject: [PATCH 1/3] add: -w/--keyword flag for retrieve/askdir + allow selecting multiple datasets --- pkg/client/client.go | 2 +- pkg/client/common.go | 2 +- pkg/client/default.go | 5 ++-- pkg/client/standalone.go | 4 +-- pkg/cmd/askdir.go | 3 +- pkg/cmd/retrieve.go | 38 +++++++++++++++++--------- pkg/datastore/retrieve.go | 11 ++++---- pkg/datastore/retrievers/keyword.go | 27 ++++++++++++++++++ pkg/datastore/retrievers/retrievers.go | 18 ++++++++++-- pkg/datastore/retrievers/routing.go | 7 +++-- pkg/datastore/retrievers/subquery.go | 16 +++++++++-- pkg/datastore/store/store.go | 2 +- pkg/flows/flows.go | 15 ++++++++-- pkg/server/routes.go | 4 +-- pkg/vectorstore/chromem/chromem.go | 15 +++++++++- pkg/vectorstore/vectorstores.go | 4 +-- 16 files changed, 131 insertions(+), 42 deletions(-) create mode 100644 pkg/datastore/retrievers/keyword.go diff --git a/pkg/client/client.go b/pkg/client/client.go index 30c30f5..8912b0f 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -33,7 +33,7 @@ type Client interface { AskDirectory(ctx context.Context, path string, query string, opts *IngestPathsOpts, ropts *datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) PrunePath(ctx context.Context, datasetID string, path string, keep []string) ([]index.File, error) DeleteDocuments(ctx context.Context, datasetID string, documentIDs ...string) error - Retrieve(ctx context.Context, datasetID string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) + Retrieve(ctx context.Context, datasetIDs []string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) ExportDatasets(ctx context.Context, path string, datasets ...string) error ImportDatasets(ctx context.Context, path string, datasets ...string) error UpdateDataset(ctx context.Context, dataset index.Dataset, opts *datastore.UpdateDatasetOpts) (*index.Dataset, error) diff --git a/pkg/client/common.go b/pkg/client/common.go index 6f2851a..48bfb88 100644 --- a/pkg/client/common.go +++ b/pkg/client/common.go @@ -241,7 +241,7 @@ func AskDir(ctx context.Context, c Client, path string, query string, opts *Inge slog.Debug("Ingested files", "count", ingested, "path", abspath) // retrieve documents - return c.Retrieve(ctx, datasetID, query, *ropts) + return c.Retrieve(ctx, []string{datasetID}, query, *ropts) } func getOrCreateDataset(ctx context.Context, c Client, datasetID string, create bool) (*index.Dataset, error) { diff --git a/pkg/client/default.go b/pkg/client/default.go index c07bb18..4d56014 100644 --- a/pkg/client/default.go +++ b/pkg/client/default.go @@ -173,7 +173,7 @@ func (c *DefaultClient) DeleteDocuments(_ context.Context, datasetID string, doc return nil } -func (c *DefaultClient) Retrieve(_ context.Context, datasetID string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) { +func (c *DefaultClient) Retrieve(_ context.Context, datasetIDs []string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) { q := types.Query{Prompt: query} if opts.TopK != 0 { @@ -185,7 +185,8 @@ func (c *DefaultClient) Retrieve(_ context.Context, datasetID string, query stri return nil, err } - resp, err := c.request(http.MethodPost, fmt.Sprintf("/datasets/%s/retrieve", datasetID), bytes.NewBuffer(data)) + // TODO: change to allow for multiple datasets + resp, err := c.request(http.MethodPost, fmt.Sprintf("/datasets/%s/retrieve", datasetIDs), bytes.NewBuffer(data)) if err != nil { return nil, err } diff --git a/pkg/client/standalone.go b/pkg/client/standalone.go index 0e8a1ef..a946aae 100644 --- a/pkg/client/standalone.go +++ b/pkg/client/standalone.go @@ -124,8 +124,8 @@ func (c *StandaloneClient) DeleteDocuments(ctx context.Context, datasetID string return nil } -func (c *StandaloneClient) Retrieve(ctx context.Context, datasetID string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) { - return c.Datastore.Retrieve(ctx, datasetID, query, opts) +func (c *StandaloneClient) Retrieve(ctx context.Context, datasetIDs []string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) { + return c.Datastore.Retrieve(ctx, datasetIDs, query, opts) } func (c *StandaloneClient) AskDirectory(ctx context.Context, path string, query string, opts *IngestPathsOpts, ropts *datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) { diff --git a/pkg/cmd/askdir.go b/pkg/cmd/askdir.go index 527dd20..9a1abaa 100644 --- a/pkg/cmd/askdir.go +++ b/pkg/cmd/askdir.go @@ -49,7 +49,8 @@ func (s *ClientAskDir) Run(cmd *cobra.Command, args []string) error { } retrieveOpts := &datastore.RetrieveOpts{ - TopK: s.TopK, + TopK: s.TopK, + Keywords: s.Keywords, } if s.FlowsFile != "" { diff --git a/pkg/cmd/retrieve.go b/pkg/cmd/retrieve.go index 96a0435..615b1a3 100644 --- a/pkg/cmd/retrieve.go +++ b/pkg/cmd/retrieve.go @@ -13,14 +13,15 @@ import ( type ClientRetrieve struct { Client - Dataset string `usage:"Target Dataset ID" short:"d" default:"default" env:"KNOW_TARGET_DATASET"` - Archive string `usage:"Path to the archive file"` + Datasets []string `usage:"Target Dataset IDs" short:"d" default:"default" env:"KNOW_TARGET_DATASETS"` + Archive string `usage:"Path to the archive file"` ClientRetrieveOpts ClientFlowsConfig } type ClientRetrieveOpts struct { - TopK int `usage:"Number of sources to retrieve" short:"k" default:"10"` + TopK int `usage:"Number of sources to retrieve" short:"k" default:"10"` + Keywords []string `usage:"Keywords that retrieved documents must contain" short:"w"` } func (s *ClientRetrieve) Customize(cmd *cobra.Command) { @@ -35,15 +36,19 @@ func (s *ClientRetrieve) Run(cmd *cobra.Command, args []string) error { return err } - datasetID := s.Dataset + datasetIDs := s.Datasets + if len(s.Datasets) == 0 { + datasetIDs = []string{"default"} + } query := args[0] retrieveOpts := datastore.RetrieveOpts{ - TopK: s.TopK, + TopK: s.TopK, + Keywords: s.Keywords, } if s.FlowsFile != "" { - slog.Debug("Loading retrieval flows from config", "flows_file", s.FlowsFile, "dataset", datasetID) + slog.Debug("Loading retrieval flows from config", "flows_file", s.FlowsFile, "dataset", datasetIDs) flowCfg, err := flowconfig.FromFile(s.FlowsFile) if err != nil { return err @@ -55,29 +60,36 @@ func (s *ClientRetrieve) Run(cmd *cobra.Command, args []string) error { return err } } else { - flow, err = flowCfg.ForDataset(datasetID) // get flow for the dataset - if err != nil { - return err + if len(datasetIDs) == 1 { + flow, err = flowCfg.ForDataset(datasetIDs[0]) // get flow for the dataset + if err != nil { + return err + } + } else { + flow, err = flowCfg.GetDefaultFlowConfigEntry() + if err != nil { + return err + } } } if flow.Retrieval == nil { - slog.Info("No retrieval config in assigned flow", "flows_file", s.FlowsFile, "dataset", datasetID) + slog.Info("No retrieval config in assigned flow", "flows_file", s.FlowsFile, "dataset", datasetIDs) } else { rf, err := flow.Retrieval.AsRetrievalFlow() if err != nil { return err } retrieveOpts.RetrievalFlow = rf - slog.Debug("Loaded retrieval flow from config", "flows_file", s.FlowsFile, "dataset", datasetID) + slog.Debug("Loaded retrieval flow from config", "flows_file", s.FlowsFile, "dataset", datasetIDs) } } - retrievalResp, err := c.Retrieve(cmd.Context(), datasetID, query, retrieveOpts) + retrievalResp, err := c.Retrieve(cmd.Context(), datasetIDs, query, retrieveOpts) if err != nil { // An empty collection is not a hard error - the LLM session can "recover" from it if errors.Is(err, vserr.ErrCollectionEmpty) { - fmt.Printf("Dataset %q does not contain any documents\n", datasetID) + fmt.Printf("Dataset %q does not contain any documents\n", datasetIDs) return nil } return err diff --git a/pkg/datastore/retrieve.go b/pkg/datastore/retrieve.go index 5c87c87..5611738 100644 --- a/pkg/datastore/retrieve.go +++ b/pkg/datastore/retrieve.go @@ -12,11 +12,12 @@ import ( type RetrieveOpts struct { TopK int + Keywords []string RetrievalFlow *flows.RetrievalFlow } -func (s *Datastore) Retrieve(ctx context.Context, datasetID string, query string, opts RetrieveOpts) (*types.RetrievalResponse, error) { - slog.Debug("Retrieving content from dataset", "dataset", datasetID, "query", query) +func (s *Datastore) Retrieve(ctx context.Context, datasetIDs []string, query string, opts RetrieveOpts) (*types.RetrievalResponse, error) { + slog.Debug("Retrieving content from dataset", "dataset", datasetIDs, "query", query) retrievalFlow := opts.RetrievalFlow if retrievalFlow == nil { @@ -28,9 +29,9 @@ func (s *Datastore) Retrieve(ctx context.Context, datasetID string, query string } retrievalFlow.FillDefaults(topK) - return retrievalFlow.Run(ctx, s, query, datasetID) + return retrievalFlow.Run(ctx, s, query, datasetIDs, &flows.RetrievalFlowOpts{Keywords: opts.Keywords}) } -func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string) ([]vectorstore.Document, error) { - return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID) +func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string, keywords ...string) ([]vectorstore.Document, error) { + return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID, keywords...) } diff --git a/pkg/datastore/retrievers/keyword.go b/pkg/datastore/retrievers/keyword.go new file mode 100644 index 0000000..4d7f773 --- /dev/null +++ b/pkg/datastore/retrievers/keyword.go @@ -0,0 +1,27 @@ +package retrievers + +import ( + "regexp" + "strings" +) + +// regex pattern to match double-quoted substrings +var doubleQuotePattern = regexp.MustCompile(`"([^"]*)"`) + +// Extract double-quoted substrings from a string +func ExtractQuotedSubstrings(input string) []string { + + matches := doubleQuotePattern.FindAllStringSubmatch(input, -1) + + var substrings []string + for _, match := range matches { + if len(match) > 1 { + m := strings.TrimSpace(match[1]) + if m != "" { + substrings = append(substrings, m) + } + } + } + + return substrings +} diff --git a/pkg/datastore/retrievers/retrievers.go b/pkg/datastore/retrievers/retrievers.go index 3eaaf73..b1a70ab 100644 --- a/pkg/datastore/retrievers/retrievers.go +++ b/pkg/datastore/retrievers/retrievers.go @@ -11,7 +11,7 @@ import ( ) type Retriever interface { - Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) + Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, keywords ...string) ([]vs.Document, error) Name() string } @@ -42,11 +42,23 @@ func (r *BasicRetriever) Name() string { return BasicRetrieverName } -func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) { +func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, keywords ...string) ([]vs.Document, error) { + + if len(datasetIDs) > 1 { + return nil, fmt.Errorf("basic retriever does not support querying multiple datasets") + } + + var datasetID string + if len(datasetIDs) == 0 { + datasetID = "default" + } else { + datasetID = datasetIDs[0] + } + log := slog.With("retriever", r.Name()) if r.TopK <= 0 { log.Debug("[BasicRetriever] TopK not set, using default", "default", defaults.TopK) r.TopK = defaults.TopK } - return store.SimilaritySearch(ctx, query, r.TopK, datasetID) + return store.SimilaritySearch(ctx, query, r.TopK, datasetID, keywords...) } diff --git a/pkg/datastore/retrievers/routing.go b/pkg/datastore/retrievers/routing.go index c47f7d4..81158ee 100644 --- a/pkg/datastore/retrievers/routing.go +++ b/pkg/datastore/retrievers/routing.go @@ -35,10 +35,11 @@ type routingResp struct { Result string `json:"result"` } -func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) { +func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, keywords ...string) ([]vs.Document, error) { log := slog.With("component", "RoutingRetriever") - log.Debug("Ignoring input datasetID in routing retriever, as it chooses one by itself", "query", query, "inputDataset", datasetID) + // TODO: properly handle the datasetIDs input + log.Debug("Ignoring input datasetIDs in routing retriever, as it chooses one by itself", "query", query, "inputDataset", datasetIDs) if r.TopK <= 0 { log.Debug("TopK not set, using default", "default", defaults.TopK) @@ -91,5 +92,5 @@ func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, quer slog.Debug("Routing query to dataset", "query", query, "dataset", resp.Result) - return store.SimilaritySearch(ctx, query, r.TopK, resp.Result) + return store.SimilaritySearch(ctx, query, r.TopK, resp.Result, keywords...) } diff --git a/pkg/datastore/retrievers/subquery.go b/pkg/datastore/retrievers/subquery.go index 940f696..ab3eb88 100644 --- a/pkg/datastore/retrievers/subquery.go +++ b/pkg/datastore/retrievers/subquery.go @@ -37,7 +37,19 @@ type subqueryResp struct { Results []string `json:"results"` } -func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) { +func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, keywords ...string) ([]vs.Document, error) { + + if len(datasetIDs) > 1 { + return nil, fmt.Errorf("basic retriever does not support querying multiple datasets") + } + + var datasetID string + if len(datasetIDs) == 0 { + datasetID = "default" + } else { + datasetID = datasetIDs[0] + } + m, err := llm.NewFromConfig(s.Model) if err != nil { return nil, err @@ -72,7 +84,7 @@ func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, quer var resultDocs []vs.Document for _, q := range queries { - docs, err := store.SimilaritySearch(ctx, q, s.TopK, datasetID) + docs, err := store.SimilaritySearch(ctx, q, s.TopK, datasetID, keywords...) if err != nil { return nil, err } diff --git a/pkg/datastore/store/store.go b/pkg/datastore/store/store.go index 71ef58c..21594b3 100644 --- a/pkg/datastore/store/store.go +++ b/pkg/datastore/store/store.go @@ -9,5 +9,5 @@ import ( type Store interface { ListDatasets(ctx context.Context) ([]index.Dataset, error) GetDataset(ctx context.Context, datasetID string) (*index.Dataset, error) - SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string) ([]vs.Document, error) + SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, keywords ...string) ([]vs.Document, error) } diff --git a/pkg/flows/flows.go b/pkg/flows/flows.go index 77068ce..f1c2b2e 100644 --- a/pkg/flows/flows.go +++ b/pkg/flows/flows.go @@ -114,7 +114,15 @@ func (f *RetrievalFlow) FillDefaults(topK int) { } } -func (f *RetrievalFlow) Run(ctx context.Context, store store.Store, query string, datasetID string) (*dstypes.RetrievalResponse, error) { +type RetrievalFlowOpts struct { + Keywords []string +} + +func (f *RetrievalFlow) Run(ctx context.Context, store store.Store, query string, datasetIDs []string, opts *RetrievalFlowOpts) (*dstypes.RetrievalResponse, error) { + if opts == nil { + opts = &RetrievalFlowOpts{} + } + queries := []string{query} for _, m := range f.QueryModifiers { mq, err := m.ModifyQueries(queries) @@ -131,11 +139,12 @@ func (f *RetrievalFlow) Run(ctx context.Context, store store.Store, query string Responses: make(map[string][]vs.Document, len(queries)), } for _, q := range queries { - docs, err := f.Retriever.Retrieve(ctx, store, q, datasetID) + + docs, err := f.Retriever.Retrieve(ctx, store, q, datasetIDs, opts.Keywords...) if err != nil { return nil, fmt.Errorf("failed to retrieve documents for query %q using retriever %q: %w", q, f.Retriever.Name(), err) } - slog.Debug("Retrieved documents", "num_documents", len(docs), "query", q, "dataset", datasetID, "retriever", f.Retriever.Name()) + slog.Debug("Retrieved documents", "num_documents", len(docs), "query", q, "datasets", datasetIDs, "retriever", f.Retriever.Name()) response.Responses[q] = docs } diff --git a/pkg/server/routes.go b/pkg/server/routes.go index 5f5daf5..b8cfc27 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -80,8 +80,8 @@ func (s *Server) RetrieveFromDS(c *gin.Context) { return } - // TODO: support retrieval flows - docs, err := s.Retrieve(c, id, query.Prompt, datastore.RetrieveOpts{TopK: z.Dereference(query.TopK)}) + // TODO: support retrieval flows and keywords + docs, err := s.Retrieve(c, []string{id}, query.Prompt, datastore.RetrieveOpts{TopK: z.Dereference(query.TopK)}) if err != nil { slog.Error("Failed to retrieve documents", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) diff --git a/pkg/vectorstore/chromem/chromem.go b/pkg/vectorstore/chromem/chromem.go index fce7f8b..66eddd9 100644 --- a/pkg/vectorstore/chromem/chromem.go +++ b/pkg/vectorstore/chromem/chromem.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "strconv" + "strings" "github.com/google/uuid" "github.com/gptscript-ai/knowledge/pkg/env" @@ -108,7 +109,7 @@ func convertStringMapToAnyMap(m map[string]string) map[string]any { return convertedMap } -func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string) ([]vs.Document, error) { +func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, keywords ...string) ([]vs.Document, error) { col := s.db.GetCollection(collection, s.embeddingFunc) if col == nil { return nil, fmt.Errorf("%w: %q", errors.ErrCollectionNotFound, collection) @@ -134,7 +135,17 @@ func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments var sDocs []vs.Document + slog.Debug("filtering documents by keywords", "keywords", keywords) + +resultLoop: for _, qrd := range qr { + for _, keyword := range keywords { + if !strings.Contains(qrd.Content, keyword) { + slog.Debug("Document does not contain keyword", "keyword", keyword, "documentID", qrd.ID) + continue resultLoop + } + } + sDocs = append(sDocs, vs.Document{ Metadata: convertStringMapToAnyMap(qrd.Metadata), SimilarityScore: qrd.Similarity, @@ -142,6 +153,8 @@ func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments }) } + slog.Debug("Found similar documents", "numDocuments", len(sDocs), "numUnfilteredDocuments", len(qr)) + return sDocs, nil } diff --git a/pkg/vectorstore/vectorstores.go b/pkg/vectorstore/vectorstores.go index 891a836..b3e8c9f 100644 --- a/pkg/vectorstore/vectorstores.go +++ b/pkg/vectorstore/vectorstores.go @@ -6,8 +6,8 @@ import ( type VectorStore interface { CreateCollection(ctx context.Context, collection string) error - AddDocuments(ctx context.Context, docs []Document, collection string) ([]string, error) // @return documentIDs, error - SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string) ([]Document, error) //nolint:lll + AddDocuments(ctx context.Context, docs []Document, collection string) ([]string, error) // @return documentIDs, error + SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, keywords ...string) ([]Document, error) //nolint:lll RemoveCollection(ctx context.Context, collection string) error RemoveDocument(ctx context.Context, documentID string, collection string, where, whereDocument map[string]string) error From 7ac9bb3e3f720f2c682f7e53f96553d7d2f8bf4f Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Tue, 13 Aug 2024 21:33:51 +0200 Subject: [PATCH 2/3] add: keyword filter search via -w/--keyword flags on retrieve/askdir + --dataset/-d flag can be used multiple times to select multiple datasets --- pkg/cmd/retrieve.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/cmd/retrieve.go b/pkg/cmd/retrieve.go index 615b1a3..0943667 100644 --- a/pkg/cmd/retrieve.go +++ b/pkg/cmd/retrieve.go @@ -13,7 +13,7 @@ import ( type ClientRetrieve struct { Client - Datasets []string `usage:"Target Dataset IDs" short:"d" default:"default" env:"KNOW_TARGET_DATASETS"` + Datasets []string `usage:"Target Dataset IDs" short:"d" default:"default" env:"KNOW_TARGET_DATASETS" name:"dataset"` Archive string `usage:"Path to the archive file"` ClientRetrieveOpts ClientFlowsConfig @@ -21,7 +21,7 @@ type ClientRetrieve struct { type ClientRetrieveOpts struct { TopK int `usage:"Number of sources to retrieve" short:"k" default:"10"` - Keywords []string `usage:"Keywords that retrieved documents must contain" short:"w"` + Keywords []string `usage:"Keywords that retrieved documents must contain" short:"w" name:"keyword" env:"KNOW_RETRIEVE_KEYWORDS"` } func (s *ClientRetrieve) Customize(cmd *cobra.Command) { From 28f66bb52760ea2f55f169fc07baac7455ea166e Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Wed, 14 Aug 2024 15:42:54 +0200 Subject: [PATCH 3/3] change: use chromem's new (pending) WhereDocument filters --- go.mod | 2 +- go.sum | 4 +-- pkg/datastore/retrieve.go | 39 ++++++++++++++++++++++++-- pkg/datastore/retrievers/retrievers.go | 7 +++-- pkg/datastore/retrievers/routing.go | 5 ++-- pkg/datastore/retrievers/subquery.go | 5 ++-- pkg/datastore/store/store.go | 3 +- pkg/flows/flows.go | 6 ++-- pkg/vectorstore/chromem/chromem.go | 30 ++++++-------------- pkg/vectorstore/vectorstores.go | 7 +++-- 10 files changed, 68 insertions(+), 40 deletions(-) diff --git a/go.mod b/go.mod index 3546331..503b146 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.22.4 replace ( github.com/hupe1980/golc => github.com/iwilltry42/golc v0.0.113-0.20240802113826-d065a3c5b0c7 // nbformat extension github.com/ledongthuc/pdf => github.com/iwilltry42/pdf v0.0.0-20240517145113-99fbaebc5dd3 // fix for reading some PDFs: https://github.com/ledongthuc/pdf/pull/36 + https://github.com/iwilltry42/pdf/pull/2 - github.com/philippgille/chromem-go => github.com/iwilltry42/chromem-go v0.0.0-20240813194839-d838df05b583 // OpenAI Compat Fixes + github.com/philippgille/chromem-go => github.com/iwilltry42/chromem-go v0.0.0-20240814135107-86b4f217a8e8 // OpenAI Compat Fixes github.com/tmc/langchaingo => github.com/StrongMonkey/langchaingo v0.0.0-20240617180437-9af4bee04c8b // Context-Aware Markdown Splitting ) diff --git a/go.sum b/go.sum index 4488e46..b276667 100644 --- a/go.sum +++ b/go.sum @@ -201,8 +201,8 @@ github.com/hupe1980/go-tiktoken v0.0.9 h1:qNs/XGTe7UHDUaFkU+jAPbhGzyi9BusOpxrNC8 github.com/hupe1980/go-tiktoken v0.0.9/go.mod h1:NME6d8hrE+Jo+kLUZHhXShYV8e40hYkm4BbSLQKtvAo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/iwilltry42/chromem-go v0.0.0-20240813194839-d838df05b583 h1:xTsr6cysGZGpu9xYaLiYItFu47Lh54jC49OwYX7fE2M= -github.com/iwilltry42/chromem-go v0.0.0-20240813194839-d838df05b583/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= +github.com/iwilltry42/chromem-go v0.0.0-20240814135107-86b4f217a8e8 h1:Tob2qUvv7zEeVNDb4kNhAmboaj0zUYUlZ+fcJg/ru14= +github.com/iwilltry42/chromem-go v0.0.0-20240814135107-86b4f217a8e8/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= github.com/iwilltry42/golc v0.0.113-0.20240802113826-d065a3c5b0c7 h1:2AzzbKVW1iP2F+ovqJKq801l6tgxYPt9m2zFKbs+i/Y= github.com/iwilltry42/golc v0.0.113-0.20240802113826-d065a3c5b0c7/go.mod h1:w692KzkSTSvXROfyu+jYauNXB4YaL1s8zHPDMnNW88o= github.com/iwilltry42/pdf v0.0.0-20240517145113-99fbaebc5dd3 h1:rCVwFT7Q+HxpijWfSzKTYX4pCDMS7oy/I/WzU30VXyI= diff --git a/pkg/datastore/retrieve.go b/pkg/datastore/retrieve.go index 5611738..29d351c 100644 --- a/pkg/datastore/retrieve.go +++ b/pkg/datastore/retrieve.go @@ -3,6 +3,7 @@ package datastore import ( "context" "github.com/gptscript-ai/knowledge/pkg/datastore/types" + "github.com/philippgille/chromem-go" "log/slog" "github.com/gptscript-ai/knowledge/pkg/datastore/defaults" @@ -29,9 +30,41 @@ func (s *Datastore) Retrieve(ctx context.Context, datasetIDs []string, query str } retrievalFlow.FillDefaults(topK) - return retrievalFlow.Run(ctx, s, query, datasetIDs, &flows.RetrievalFlowOpts{Keywords: opts.Keywords}) + var whereDocs []chromem.WhereDocument + if len(opts.Keywords) > 0 { + whereDoc := chromem.WhereDocument{ + Operator: chromem.WhereDocumentOperatorOr, + WhereDocuments: []chromem.WhereDocument{}, + } + whereDocNot := chromem.WhereDocument{ + Operator: chromem.WhereDocumentOperatorAnd, + WhereDocuments: []chromem.WhereDocument{}, + } + for _, kw := range opts.Keywords { + if kw[0] == '-' { + whereDocNot.WhereDocuments = append(whereDocNot.WhereDocuments, chromem.WhereDocument{ + Operator: chromem.WhereDocumentOperatorNotContains, + Value: kw[1:], + }) + } else { + whereDoc.WhereDocuments = append(whereDoc.WhereDocuments, chromem.WhereDocument{ + Operator: chromem.WhereDocumentOperatorContains, + Value: kw, + }) + } + } + if len(whereDoc.WhereDocuments) > 0 { + whereDocs = append(whereDocs, whereDoc) + } + if len(whereDocNot.WhereDocuments) > 0 { + whereDocs = append(whereDocs, whereDocNot) + } + + } + + return retrievalFlow.Run(ctx, s, query, datasetIDs, &flows.RetrievalFlowOpts{Where: nil, WhereDocument: whereDocs}) } -func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string, keywords ...string) ([]vectorstore.Document, error) { - return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID, keywords...) +func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vectorstore.Document, error) { + return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID, where, whereDocument) } diff --git a/pkg/datastore/retrievers/retrievers.go b/pkg/datastore/retrievers/retrievers.go index b1a70ab..a789710 100644 --- a/pkg/datastore/retrievers/retrievers.go +++ b/pkg/datastore/retrievers/retrievers.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/gptscript-ai/knowledge/pkg/datastore/store" + "github.com/philippgille/chromem-go" "log/slog" "github.com/gptscript-ai/knowledge/pkg/datastore/defaults" @@ -11,7 +12,7 @@ import ( ) type Retriever interface { - Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, keywords ...string) ([]vs.Document, error) + Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) Name() string } @@ -42,7 +43,7 @@ func (r *BasicRetriever) Name() string { return BasicRetrieverName } -func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, keywords ...string) ([]vs.Document, error) { +func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) { if len(datasetIDs) > 1 { return nil, fmt.Errorf("basic retriever does not support querying multiple datasets") @@ -60,5 +61,5 @@ func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query log.Debug("[BasicRetriever] TopK not set, using default", "default", defaults.TopK) r.TopK = defaults.TopK } - return store.SimilaritySearch(ctx, query, r.TopK, datasetID, keywords...) + return store.SimilaritySearch(ctx, query, r.TopK, datasetID, where, whereDocument) } diff --git a/pkg/datastore/retrievers/routing.go b/pkg/datastore/retrievers/routing.go index 81158ee..0fd3b21 100644 --- a/pkg/datastore/retrievers/routing.go +++ b/pkg/datastore/retrievers/routing.go @@ -8,6 +8,7 @@ import ( "github.com/gptscript-ai/knowledge/pkg/datastore/store" "github.com/gptscript-ai/knowledge/pkg/llm" vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" + "github.com/philippgille/chromem-go" "log/slog" ) @@ -35,7 +36,7 @@ type routingResp struct { Result string `json:"result"` } -func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, keywords ...string) ([]vs.Document, error) { +func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) { log := slog.With("component", "RoutingRetriever") // TODO: properly handle the datasetIDs input @@ -92,5 +93,5 @@ func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, quer slog.Debug("Routing query to dataset", "query", query, "dataset", resp.Result) - return store.SimilaritySearch(ctx, query, r.TopK, resp.Result, keywords...) + return store.SimilaritySearch(ctx, query, r.TopK, resp.Result, where, whereDocument) } diff --git a/pkg/datastore/retrievers/subquery.go b/pkg/datastore/retrievers/subquery.go index ab3eb88..b682b81 100644 --- a/pkg/datastore/retrievers/subquery.go +++ b/pkg/datastore/retrievers/subquery.go @@ -7,6 +7,7 @@ import ( "github.com/gptscript-ai/knowledge/pkg/datastore/store" "github.com/gptscript-ai/knowledge/pkg/llm" vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" + "github.com/philippgille/chromem-go" "log/slog" "strings" ) @@ -37,7 +38,7 @@ type subqueryResp struct { Results []string `json:"results"` } -func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, keywords ...string) ([]vs.Document, error) { +func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) { if len(datasetIDs) > 1 { return nil, fmt.Errorf("basic retriever does not support querying multiple datasets") @@ -84,7 +85,7 @@ func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, quer var resultDocs []vs.Document for _, q := range queries { - docs, err := store.SimilaritySearch(ctx, q, s.TopK, datasetID, keywords...) + docs, err := store.SimilaritySearch(ctx, q, s.TopK, datasetID, where, whereDocument) if err != nil { return nil, err } diff --git a/pkg/datastore/store/store.go b/pkg/datastore/store/store.go index 21594b3..bd7d9da 100644 --- a/pkg/datastore/store/store.go +++ b/pkg/datastore/store/store.go @@ -4,10 +4,11 @@ import ( "context" "github.com/gptscript-ai/knowledge/pkg/index" vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" + "github.com/philippgille/chromem-go" ) type Store interface { ListDatasets(ctx context.Context) ([]index.Dataset, error) GetDataset(ctx context.Context, datasetID string) (*index.Dataset, error) - SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, keywords ...string) ([]vs.Document, error) + SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) } diff --git a/pkg/flows/flows.go b/pkg/flows/flows.go index f1c2b2e..34b805d 100644 --- a/pkg/flows/flows.go +++ b/pkg/flows/flows.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/gptscript-ai/knowledge/pkg/datastore/store" + "github.com/philippgille/chromem-go" "io" "log/slog" "slices" @@ -115,7 +116,8 @@ func (f *RetrievalFlow) FillDefaults(topK int) { } type RetrievalFlowOpts struct { - Keywords []string + Where map[string]string + WhereDocument []chromem.WhereDocument } func (f *RetrievalFlow) Run(ctx context.Context, store store.Store, query string, datasetIDs []string, opts *RetrievalFlowOpts) (*dstypes.RetrievalResponse, error) { @@ -140,7 +142,7 @@ func (f *RetrievalFlow) Run(ctx context.Context, store store.Store, query string } for _, q := range queries { - docs, err := f.Retriever.Retrieve(ctx, store, q, datasetIDs, opts.Keywords...) + docs, err := f.Retriever.Retrieve(ctx, store, q, datasetIDs, opts.Where, opts.WhereDocument) if err != nil { return nil, fmt.Errorf("failed to retrieve documents for query %q using retriever %q: %w", q, f.Retriever.Name(), err) } diff --git a/pkg/vectorstore/chromem/chromem.go b/pkg/vectorstore/chromem/chromem.go index 66eddd9..41a437f 100644 --- a/pkg/vectorstore/chromem/chromem.go +++ b/pkg/vectorstore/chromem/chromem.go @@ -3,18 +3,16 @@ package chromem import ( "context" "fmt" + "github.com/google/uuid" + "github.com/gptscript-ai/knowledge/pkg/env" + vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" "github.com/gptscript-ai/knowledge/pkg/vectorstore/errors" + "github.com/philippgille/chromem-go" "log/slog" "maps" "os" "path/filepath" "strconv" - "strings" - - "github.com/google/uuid" - "github.com/gptscript-ai/knowledge/pkg/env" - vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" - "github.com/philippgille/chromem-go" ) // VsChromemEmbeddingParallelThread can be set as an environment variable to control the number of parallel API calls to create embedding for documents. Default is 100 @@ -109,7 +107,7 @@ func convertStringMapToAnyMap(m map[string]string) map[string]any { return convertedMap } -func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, keywords ...string) ([]vs.Document, error) { +func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) { col := s.db.GetCollection(collection, s.embeddingFunc) if col == nil { return nil, fmt.Errorf("%w: %q", errors.ErrCollectionNotFound, collection) @@ -124,7 +122,9 @@ func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments slog.Debug("Reduced number of documents to search for", "numDocuments", numDocuments) } - qr, err := col.Query(ctx, query, numDocuments, nil, nil) + slog.Debug("filtering documents", "where", where, "whereDocument", whereDocument) + + qr, err := col.Query(ctx, query, numDocuments, where, whereDocument) if err != nil { return nil, err } @@ -135,17 +135,7 @@ func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments var sDocs []vs.Document - slog.Debug("filtering documents by keywords", "keywords", keywords) - -resultLoop: for _, qrd := range qr { - for _, keyword := range keywords { - if !strings.Contains(qrd.Content, keyword) { - slog.Debug("Document does not contain keyword", "keyword", keyword, "documentID", qrd.ID) - continue resultLoop - } - } - sDocs = append(sDocs, vs.Document{ Metadata: convertStringMapToAnyMap(qrd.Metadata), SimilarityScore: qrd.Similarity, @@ -153,8 +143,6 @@ resultLoop: }) } - slog.Debug("Found similar documents", "numDocuments", len(sDocs), "numUnfilteredDocuments", len(qr)) - return sDocs, nil } @@ -162,7 +150,7 @@ func (s *Store) RemoveCollection(_ context.Context, collection string) error { return s.db.DeleteCollection(collection) } -func (s *Store) RemoveDocument(ctx context.Context, documentID string, collection string, where, whereDocument map[string]string) error { +func (s *Store) RemoveDocument(ctx context.Context, documentID string, collection string, where map[string]string, whereDocument []chromem.WhereDocument) error { col := s.db.GetCollection(collection, s.embeddingFunc) if col == nil { return fmt.Errorf("%w: %q", errors.ErrCollectionNotFound, collection) diff --git a/pkg/vectorstore/vectorstores.go b/pkg/vectorstore/vectorstores.go index b3e8c9f..db2145d 100644 --- a/pkg/vectorstore/vectorstores.go +++ b/pkg/vectorstore/vectorstores.go @@ -2,14 +2,15 @@ package vectorstore import ( "context" + "github.com/philippgille/chromem-go" ) type VectorStore interface { CreateCollection(ctx context.Context, collection string) error - AddDocuments(ctx context.Context, docs []Document, collection string) ([]string, error) // @return documentIDs, error - SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, keywords ...string) ([]Document, error) //nolint:lll + AddDocuments(ctx context.Context, docs []Document, collection string) ([]string, error) // @return documentIDs, error + SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, where map[string]string, whereDocument []chromem.WhereDocument) ([]Document, error) //nolint:lll RemoveCollection(ctx context.Context, collection string) error - RemoveDocument(ctx context.Context, documentID string, collection string, where, whereDocument map[string]string) error + RemoveDocument(ctx context.Context, documentID string, collection string, where map[string]string, whereDocument []chromem.WhereDocument) error ImportCollectionsFromFile(ctx context.Context, path string, collections ...string) error ExportCollectionsToFile(ctx context.Context, path string, collections ...string) error