-
-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(embeddings): Add support for CacheBackedEmbeddings (#131)
- Loading branch information
1 parent
f06920d
commit 27d8b77
Showing
5 changed files
with
256 additions
and
4 deletions.
There are no files selected for viewing
140 changes: 140 additions & 0 deletions
140
packages/langchain/lib/src/documents/embeddings/cache.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import 'dart:convert'; | ||
import 'dart:typed_data'; | ||
|
||
import 'package:crypto/crypto.dart'; | ||
import 'package:uuid/uuid.dart'; | ||
|
||
import '../../storage/storage.dart'; | ||
import 'base.dart'; | ||
|
||
/// {@template cache_backed_embeddings} | ||
/// Wrapper around an embedder that caches embeddings in a key-value store to | ||
/// avoid recomputing embeddings for the same text. | ||
/// | ||
/// When embedding a new document, the method first checks the cache for the | ||
/// embeddings. If the embeddings are not found, the method uses the underlying | ||
/// embedder to embed the documents and stores the results in the cache. | ||
/// | ||
/// The factory constructor [CacheBackedEmbeddings.fromByteStore] can be used | ||
/// to create a cache backed embeddings that uses a [EncoderBackedStore] which | ||
/// generates the keys for the cache by hashing the text. | ||
/// | ||
/// The [CacheBackedEmbeddings.embedQuery] method does not support caching at | ||
/// the moment. | ||
/// {@endtemplate} | ||
class CacheBackedEmbeddings implements Embeddings { | ||
/// {@macro cache_backed_embeddings} | ||
const CacheBackedEmbeddings({ | ||
required this.underlyingEmbeddings, | ||
required this.documentEmbeddingsStore, | ||
}); | ||
|
||
/// The embedder to use for computing embeddings. | ||
final Embeddings underlyingEmbeddings; | ||
|
||
/// The store to use for caching embeddings. | ||
final BaseStore<String, List<double>> documentEmbeddingsStore; | ||
|
||
/// Create a cache backed embeddings that uses a [EncoderBackedStore] which | ||
/// generates the keys for the cache by hashing the text. | ||
/// | ||
/// - [underlyingEmbeddings] is the embedder to use for computing embeddings. | ||
/// - [documentEmbeddingsStore] is the store to use for caching embeddings. | ||
/// - [namespace] is the namespace to use for the cache. This namespace is | ||
/// used to avoid collisions of the same text embedded using different | ||
/// embeddings models. For example, you can set it to the name of the | ||
/// embedding model used. | ||
/// | ||
/// Example: | ||
/// ```dart | ||
/// final cacheBackedEmbeddings = CacheBackedEmbeddings.fromByteStore( | ||
/// underlyingEmbeddings: OpenAIEmbeddings(apiKey: openaiApiKey), | ||
/// documentEmbeddingsStore: InMemoryStore(), | ||
/// namespace: 'text-embedding-ada-002', | ||
/// ); | ||
factory CacheBackedEmbeddings.fromByteStore({ | ||
required final Embeddings underlyingEmbeddings, | ||
required final BaseStore<String, Uint8List> documentEmbeddingsStore, | ||
final String namespace = '', | ||
}) { | ||
return CacheBackedEmbeddings( | ||
underlyingEmbeddings: underlyingEmbeddings, | ||
documentEmbeddingsStore: EncoderBackedStore( | ||
store: documentEmbeddingsStore, | ||
encoder: EmbeddingsByteStoreEncoder(namespace: namespace), | ||
), | ||
); | ||
} | ||
|
||
@override | ||
Future<List<List<double>>> embedDocuments(final List<String> texts) async { | ||
final vectors = await documentEmbeddingsStore.get(texts); | ||
final missingIndices = [ | ||
for (var i = 0; i < texts.length; i++) | ||
if (vectors[i] == null) i, | ||
]; | ||
final missingTexts = | ||
missingIndices.map((final i) => texts[i]).toList(growable: false); | ||
|
||
if (missingTexts.isNotEmpty) { | ||
final missingVectors = | ||
await underlyingEmbeddings.embedDocuments(missingTexts); | ||
final missingVectorPairs = missingIndices | ||
.map((final i) => (texts[i], missingVectors[i])) | ||
.toList(growable: false); | ||
await documentEmbeddingsStore.set(missingVectorPairs); | ||
for (var i = 0; i < missingIndices.length; i++) { | ||
vectors[missingIndices[i]] = missingVectors[i]; | ||
} | ||
} | ||
return vectors.cast(); | ||
} | ||
|
||
/// Embed query text. | ||
/// | ||
/// This method does not support caching at the moment. | ||
/// | ||
/// Support for caching queries is easily to implement, but might make | ||
/// sense to hold off to see the most common patterns. | ||
/// | ||
/// If the cache has an eviction policy, we may need to be a bit more careful | ||
/// about sharing the cache between documents and queries. Generally, | ||
/// one is OK evicting query caches, but document caches should be kept. | ||
@override | ||
Future<List<double>> embedQuery(final String query) { | ||
return underlyingEmbeddings.embedQuery(query); | ||
} | ||
} | ||
|
||
class EmbeddingsByteStoreEncoder | ||
implements StoreEncoder<String, List<double>, String, Uint8List> { | ||
const EmbeddingsByteStoreEncoder({ | ||
this.namespace = '', | ||
this.uuid = const Uuid(), | ||
}); | ||
|
||
final String namespace; | ||
final Uuid uuid; | ||
|
||
@override | ||
String encodeKey(final String key) { | ||
final keyHash = sha1.convert(utf8.encode(key)).toString(); | ||
return uuid.v5(Uuid.NAMESPACE_URL, keyHash); | ||
} | ||
|
||
@override | ||
Uint8List encodeValue(final List<double> value) { | ||
return utf8.encoder.convert(json.encode(value)); | ||
} | ||
|
||
@override | ||
String decodeKey(final String encodedKey) => throw UnimplementedError( | ||
'Decoding keys is not supported for the _ByteStoreEncoder.', | ||
); | ||
|
||
@override | ||
List<double> decodeValue(final Uint8List encodedValue) { | ||
// ignore: avoid_dynamic_calls | ||
return json.decode(utf8.decode(encodedValue)).cast<double>(); | ||
} | ||
} |
1 change: 1 addition & 0 deletions
1
packages/langchain/lib/src/documents/embeddings/embeddings.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
export 'base.dart'; | ||
export 'cache.dart'; | ||
export 'fake.dart'; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
111 changes: 111 additions & 0 deletions
111
packages/langchain/test/documents/embeddings/cache.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import 'dart:convert'; | ||
import 'dart:typed_data'; | ||
|
||
import 'package:crypto/crypto.dart'; | ||
import 'package:langchain/langchain.dart'; | ||
import 'package:test/test.dart'; | ||
import 'package:uuid/uuid.dart'; | ||
|
||
void main() async { | ||
group('CacheBackedEmbeddings', () { | ||
late InMemoryStore<String, List<double>> store; | ||
late CacheBackedEmbeddings cacheBackedEmbeddings; | ||
|
||
setUp(() { | ||
store = InMemoryStore(); | ||
cacheBackedEmbeddings = CacheBackedEmbeddings( | ||
underlyingEmbeddings: FakeEmbeddings(deterministic: false), | ||
documentEmbeddingsStore: store, | ||
); | ||
}); | ||
|
||
test( | ||
'embedDocuments returns correct embeddings, and fills missing embeddings', | ||
() async { | ||
final preStoreRes = await store.get(['testDoc']); | ||
expect(preStoreRes.first, isNull); | ||
final res1 = await cacheBackedEmbeddings.embedDocuments(['testDoc']); | ||
final storeRes1 = await store.get(['testDoc']); | ||
expect(res1, storeRes1); | ||
final res2 = await cacheBackedEmbeddings.embedDocuments(['testDoc']); | ||
expect(res2, storeRes1); | ||
final newDocStoreRes = await store.get(['newDoc']); | ||
expect(newDocStoreRes.first, isNull); | ||
final res3 = await cacheBackedEmbeddings.embedDocuments(['newDoc']); | ||
final storeRes3 = await store.get(['newDoc']); | ||
expect(res3, storeRes3); | ||
}); | ||
|
||
test('embedQuery is not cached', () async { | ||
final result = await cacheBackedEmbeddings.embedQuery('testQuery'); | ||
final storeResult = await store.get(['testQuery']); | ||
expect(result.first, isNotNull); | ||
expect(storeResult.first, isNull); | ||
}); | ||
}); | ||
|
||
group('CacheBackedEmbeddings.fromByteStore', () { | ||
late InMemoryStore<String, Uint8List> store; | ||
late CacheBackedEmbeddings cacheBackedEmbeddings; | ||
|
||
setUp(() { | ||
store = InMemoryStore(); | ||
cacheBackedEmbeddings = CacheBackedEmbeddings.fromByteStore( | ||
underlyingEmbeddings: FakeEmbeddings(), | ||
documentEmbeddingsStore: store, | ||
); | ||
}); | ||
|
||
test( | ||
'embedDocuments returns correct embeddings, and fills missing embeddings', | ||
() async { | ||
final res1 = await cacheBackedEmbeddings.embedDocuments(['testDoc']); | ||
final res2 = await cacheBackedEmbeddings.embedDocuments(['testDoc']); | ||
expect(res1, res2); | ||
final res3 = await cacheBackedEmbeddings.embedDocuments(['newDoc']); | ||
expect(res3, isNot(res2)); | ||
}); | ||
|
||
test('embedQuery is not cached', () async { | ||
final result = await cacheBackedEmbeddings.embedQuery('testQuery'); | ||
final storeResult = await store.get(['testQuery']); | ||
expect(result.first, isNotNull); | ||
expect(storeResult.first, isNull); | ||
}); | ||
}); | ||
|
||
group('EmbeddingsByteStoreEncoder tests', () { | ||
const namespace = 'test'; | ||
const uuid = Uuid(); | ||
const key = 'key'; | ||
final keyHash = sha1.convert(utf8.encode(key)).toString(); | ||
final expectedEncodedKey = uuid.v5(Uuid.NAMESPACE_URL, keyHash); | ||
final value = [0.1, 0.2, 0.3]; | ||
final expectedEncodedValue = | ||
Uint8List.fromList(utf8.encode(json.encode(value))); | ||
|
||
const encoder = EmbeddingsByteStoreEncoder(namespace: namespace); | ||
|
||
test('encodeKey returns encoded key', () { | ||
final result = encoder.encodeKey(key); | ||
expect(result, expectedEncodedKey); | ||
}); | ||
|
||
test('encodeValue returns encoded value', () { | ||
final result = encoder.encodeValue(value); | ||
expect(result, expectedEncodedValue); | ||
}); | ||
|
||
test('decodeKey throws UnimplementedError', () { | ||
expect( | ||
() => encoder.decodeKey('anyKey'), | ||
throwsA(isA<UnimplementedError>()), | ||
); | ||
}); | ||
|
||
test('decodeValue returns decoded value', () { | ||
final result = encoder.decodeValue(expectedEncodedValue); | ||
expect(result, value); | ||
}); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters