Skip to content

Commit

Permalink
feat(embeddings): Add support for CacheBackedEmbeddings (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Aug 19, 2023
1 parent f06920d commit 27d8b77
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 4 deletions.
140 changes: 140 additions & 0 deletions packages/langchain/lib/src/documents/embeddings/cache.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import 'dart:convert';
import 'dart:typed_data';

import 'package:crypto/crypto.dart';
import 'package:uuid/uuid.dart';

import '../../storage/storage.dart';
import 'base.dart';

/// {@template cache_backed_embeddings}
/// Wrapper around an embedder that caches embeddings in a key-value store to
/// avoid recomputing embeddings for the same text.
///
/// When embedding a new document, the method first checks the cache for the
/// embeddings. If the embeddings are not found, the method uses the underlying
/// embedder to embed the documents and stores the results in the cache.
///
/// The factory constructor [CacheBackedEmbeddings.fromByteStore] can be used
/// to create a cache backed embeddings that uses a [EncoderBackedStore] which
/// generates the keys for the cache by hashing the text.
///
/// The [CacheBackedEmbeddings.embedQuery] method does not support caching at
/// the moment.
/// {@endtemplate}
class CacheBackedEmbeddings implements Embeddings {
/// {@macro cache_backed_embeddings}
const CacheBackedEmbeddings({
required this.underlyingEmbeddings,
required this.documentEmbeddingsStore,
});

/// The embedder to use for computing embeddings.
final Embeddings underlyingEmbeddings;

/// The store to use for caching embeddings.
final BaseStore<String, List<double>> documentEmbeddingsStore;

/// Create a cache backed embeddings that uses a [EncoderBackedStore] which
/// generates the keys for the cache by hashing the text.
///
/// - [underlyingEmbeddings] is the embedder to use for computing embeddings.
/// - [documentEmbeddingsStore] is the store to use for caching embeddings.
/// - [namespace] is the namespace to use for the cache. This namespace is
/// used to avoid collisions of the same text embedded using different
/// embeddings models. For example, you can set it to the name of the
/// embedding model used.
///
/// Example:
/// ```dart
/// final cacheBackedEmbeddings = CacheBackedEmbeddings.fromByteStore(
/// underlyingEmbeddings: OpenAIEmbeddings(apiKey: openaiApiKey),
/// documentEmbeddingsStore: InMemoryStore(),
/// namespace: 'text-embedding-ada-002',
/// );
factory CacheBackedEmbeddings.fromByteStore({
required final Embeddings underlyingEmbeddings,
required final BaseStore<String, Uint8List> documentEmbeddingsStore,
final String namespace = '',
}) {
return CacheBackedEmbeddings(
underlyingEmbeddings: underlyingEmbeddings,
documentEmbeddingsStore: EncoderBackedStore(
store: documentEmbeddingsStore,
encoder: EmbeddingsByteStoreEncoder(namespace: namespace),
),
);
}

@override
Future<List<List<double>>> embedDocuments(final List<String> texts) async {
final vectors = await documentEmbeddingsStore.get(texts);
final missingIndices = [
for (var i = 0; i < texts.length; i++)
if (vectors[i] == null) i,
];
final missingTexts =
missingIndices.map((final i) => texts[i]).toList(growable: false);

if (missingTexts.isNotEmpty) {
final missingVectors =
await underlyingEmbeddings.embedDocuments(missingTexts);
final missingVectorPairs = missingIndices
.map((final i) => (texts[i], missingVectors[i]))
.toList(growable: false);
await documentEmbeddingsStore.set(missingVectorPairs);
for (var i = 0; i < missingIndices.length; i++) {
vectors[missingIndices[i]] = missingVectors[i];
}
}
return vectors.cast();
}

/// Embed query text.
///
/// This method does not support caching at the moment.
///
/// Support for caching queries is easily to implement, but might make
/// sense to hold off to see the most common patterns.
///
/// If the cache has an eviction policy, we may need to be a bit more careful
/// about sharing the cache between documents and queries. Generally,
/// one is OK evicting query caches, but document caches should be kept.
@override
Future<List<double>> embedQuery(final String query) {
return underlyingEmbeddings.embedQuery(query);
}
}

class EmbeddingsByteStoreEncoder
implements StoreEncoder<String, List<double>, String, Uint8List> {
const EmbeddingsByteStoreEncoder({
this.namespace = '',
this.uuid = const Uuid(),
});

final String namespace;
final Uuid uuid;

@override
String encodeKey(final String key) {
final keyHash = sha1.convert(utf8.encode(key)).toString();
return uuid.v5(Uuid.NAMESPACE_URL, keyHash);
}

@override
Uint8List encodeValue(final List<double> value) {
return utf8.encoder.convert(json.encode(value));
}

@override
String decodeKey(final String encodedKey) => throw UnimplementedError(
'Decoding keys is not supported for the _ByteStoreEncoder.',
);

@override
List<double> decodeValue(final Uint8List encodedValue) {
// ignore: avoid_dynamic_calls
return json.decode(utf8.decode(encodedValue)).cast<double>();
}
}
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export 'base.dart';
export 'cache.dart';
export 'fake.dart';
6 changes: 3 additions & 3 deletions packages/langchain/lib/src/storage/encoder_backed.dart
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class EncoderBackedStore<K, V, EK, EV> implements BaseStore<K, V> {
final BaseStore<EK, EV> store;

/// The encoder/decoder for keys and values.
final Encoder<K, V, EK, EV> encoder;
final StoreEncoder<K, V, EK, EV> encoder;

@override
Future<List<V?>> get(final List<K> keys) async {
Expand Down Expand Up @@ -60,9 +60,9 @@ class EncoderBackedStore<K, V, EK, EV> implements BaseStore<K, V> {
/// {@template encoder}
/// Encoder/decoder for keys and values.
/// {@endtemplate}
abstract class Encoder<K, V, EK, EV> {
abstract interface class StoreEncoder<K, V, EK, EV> {
/// {@macro encoder}
const Encoder();
const StoreEncoder();

/// Encodes a key.
EK encodeKey(final K key);
Expand Down
111 changes: 111 additions & 0 deletions packages/langchain/test/documents/embeddings/cache.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import 'dart:convert';
import 'dart:typed_data';

import 'package:crypto/crypto.dart';
import 'package:langchain/langchain.dart';
import 'package:test/test.dart';
import 'package:uuid/uuid.dart';

void main() async {
group('CacheBackedEmbeddings', () {
late InMemoryStore<String, List<double>> store;
late CacheBackedEmbeddings cacheBackedEmbeddings;

setUp(() {
store = InMemoryStore();
cacheBackedEmbeddings = CacheBackedEmbeddings(
underlyingEmbeddings: FakeEmbeddings(deterministic: false),
documentEmbeddingsStore: store,
);
});

test(
'embedDocuments returns correct embeddings, and fills missing embeddings',
() async {
final preStoreRes = await store.get(['testDoc']);
expect(preStoreRes.first, isNull);
final res1 = await cacheBackedEmbeddings.embedDocuments(['testDoc']);
final storeRes1 = await store.get(['testDoc']);
expect(res1, storeRes1);
final res2 = await cacheBackedEmbeddings.embedDocuments(['testDoc']);
expect(res2, storeRes1);
final newDocStoreRes = await store.get(['newDoc']);
expect(newDocStoreRes.first, isNull);
final res3 = await cacheBackedEmbeddings.embedDocuments(['newDoc']);
final storeRes3 = await store.get(['newDoc']);
expect(res3, storeRes3);
});

test('embedQuery is not cached', () async {
final result = await cacheBackedEmbeddings.embedQuery('testQuery');
final storeResult = await store.get(['testQuery']);
expect(result.first, isNotNull);
expect(storeResult.first, isNull);
});
});

group('CacheBackedEmbeddings.fromByteStore', () {
late InMemoryStore<String, Uint8List> store;
late CacheBackedEmbeddings cacheBackedEmbeddings;

setUp(() {
store = InMemoryStore();
cacheBackedEmbeddings = CacheBackedEmbeddings.fromByteStore(
underlyingEmbeddings: FakeEmbeddings(),
documentEmbeddingsStore: store,
);
});

test(
'embedDocuments returns correct embeddings, and fills missing embeddings',
() async {
final res1 = await cacheBackedEmbeddings.embedDocuments(['testDoc']);
final res2 = await cacheBackedEmbeddings.embedDocuments(['testDoc']);
expect(res1, res2);
final res3 = await cacheBackedEmbeddings.embedDocuments(['newDoc']);
expect(res3, isNot(res2));
});

test('embedQuery is not cached', () async {
final result = await cacheBackedEmbeddings.embedQuery('testQuery');
final storeResult = await store.get(['testQuery']);
expect(result.first, isNotNull);
expect(storeResult.first, isNull);
});
});

group('EmbeddingsByteStoreEncoder tests', () {
const namespace = 'test';
const uuid = Uuid();
const key = 'key';
final keyHash = sha1.convert(utf8.encode(key)).toString();
final expectedEncodedKey = uuid.v5(Uuid.NAMESPACE_URL, keyHash);
final value = [0.1, 0.2, 0.3];
final expectedEncodedValue =
Uint8List.fromList(utf8.encode(json.encode(value)));

const encoder = EmbeddingsByteStoreEncoder(namespace: namespace);

test('encodeKey returns encoded key', () {
final result = encoder.encodeKey(key);
expect(result, expectedEncodedKey);
});

test('encodeValue returns encoded value', () {
final result = encoder.encodeValue(value);
expect(result, expectedEncodedValue);
});

test('decodeKey throws UnimplementedError', () {
expect(
() => encoder.decodeKey('anyKey'),
throwsA(isA<UnimplementedError>()),
);
});

test('decodeValue returns decoded value', () {
final result = encoder.decodeValue(expectedEncodedValue);
expect(result, value);
});
});
}
2 changes: 1 addition & 1 deletion packages/langchain/test/storage/encoder_backed.dart
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void main() {
});
}

class SampleEncoder extends Encoder<int, String, String, String> {
class SampleEncoder implements StoreEncoder<int, String, String, String> {
@override
String encodeKey(final int key) => '$key';

Expand Down

0 comments on commit 27d8b77

Please sign in to comment.