diff --git a/packages/langchain/lib/src/documents/embeddings/base.dart b/packages/langchain/lib/src/documents/embeddings/base.dart index 62791037..b6986699 100644 --- a/packages/langchain/lib/src/documents/embeddings/base.dart +++ b/packages/langchain/lib/src/documents/embeddings/base.dart @@ -1,3 +1,5 @@ +import '../models/models.dart'; + /// {@template embeddings} /// Interface for embedding models. /// {@endtemplate} @@ -6,7 +8,7 @@ abstract interface class Embeddings { const Embeddings(); /// Embed search docs. - Future<List<List<double>>> embedDocuments(final List<String> texts); + Future<List<List<double>>> embedDocuments(final List<Document> documents); /// Embed query text. Future<List<double>> embedQuery(final String query); diff --git a/packages/langchain/lib/src/documents/embeddings/cache.dart b/packages/langchain/lib/src/documents/embeddings/cache.dart index dfd72548..487b3f4d 100644 --- a/packages/langchain/lib/src/documents/embeddings/cache.dart +++ b/packages/langchain/lib/src/documents/embeddings/cache.dart @@ -5,6 +5,7 @@ import 'package:crypto/crypto.dart'; import 'package:uuid/uuid.dart'; import '../../storage/storage.dart'; +import '../models/models.dart'; import 'base.dart'; /// {@template cache_backed_embeddings} @@ -70,18 +71,22 @@ class CacheBackedEmbeddings implements Embeddings { } @override - Future<List<List<double>>> embedDocuments(final List<String> texts) async { + Future<List<List<double>>> embedDocuments( + final List<Document> documents, + ) async { + final texts = + documents.map((final doc) => doc.pageContent).toList(growable: false); final vectors = await documentEmbeddingsStore.get(texts); final missingIndices = [ for (var i = 0; i < texts.length; i++) if (vectors[i] == null) i, ]; - final missingTexts = - missingIndices.map((final i) => texts[i]).toList(growable: false); + final missingDocs = + missingIndices.map((final i) => documents[i]).toList(growable: false); - if (missingTexts.isNotEmpty) { + if (missingDocs.isNotEmpty) { final missingVectors = - await underlyingEmbeddings.embedDocuments(missingTexts); + await underlyingEmbeddings.embedDocuments(missingDocs); final missingVectorPairs = missingIndices .map((final i) => (texts[i], missingVectors[i])) .toList(growable: false); diff --git a/packages/langchain/lib/src/documents/embeddings/fake.dart b/packages/langchain/lib/src/documents/embeddings/fake.dart index 3a23a815..e823bfc4 100644 --- a/packages/langchain/lib/src/documents/embeddings/fake.dart +++ b/packages/langchain/lib/src/documents/embeddings/fake.dart @@ -3,6 +3,7 @@ import 'dart:math'; import 'package:crypto/crypto.dart'; +import '../models/models.dart'; import 'base.dart'; /// {@template fake_embeddings} @@ -26,8 +27,12 @@ class FakeEmbeddings implements Embeddings { final bool deterministic; @override - Future<List<List<double>>> embedDocuments(final List<String> texts) async { - return texts.map(_getEmbeddings).toList(growable: false); + Future<List<List<double>>> embedDocuments( + final List<Document> documents, + ) async { + return documents + .map((final d) => _getEmbeddings(d.pageContent)) + .toList(growable: false); } @override diff --git a/packages/langchain/lib/src/documents/vector_stores/base.dart b/packages/langchain/lib/src/documents/vector_stores/base.dart index e3fb9c19..7408b1cc 100644 --- a/packages/langchain/lib/src/documents/vector_stores/base.dart +++ b/packages/langchain/lib/src/documents/vector_stores/base.dart @@ -24,10 +24,8 @@ abstract class VectorStore { Future<List<String>> addDocuments({ required final List<Document> documents, }) async { - final texts = - documents.map((final doc) => doc.pageContent).toList(growable: false); return addVectors( - vectors: await embeddings.embedDocuments(texts), + vectors: await embeddings.embedDocuments(documents), documents: documents, ); } diff --git a/packages/langchain/test/chains/retrieval_qa_test.dart b/packages/langchain/test/chains/retrieval_qa_test.dart index 0a3a57fc..40d3ca42 100644 --- a/packages/langchain/test/chains/retrieval_qa_test.dart +++ b/packages/langchain/test/chains/retrieval_qa_test.dart @@ -92,7 +92,7 @@ class _FakeEmbeddings implements Embeddings { @override Future<List<List<double>>> embedDocuments( - final List<String> documents, + final List<Document> documents, ) async { return List.generate(documents.length, (final i) => [0, 1 / i]); } diff --git a/packages/langchain/test/documents/embeddings/cache.dart b/packages/langchain/test/documents/embeddings/cache.dart index 57e1c9ac..c29f4e41 100644 --- a/packages/langchain/test/documents/embeddings/cache.dart +++ b/packages/langchain/test/documents/embeddings/cache.dart @@ -24,14 +24,20 @@ void main() async { () async { final preStoreRes = await store.get(['testDoc']); expect(preStoreRes.first, isNull); - final res1 = await cacheBackedEmbeddings.embedDocuments(['testDoc']); + final res1 = await cacheBackedEmbeddings.embedDocuments( + [const Document(pageContent: 'testDoc')], + ); final storeRes1 = await store.get(['testDoc']); expect(res1, storeRes1); - final res2 = await cacheBackedEmbeddings.embedDocuments(['testDoc']); + final res2 = await cacheBackedEmbeddings.embedDocuments( + [const Document(pageContent: 'testDoc')], + ); expect(res2, storeRes1); final newDocStoreRes = await store.get(['newDoc']); expect(newDocStoreRes.first, isNull); - final res3 = await cacheBackedEmbeddings.embedDocuments(['newDoc']); + final res3 = await cacheBackedEmbeddings.embedDocuments( + [const Document(pageContent: 'newDoc')], + ); final storeRes3 = await store.get(['newDoc']); expect(res3, storeRes3); }); @@ -59,10 +65,16 @@ void main() async { test( 'embedDocuments returns correct embeddings, and fills missing embeddings', () async { - final res1 = await cacheBackedEmbeddings.embedDocuments(['testDoc']); - final res2 = await cacheBackedEmbeddings.embedDocuments(['testDoc']); + final res1 = await cacheBackedEmbeddings.embedDocuments( + [const Document(pageContent: 'testDoc')], + ); + final res2 = await cacheBackedEmbeddings.embedDocuments( + [const Document(pageContent: 'testDoc')], + ); expect(res1, res2); - final res3 = await cacheBackedEmbeddings.embedDocuments(['newDoc']); + final res3 = await cacheBackedEmbeddings.embedDocuments( + [const Document(pageContent: 'newDoc')], + ); expect(res3, isNot(res2)); }); diff --git a/packages/langchain/test/documents/embeddings/fake.dart b/packages/langchain/test/documents/embeddings/fake.dart index 19381f1a..2d4592e2 100644 --- a/packages/langchain/test/documents/embeddings/fake.dart +++ b/packages/langchain/test/documents/embeddings/fake.dart @@ -12,8 +12,8 @@ void main() async { test('Embeds a document with the same embedding vector for the same text', () async { final embeddings = FakeEmbeddings(size: 3); - const document1 = 'This is a document.'; - const document2 = 'This is a document.'; + const document1 = Document(pageContent: 'This is a document.'); + const document2 = Document(pageContent: 'This is a document.'); final embedding1 = (await embeddings.embedDocuments([document1])).first; final embedding2 = (await embeddings.embedDocuments([document2])).first; @@ -53,8 +53,8 @@ void main() async { test('If deterministic is false, embeddings are different', () async { final embeddings = FakeEmbeddings(size: 3, deterministic: false); - const document1 = 'This is a document.'; - const document2 = 'This is a document.'; + const document1 = Document(pageContent: 'This is a document.'); + const document2 = Document(pageContent: 'This is a document.'); final embedding1 = (await embeddings.embedDocuments([document1])).first; final embedding2 = (await embeddings.embedDocuments([document2])).first; diff --git a/packages/langchain/test/documents/vector_stores/memory_test.dart b/packages/langchain/test/documents/vector_stores/memory_test.dart index ac7aa1bb..89df1c6b 100644 --- a/packages/langchain/test/documents/vector_stores/memory_test.dart +++ b/packages/langchain/test/documents/vector_stores/memory_test.dart @@ -243,10 +243,10 @@ class _FakeEmbeddings implements Embeddings { @override Future<List<List<double>>> embedDocuments( - final List<String> documents, + final List<Document> documents, ) async { return [ - for (final document in documents) embedText(document), + for (final document in documents) embedText(document.pageContent), ]; } diff --git a/packages/langchain/test/memory/vector_store_test.dart b/packages/langchain/test/memory/vector_store_test.dart index aeddaac2..aa07333e 100644 --- a/packages/langchain/test/memory/vector_store_test.dart +++ b/packages/langchain/test/memory/vector_store_test.dart @@ -83,9 +83,11 @@ void main() { class _FakeEmbeddings implements Embeddings { @override Future<List<List<double>>> embedDocuments( - final List<String> documents, + final List<Document> documents, ) async { - return documents.map(_embed).toList(growable: false); + return documents + .map((final doc) => _embed(doc.pageContent)) + .toList(growable: false); } @override diff --git a/packages/langchain_google/lib/src/embeddings/vertex_ai.dart b/packages/langchain_google/lib/src/embeddings/vertex_ai.dart index f71d6795..00c75889 100644 --- a/packages/langchain_google/lib/src/embeddings/vertex_ai.dart +++ b/packages/langchain_google/lib/src/embeddings/vertex_ai.dart @@ -123,21 +123,21 @@ class VertexAIEmbeddings implements Embeddings { @override Future<List<List<double>>> embedDocuments( - final List<String> documents, + final List<Document> documents, ) async { - final subDocs = chunkArray(documents, chunkSize: batchSize); + final batches = chunkArray(documents, chunkSize: batchSize); final embeddings = await Future.wait( - subDocs.map((final docsBatch) async { + batches.map((final batch) async { final data = await client.textEmbeddings.predict( - content: docsBatch + content: batch .map( (final doc) => VertexAITextEmbeddingsModelContent( taskType: _getTaskType( defaultTaskType: VertexAITextEmbeddingsModelTaskType.retrievalDocument, ), - content: doc, + content: doc.pageContent, ), ) .toList(growable: false), diff --git a/packages/langchain_google/test/embeddings/vertex_ai_test.dart b/packages/langchain_google/test/embeddings/vertex_ai_test.dart index c9bee55b..9eaea504 100644 --- a/packages/langchain_google/test/embeddings/vertex_ai_test.dart +++ b/packages/langchain_google/test/embeddings/vertex_ai_test.dart @@ -3,6 +3,7 @@ library; // Uses dart:io import 'dart:io'; +import 'package:langchain/langchain.dart'; import 'package:langchain_google/langchain_google.dart'; import 'package:test/test.dart'; @@ -26,7 +27,16 @@ void main() async { project: Platform.environment['VERTEX_AI_PROJECT_ID']!, batchSize: 1, ); - final res = await embeddings.embedDocuments(['Hello world', 'Bye bye']); + final res = await embeddings.embedDocuments([ + const Document( + id: '1', + pageContent: 'Hello world', + ), + const Document( + id: '2', + pageContent: 'Bye bye', + ), + ]); expect(res.length, 2); expect(res[0].length, 768); expect(res[1].length, 768); diff --git a/packages/langchain_google/test/vector_stores/matching_engine.dart b/packages/langchain_google/test/vector_stores/matching_engine_test.dart similarity index 96% rename from packages/langchain_google/test/vector_stores/matching_engine.dart rename to packages/langchain_google/test/vector_stores/matching_engine_test.dart index 9ecee647..e89ec69a 100644 --- a/packages/langchain_google/test/vector_stores/matching_engine.dart +++ b/packages/langchain_google/test/vector_stores/matching_engine_test.dart @@ -14,6 +14,7 @@ void main() async { final embeddings = VertexAIEmbeddings( authHttpClient: authHttpClient, project: Platform.environment['VERTEX_AI_PROJECT_ID']!, + model: 'textembedding-gecko-multilingual', ); final vectorStore = VertexAIMatchingEngine( authHttpClient: authHttpClient, @@ -52,7 +53,7 @@ void main() async { expect(res.length, 1); expect( res.first.id, - 'faq_621656c96b5ff317d867d019', + 'blog_62fced7e440f2d026f7d442e', ); }); diff --git a/packages/langchain_openai/lib/src/embeddings/openai.dart b/packages/langchain_openai/lib/src/embeddings/openai.dart index 0193ea9a..b031fac3 100644 --- a/packages/langchain_openai/lib/src/embeddings/openai.dart +++ b/packages/langchain_openai/lib/src/embeddings/openai.dart @@ -39,16 +39,17 @@ class OpenAIEmbeddings implements Embeddings { @override Future<List<List<double>>> embedDocuments( - final List<String> documents, + final List<Document> documents, ) async { // TODO use tiktoken to chunk documents that exceed the context length of the model - final subPrompts = chunkArray(documents, chunkSize: batchSize); + final batches = chunkArray(documents, chunkSize: batchSize); final embeddings = await Future.wait( - subPrompts.map((final input) async { + batches.map((final batch) async { final data = await _client.createEmbeddings( model: model, - input: input, + input: + batch.map((final doc) => doc.pageContent).toList(growable: false), ); return data.data.map((final d) => d.embeddings); }), diff --git a/packages/langchain_openai/test/embeddings/openai_test.dart b/packages/langchain_openai/test/embeddings/openai_test.dart index 75d62e73..3a2fab31 100644 --- a/packages/langchain_openai/test/embeddings/openai_test.dart +++ b/packages/langchain_openai/test/embeddings/openai_test.dart @@ -3,6 +3,7 @@ library; // Uses dart:io import 'dart:io'; +import 'package:langchain/langchain.dart'; import 'package:langchain_openai/langchain_openai.dart'; import 'package:test/test.dart'; @@ -18,7 +19,16 @@ void main() { test('Test OpenAIEmbeddings.embedDocuments', () async { final embeddings = OpenAIEmbeddings(apiKey: openaiApiKey, batchSize: 1); - final res = await embeddings.embedDocuments(['Hello world', 'Bye bye']); + final res = await embeddings.embedDocuments([ + const Document( + id: '1', + pageContent: 'Hello world', + ), + const Document( + id: '2', + pageContent: 'Bye bye', + ), + ]); expect(res.length, 2); expect(res[0].length, 1536); expect(res[1].length, 1536);