-
-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add support for VectorStoreRetrieverMemory (#54)
Co-authored-by: David Miguel <[email protected]>
- Loading branch information
1 parent
3b5c0b2
commit 72cd1b1
Showing
14 changed files
with
210 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import '../../langchain.dart'; | ||
import 'utils.dart'; | ||
|
||
/// {@template vector_store_retriever_memory} | ||
/// VectorStoreRetriever-backed memory. | ||
/// {@endtemplate} | ||
class VectorStoreRetrieverMemory implements BaseMemory { | ||
/// {@macro vector_store_retriever_memory} | ||
VectorStoreRetrieverMemory({ | ||
required this.retriever, | ||
this.memoryKey = defaultMemoryKey, | ||
this.inputKey, | ||
this.excludeInputKeys = const {}, | ||
this.returnDocs = false, | ||
}); | ||
|
||
/// VectorStoreRetriever object to connect to. | ||
final VectorStoreRetriever retriever; | ||
|
||
/// Name of the key where the memories are in the map returned by | ||
/// [loadMemoryVariables]. | ||
final String memoryKey; | ||
|
||
/// The input key to use for the query to the vector store. | ||
/// | ||
/// If null, the input key is inferred from the prompt (the input key hat | ||
/// was filled in by the user (i.e. not a memory key)). | ||
final String? inputKey; | ||
|
||
/// Input keys to exclude in addition to memory key when constructing the | ||
/// document. | ||
final Set<String> excludeInputKeys; | ||
|
||
/// Whether or not to return the result of querying the database directly. | ||
/// If false, the page content of all the documents is returned as a single | ||
/// string. | ||
final bool returnDocs; | ||
|
||
/// Default key for [memoryKey]. | ||
static const String defaultMemoryKey = 'memory'; | ||
|
||
@override | ||
Set<String> get memoryKeys => {memoryKey}; | ||
|
||
@override | ||
Future<MemoryVariables> loadMemoryVariables([ | ||
final MemoryInputValues values = const {}, | ||
]) async { | ||
final promptInputKey = inputKey ?? getPromptInputKey(values, memoryKeys); | ||
final query = values[promptInputKey]; | ||
final docs = await retriever.getRelevantDocuments(query); | ||
return { | ||
memoryKey: returnDocs | ||
? docs | ||
: docs.map((final doc) => doc.pageContent).join('\n'), | ||
}; | ||
} | ||
|
||
@override | ||
Future<void> saveContext({ | ||
required final MemoryInputValues inputValues, | ||
required final MemoryOutputValues outputValues, | ||
}) async { | ||
final docs = _buildDocuments(inputValues, outputValues); | ||
await retriever.addDocuments(docs); | ||
} | ||
|
||
/// Builds the documents to save to the vector store from the given | ||
/// [inputValues] and [outputValues]. | ||
List<Document> _buildDocuments( | ||
final MemoryInputValues inputValues, | ||
final MemoryOutputValues outputValues, | ||
) { | ||
final excludeKeys = {memoryKey, ...excludeInputKeys}; | ||
final filteredInputs = { | ||
for (final entry in inputValues.entries) | ||
if (!excludeKeys.contains(entry.key)) entry.key: entry.value | ||
}; | ||
final inputsOutputs = {...filteredInputs, ...outputValues}; | ||
final pageContent = inputsOutputs.entries.map((final entry) { | ||
return '${entry.key}: ${entry.value}'; | ||
}).join('\n'); | ||
return [Document(pageContent: pageContent)]; | ||
} | ||
|
||
@override | ||
Future<void> clear() async { | ||
// Nothing to clear | ||
} | ||
} |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import 'package:langchain/langchain.dart'; | ||
import 'package:test/test.dart'; | ||
|
||
void main() { | ||
group('VectorStoreRetrieverMemory tests', () { | ||
test('Test vector store memory', () async { | ||
final embeddings = _FakeEmbeddings(); | ||
final vectorStore = MemoryVectorStore(embeddings: embeddings); | ||
final memory = VectorStoreRetrieverMemory( | ||
retriever: vectorStore.asRetriever(), | ||
); | ||
|
||
final result1 = await memory.loadMemoryVariables({'input': 'foo'}); | ||
expect(result1[VectorStoreRetrieverMemory.defaultMemoryKey], ''); | ||
|
||
await memory.saveContext( | ||
inputValues: { | ||
'foo': 'bar', | ||
}, | ||
outputValues: { | ||
'bar': 'foo', | ||
}, | ||
); | ||
final result2 = await memory.loadMemoryVariables({'input': 'foo'}); | ||
expect( | ||
result2[VectorStoreRetrieverMemory.defaultMemoryKey], | ||
'foo: bar\nbar: foo', | ||
); | ||
}); | ||
|
||
test('Test returnDocs', () async { | ||
final embeddings = _FakeEmbeddings(); | ||
final vectorStore = MemoryVectorStore(embeddings: embeddings); | ||
final memory = VectorStoreRetrieverMemory( | ||
retriever: vectorStore.asRetriever(), | ||
returnDocs: true, | ||
); | ||
|
||
await memory.saveContext( | ||
inputValues: { | ||
'foo': 'bar', | ||
}, | ||
outputValues: { | ||
'bar': 'foo', | ||
}, | ||
); | ||
final result = await memory.loadMemoryVariables({'input': 'foo'}); | ||
const expectedDoc = Document(pageContent: 'foo: bar\nbar: foo'); | ||
expect( | ||
result[VectorStoreRetrieverMemory.defaultMemoryKey], | ||
[expectedDoc], | ||
); | ||
}); | ||
|
||
test('Test excludeInputKeys', () async { | ||
final embeddings = _FakeEmbeddings(); | ||
final vectorStore = MemoryVectorStore(embeddings: embeddings); | ||
final memory = VectorStoreRetrieverMemory( | ||
retriever: vectorStore.asRetriever(), | ||
excludeInputKeys: {'foo'}, | ||
); | ||
|
||
final result1 = await memory.loadMemoryVariables({'input': 'foo'}); | ||
expect(result1[VectorStoreRetrieverMemory.defaultMemoryKey], ''); | ||
|
||
await memory.saveContext( | ||
inputValues: { | ||
'foo': 'bar', | ||
}, | ||
outputValues: { | ||
'bar': 'foo', | ||
}, | ||
); | ||
final result2 = await memory.loadMemoryVariables({'input': 'foo'}); | ||
expect( | ||
result2[VectorStoreRetrieverMemory.defaultMemoryKey], | ||
'bar: foo', | ||
); | ||
}); | ||
}); | ||
} | ||
|
||
class _FakeEmbeddings implements Embeddings { | ||
@override | ||
Future<List<List<double>>> embedDocuments( | ||
final List<String> documents, | ||
) async { | ||
return documents.map(_embed).toList(growable: false); | ||
} | ||
|
||
@override | ||
Future<List<double>> embedQuery( | ||
final String query, | ||
) async { | ||
return _embed(query); | ||
} | ||
|
||
List<double> _embed(final String text) { | ||
return switch (text) { | ||
'foo' => [1.0, 1.0], | ||
'bar' => [-1.0, -1.0], | ||
'foo: bar\nbar: foo' => [1.0, -1.0], | ||
'bar: foo' => [-1.0, 1.0], | ||
_ => throw Exception('Unknown text: $text'), | ||
}; | ||
} | ||
} |
File renamed without changes.