Skip to content

Commit

Permalink
feat(vector-stores): Allow to pass vector search config (#135)
Browse files Browse the repository at this point in the history
- Made VectorStoreSearchType a sealed class (similarity and mrr)
- Each can have specific configuration (e.g. scoreThreshold)
- Implemented scoreThreshold in MemoryVectorStore and VertexAIMatchingEngine
  • Loading branch information
davidmigloz authored Aug 22, 2023
1 parent b211ab4 commit 5b8fa5a
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import '../combine_documents/stuff.dart';
import '../llm_chain.dart';

const _promptTemplate = '''
Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ class VectorStoreRetriever<V extends VectorStore> implements BaseRetriever {
/// {@macro vector_store_retriever}
const VectorStoreRetriever({
required this.vectorStore,
this.searchType = VectorStoreSearchType.similarity,
this.k = 4,
this.searchType = const VectorStoreSimilaritySearch(),
});

/// The vector store to retrieve documents from.
Expand All @@ -19,12 +18,9 @@ class VectorStoreRetriever<V extends VectorStore> implements BaseRetriever {
/// The type of search to perform.
final VectorStoreSearchType searchType;

/// The number of documents to return.
final int k;

@override
Future<List<Document>> getRelevantDocuments(final String query) {
return vectorStore.search(query: query, searchType: searchType, k: k);
return vectorStore.search(query: query, searchType: searchType);
}

/// Runs more documents through the embeddings and add to the vector store.
Expand Down
85 changes: 43 additions & 42 deletions packages/langchain/lib/src/documents/vector_stores/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -81,29 +81,36 @@ abstract class VectorStore {
/// Returns docs most similar to query using specified search type.
///
/// - [query] is the query to search for.
/// - [searchType] is the type of search to perform.
/// - [k] is the number of documents to return.
/// - [searchType] is the type of search to perform, either
/// [VectorStoreSearchType.similarity] (default) or
/// [VectorStoreSearchType.mmr].
Future<List<Document>> search({
required final String query,
required final VectorStoreSearchType searchType,
final int k = 4,
}) {
return switch (searchType) {
VectorStoreSearchType.similarity => similaritySearch(query: query, k: k),
VectorStoreSearchType.mmr =>
maxMarginalRelevanceSearch(query: query, k: k),
final VectorStoreSimilaritySearch config => similaritySearch(
query: query,
config: config,
),
final VectorStoreMMRSearch config =>
maxMarginalRelevanceSearch(query: query, config: config),
};
}

/// Returns docs most similar to query using similarity.
///
/// - [query] is the query to search for.
/// - [k] is the number of documents to return.
/// - [query] the query to search for.
/// - [config] the configuration for the search.
Future<List<Document>> similaritySearch({
required final String query,
final int k = 4,
final VectorStoreSimilaritySearch config =
const VectorStoreSimilaritySearch(),
}) async {
final docsWithScores = await similaritySearchWithScores(query: query, k: k);
final docsWithScores = await similaritySearchWithScores(
query: query,
config: config,
);
return docsWithScores
.map((final docWithScore) => docWithScore.$1)
.toList(growable: false);
Expand All @@ -112,13 +119,16 @@ abstract class VectorStore {
/// Returns docs most similar to embedding vector using similarity.
///
/// - [embedding] is the embedding vector to look up documents similar to.
/// - [k] is the number of documents to return.
/// - [config] the configuration for the search.
Future<List<Document>> similaritySearchByVector({
required final List<double> embedding,
final int k = 4,
final VectorStoreSimilaritySearch config =
const VectorStoreSimilaritySearch(),
}) async {
final docsWithScores =
await similaritySearchByVectorWithScores(embedding: embedding, k: k);
final docsWithScores = await similaritySearchByVectorWithScores(
embedding: embedding,
config: config,
);
return docsWithScores
.map((final docWithScore) => docWithScore.$1)
.toList(growable: false);
Expand All @@ -128,29 +138,31 @@ abstract class VectorStore {
/// 0 is dissimilar, 1 is most similar.
///
/// - [query] is the query to search for.
/// - [k] is the number of documents to return.
/// - [config] the configuration for the search.
///
/// Returns a list of tuples of documents and their similarity scores.
Future<List<(Document, double score)>> similaritySearchWithScores({
required final String query,
final int k = 4,
final VectorStoreSimilaritySearch config =
const VectorStoreSimilaritySearch(),
}) async {
return similaritySearchByVectorWithScores(
embedding: await embeddings.embedQuery(query),
k: k,
config: config,
);
}

/// Returns docs and relevance scores in the range [0, 1],
/// 0 is dissimilar, 1 is most similar.
///
/// - [query] is the query to search for.
/// - [k] is the number of documents to return.
/// - [config] the configuration for the search.
///
/// Returns a list of tuples of documents and their similarity scores.
Future<List<(Document, double scores)>> similaritySearchByVectorWithScores({
required final List<double> embedding,
final int k = 4,
final VectorStoreSimilaritySearch config =
const VectorStoreSimilaritySearch(),
});

/// Returns docs selected using the maximal marginal relevance algorithm (MMR)
Expand All @@ -160,22 +172,14 @@ abstract class VectorStore {
/// AND diversity among selected documents.
///
/// - [query] is the query to search for.
/// - [k] is the number of documents to return.
/// - [fetchK] is the number of documents to pass to MMR algorithm.
/// - [lambdaMult] is a umber between 0 and 1 that determines the degree of
/// diversity among the results with 0 corresponding to maximum diversity
/// and 1 to minimum diversity.
/// - [config] the configuration for the search.
Future<List<Document>> maxMarginalRelevanceSearch({
required final String query,
final int k = 4,
final int fetchK = 20,
final double lambdaMult = 0.5,
final VectorStoreMMRSearch config = const VectorStoreMMRSearch(),
}) async {
return maxMarginalRelevanceSearchByVector(
embedding: await embeddings.embedQuery(query),
k: k,
fetchK: fetchK,
lambdaMult: lambdaMult,
config: config,
);
}

Expand All @@ -186,29 +190,26 @@ abstract class VectorStore {
/// AND diversity among selected documents.
///
/// - [embedding] is the embedding vector to look up documents similar to.
/// - [k] is the number of documents to return.
/// - [fetchK] is the number of documents to pass to MMR algorithm.
/// - [lambdaMult] is a umber between 0 and 1 that determines the degree of
/// diversity among the results with 0 corresponding to maximum diversity
/// and 1 to minimum diversity.
/// - [config] the configuration for the search.
List<Document> maxMarginalRelevanceSearchByVector({
required final List<double> embedding,
final int k = 4,
final int fetchK = 20,
final double lambdaMult = 0.5,
final VectorStoreMMRSearch config = const VectorStoreMMRSearch(),
}) {
throw UnimplementedError('MRR not supported for this vector store');
}

/// Returns a retriever that uses this vector store.
/// Returns a [VectorStoreRetriever] that uses this vector store.
///
/// - [searchType] is the type of search to perform, either
/// [VectorStoreSearchType.similarity] (default) or
/// [VectorStoreSearchType.mmr].
VectorStoreRetriever asRetriever({
final VectorStoreSearchType searchType = VectorStoreSearchType.similarity,
final int k = 4,
final VectorStoreSearchType searchType =
const VectorStoreSimilaritySearch(),
}) {
return VectorStoreRetriever(
vectorStore: this,
searchType: searchType,
k: k,
);
}
}
13 changes: 10 additions & 3 deletions packages/langchain/lib/src/documents/vector_stores/memory.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import 'package:meta/meta.dart';
import '../embeddings/base.dart';
import '../models/models.dart';
import 'base.dart';
import 'models/models.dart';

/// Vector store that stores vectors in memory.
///
Expand Down Expand Up @@ -117,9 +118,10 @@ class MemoryVectorStore extends VectorStore {
@override
Future<List<(Document, double)>> similaritySearchByVectorWithScores({
required final List<double> embedding,
final int k = 4,
final VectorStoreSimilaritySearch config =
const VectorStoreSimilaritySearch(),
}) async {
final searches = memoryVectors
var searches = memoryVectors
.asMap()
.map(
(final key, final value) => MapEntry(
Expand All @@ -129,7 +131,12 @@ class MemoryVectorStore extends VectorStore {
)
.entries
.sorted((final a, final b) => (a.value > b.value ? -1 : 1))
.take(k);
.take(config.k);

if (config.scoreThreshold != null) {
searches = searches
.where((final search) => search.value >= config.scoreThreshold!);
}

return searches
.map(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,75 @@
enum VectorStoreSearchType {
/// Search for documents that are similar to the query.
/// Eg. using Cosine similarity.
similarity,

/// Maximal marginal relevance.
///
/// Maximal marginal relevance optimizes for similarity to query
/// AND diversity among selected documents.
mmr,
/// {@template vector_store_search_type}
/// Vector store search type.
/// {@endtemplate}
sealed class VectorStoreSearchType {
/// {@macro vector_store_search_type}
const VectorStoreSearchType({
required this.k,
});

/// The number of documents to return.
final int k;

/// Similarity search.
factory VectorStoreSearchType.similarity({
final int k = 4,
final double? scoreThreshold,
}) {
return VectorStoreSimilaritySearch(
k: k,
scoreThreshold: scoreThreshold,
);
}

/// Maximal Marginal Relevance (MMR) search.
factory VectorStoreSearchType.mmr({
final int k = 4,
final int fetchK = 20,
final double lambdaMult = 0.5,
}) {
return VectorStoreMMRSearch(
k: k,
fetchK: fetchK,
lambdaMult: lambdaMult,
);
}
}

/// {@template vector_store_similarity_search}
/// Similarity search.
/// Eg. using Cosine similarity.
/// {@endtemplate}
class VectorStoreSimilaritySearch extends VectorStoreSearchType {
/// {@macro vector_store_similarity_search}
const VectorStoreSimilaritySearch({
super.k = 4,
this.scoreThreshold,
});

/// The minimum relevance score a document must have to be returned.
/// Range: [0, 1].
final double? scoreThreshold;
}

/// {@template vector_store_mmr_search}
/// Maximal Marginal Relevance (MMR) search .
///
/// Maximal marginal relevance optimizes for similarity to query
/// AND diversity among selected documents.
/// {@endtemplate}
class VectorStoreMMRSearch extends VectorStoreSearchType {
/// {@macro vector_store_mmr_search}
const VectorStoreMMRSearch({
super.k = 4,
this.fetchK = 20,
this.lambdaMult = 0.5,
});

/// The number of documents to pass to MMR algorithm.
final int fetchK;

/// Number between 0 and 1 that determines the degree of diversity among the
/// results with 0 corresponding to maximum diversity and 1 to minimum
/// diversity.
final double lambdaMult;
}
3 changes: 1 addition & 2 deletions packages/langchain/test/chains/retrieval_qa_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ void main() {
final res = await retrievalQA({'query': query});

const expectedRes = '''
Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
what's this
Expand Down
20 changes: 16 additions & 4 deletions packages/langchain/test/documents/vector_stores/memory_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ void main() {
embeddings: embeddings,
);

final results = await store.similaritySearch(query: 'chao', k: 1);
final results = await store.similaritySearch(
query: 'chao',
config: const VectorStoreSimilaritySearch(k: 1),
);

expect(results.length, 1);
expect(results.first.id, '3');
Expand All @@ -30,7 +33,10 @@ void main() {
embeddings: embeddings,
);

final results = await store.similaritySearch(query: 'chao', k: 1);
final results = await store.similaritySearch(
query: 'chao',
config: const VectorStoreSimilaritySearch(k: 1),
);

expect(results.length, 1);
expect(results.first.id, '3');
Expand Down Expand Up @@ -61,7 +67,10 @@ void main() {
],
);

final results = await store.similaritySearch(query: 'chao', k: 1);
final results = await store.similaritySearch(
query: 'chao',
config: const VectorStoreSimilaritySearch(k: 1),
);

expect(results.length, 1);
expect(results.first.id, '3');
Expand All @@ -81,7 +90,10 @@ void main() {
);
await store.delete(ids: ['3']);

final results = await store.similaritySearch(query: 'chao', k: 1);
final results = await store.similaritySearch(
query: 'chao',
config: const VectorStoreSimilaritySearch(k: 1),
);

expect(results.length, 1);
expect(results.first.id, '2');
Expand Down
Loading

0 comments on commit 5b8fa5a

Please sign in to comment.