Skip to content

Commit

Permalink
refactor(embeddings)!: Change embedDocuments input to List<Document> (
Browse files Browse the repository at this point in the history
#153)

Currently `Embeddings.embedDocuments` method takes a `List<String>` which is not ideal because:
1. It is not consistent with the name of the method
2. Some embedding models allow to pass some metadata to improve the quality of the generated embedding (e.g. VertexAI allows to pass a 'title' property)

This PR refactors the method to expect a `List<Document>` instead. Making it more aligned with the method name and allowing to pass metadata to the embedding models.
  • Loading branch information
davidmigloz committed Sep 5, 2023
1 parent 4f7161d commit 1b5d6fb
Show file tree
Hide file tree
Showing 14 changed files with 84 additions and 38 deletions.
4 changes: 3 additions & 1 deletion packages/langchain/lib/src/documents/embeddings/base.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import '../models/models.dart';

/// {@template embeddings}
/// Interface for embedding models.
/// {@endtemplate}
Expand All @@ -6,7 +8,7 @@ abstract interface class Embeddings {
const Embeddings();

/// Embed search docs.
Future<List<List<double>>> embedDocuments(final List<String> texts);
Future<List<List<double>>> embedDocuments(final List<Document> documents);

/// Embed query text.
Future<List<double>> embedQuery(final String query);
Expand Down
15 changes: 10 additions & 5 deletions packages/langchain/lib/src/documents/embeddings/cache.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import 'package:crypto/crypto.dart';
import 'package:uuid/uuid.dart';

import '../../storage/storage.dart';
import '../models/models.dart';
import 'base.dart';

/// {@template cache_backed_embeddings}
Expand Down Expand Up @@ -70,18 +71,22 @@ class CacheBackedEmbeddings implements Embeddings {
}

@override
Future<List<List<double>>> embedDocuments(final List<String> texts) async {
Future<List<List<double>>> embedDocuments(
final List<Document> documents,
) async {
final texts =
documents.map((final doc) => doc.pageContent).toList(growable: false);
final vectors = await documentEmbeddingsStore.get(texts);
final missingIndices = [
for (var i = 0; i < texts.length; i++)
if (vectors[i] == null) i,
];
final missingTexts =
missingIndices.map((final i) => texts[i]).toList(growable: false);
final missingDocs =
missingIndices.map((final i) => documents[i]).toList(growable: false);

if (missingTexts.isNotEmpty) {
if (missingDocs.isNotEmpty) {
final missingVectors =
await underlyingEmbeddings.embedDocuments(missingTexts);
await underlyingEmbeddings.embedDocuments(missingDocs);
final missingVectorPairs = missingIndices
.map((final i) => (texts[i], missingVectors[i]))
.toList(growable: false);
Expand Down
9 changes: 7 additions & 2 deletions packages/langchain/lib/src/documents/embeddings/fake.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import 'dart:math';

import 'package:crypto/crypto.dart';

import '../models/models.dart';
import 'base.dart';

/// {@template fake_embeddings}
Expand All @@ -26,8 +27,12 @@ class FakeEmbeddings implements Embeddings {
final bool deterministic;

@override
Future<List<List<double>>> embedDocuments(final List<String> texts) async {
return texts.map(_getEmbeddings).toList(growable: false);
Future<List<List<double>>> embedDocuments(
final List<Document> documents,
) async {
return documents
.map((final d) => _getEmbeddings(d.pageContent))
.toList(growable: false);
}

@override
Expand Down
4 changes: 1 addition & 3 deletions packages/langchain/lib/src/documents/vector_stores/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ abstract class VectorStore {
Future<List<String>> addDocuments({
required final List<Document> documents,
}) async {
final texts =
documents.map((final doc) => doc.pageContent).toList(growable: false);
return addVectors(
vectors: await embeddings.embedDocuments(texts),
vectors: await embeddings.embedDocuments(documents),
documents: documents,
);
}
Expand Down
2 changes: 1 addition & 1 deletion packages/langchain/test/chains/retrieval_qa_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class _FakeEmbeddings implements Embeddings {

@override
Future<List<List<double>>> embedDocuments(
final List<String> documents,
final List<Document> documents,
) async {
return List.generate(documents.length, (final i) => [0, 1 / i]);
}
Expand Down
24 changes: 18 additions & 6 deletions packages/langchain/test/documents/embeddings/cache.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,20 @@ void main() async {
() async {
final preStoreRes = await store.get(['testDoc']);
expect(preStoreRes.first, isNull);
final res1 = await cacheBackedEmbeddings.embedDocuments(['testDoc']);
final res1 = await cacheBackedEmbeddings.embedDocuments(
[const Document(pageContent: 'testDoc')],
);
final storeRes1 = await store.get(['testDoc']);
expect(res1, storeRes1);
final res2 = await cacheBackedEmbeddings.embedDocuments(['testDoc']);
final res2 = await cacheBackedEmbeddings.embedDocuments(
[const Document(pageContent: 'testDoc')],
);
expect(res2, storeRes1);
final newDocStoreRes = await store.get(['newDoc']);
expect(newDocStoreRes.first, isNull);
final res3 = await cacheBackedEmbeddings.embedDocuments(['newDoc']);
final res3 = await cacheBackedEmbeddings.embedDocuments(
[const Document(pageContent: 'newDoc')],
);
final storeRes3 = await store.get(['newDoc']);
expect(res3, storeRes3);
});
Expand Down Expand Up @@ -59,10 +65,16 @@ void main() async {
test(
'embedDocuments returns correct embeddings, and fills missing embeddings',
() async {
final res1 = await cacheBackedEmbeddings.embedDocuments(['testDoc']);
final res2 = await cacheBackedEmbeddings.embedDocuments(['testDoc']);
final res1 = await cacheBackedEmbeddings.embedDocuments(
[const Document(pageContent: 'testDoc')],
);
final res2 = await cacheBackedEmbeddings.embedDocuments(
[const Document(pageContent: 'testDoc')],
);
expect(res1, res2);
final res3 = await cacheBackedEmbeddings.embedDocuments(['newDoc']);
final res3 = await cacheBackedEmbeddings.embedDocuments(
[const Document(pageContent: 'newDoc')],
);
expect(res3, isNot(res2));
});

Expand Down
8 changes: 4 additions & 4 deletions packages/langchain/test/documents/embeddings/fake.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ void main() async {
test('Embeds a document with the same embedding vector for the same text',
() async {
final embeddings = FakeEmbeddings(size: 3);
const document1 = 'This is a document.';
const document2 = 'This is a document.';
const document1 = Document(pageContent: 'This is a document.');
const document2 = Document(pageContent: 'This is a document.');

final embedding1 = (await embeddings.embedDocuments([document1])).first;
final embedding2 = (await embeddings.embedDocuments([document2])).first;
Expand Down Expand Up @@ -53,8 +53,8 @@ void main() async {

test('If deterministic is false, embeddings are different', () async {
final embeddings = FakeEmbeddings(size: 3, deterministic: false);
const document1 = 'This is a document.';
const document2 = 'This is a document.';
const document1 = Document(pageContent: 'This is a document.');
const document2 = Document(pageContent: 'This is a document.');

final embedding1 = (await embeddings.embedDocuments([document1])).first;
final embedding2 = (await embeddings.embedDocuments([document2])).first;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ class _FakeEmbeddings implements Embeddings {

@override
Future<List<List<double>>> embedDocuments(
final List<String> documents,
final List<Document> documents,
) async {
return [
for (final document in documents) embedText(document),
for (final document in documents) embedText(document.pageContent),
];
}

Expand Down
6 changes: 4 additions & 2 deletions packages/langchain/test/memory/vector_store_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ void main() {
class _FakeEmbeddings implements Embeddings {
@override
Future<List<List<double>>> embedDocuments(
final List<String> documents,
final List<Document> documents,
) async {
return documents.map(_embed).toList(growable: false);
return documents
.map((final doc) => _embed(doc.pageContent))
.toList(growable: false);
}

@override
Expand Down
10 changes: 5 additions & 5 deletions packages/langchain_google/lib/src/embeddings/vertex_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,21 @@ class VertexAIEmbeddings implements Embeddings {

@override
Future<List<List<double>>> embedDocuments(
final List<String> documents,
final List<Document> documents,
) async {
final subDocs = chunkArray(documents, chunkSize: batchSize);
final batches = chunkArray(documents, chunkSize: batchSize);

final embeddings = await Future.wait(
subDocs.map((final docsBatch) async {
batches.map((final batch) async {
final data = await client.textEmbeddings.predict(
content: docsBatch
content: batch
.map(
(final doc) => VertexAITextEmbeddingsModelContent(
taskType: _getTaskType(
defaultTaskType:
VertexAITextEmbeddingsModelTaskType.retrievalDocument,
),
content: doc,
content: doc.pageContent,
),
)
.toList(growable: false),
Expand Down
12 changes: 11 additions & 1 deletion packages/langchain_google/test/embeddings/vertex_ai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ library; // Uses dart:io

import 'dart:io';

import 'package:langchain/langchain.dart';
import 'package:langchain_google/langchain_google.dart';
import 'package:test/test.dart';

Expand All @@ -26,7 +27,16 @@ void main() async {
project: Platform.environment['VERTEX_AI_PROJECT_ID']!,
batchSize: 1,
);
final res = await embeddings.embedDocuments(['Hello world', 'Bye bye']);
final res = await embeddings.embedDocuments([
const Document(
id: '1',
pageContent: 'Hello world',
),
const Document(
id: '2',
pageContent: 'Bye bye',
),
]);
expect(res.length, 2);
expect(res[0].length, 768);
expect(res[1].length, 768);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ void main() async {
final embeddings = VertexAIEmbeddings(
authHttpClient: authHttpClient,
project: Platform.environment['VERTEX_AI_PROJECT_ID']!,
model: 'textembedding-gecko-multilingual',
);
final vectorStore = VertexAIMatchingEngine(
authHttpClient: authHttpClient,
Expand Down Expand Up @@ -52,7 +53,7 @@ void main() async {
expect(res.length, 1);
expect(
res.first.id,
'faq_621656c96b5ff317d867d019',
'blog_62fced7e440f2d026f7d442e',
);
});

Expand Down
9 changes: 5 additions & 4 deletions packages/langchain_openai/lib/src/embeddings/openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ class OpenAIEmbeddings implements Embeddings {

@override
Future<List<List<double>>> embedDocuments(
final List<String> documents,
final List<Document> documents,
) async {
// TODO use tiktoken to chunk documents that exceed the context length of the model
final subPrompts = chunkArray(documents, chunkSize: batchSize);
final batches = chunkArray(documents, chunkSize: batchSize);

final embeddings = await Future.wait(
subPrompts.map((final input) async {
batches.map((final batch) async {
final data = await _client.createEmbeddings(
model: model,
input: input,
input:
batch.map((final doc) => doc.pageContent).toList(growable: false),
);
return data.data.map((final d) => d.embeddings);
}),
Expand Down
12 changes: 11 additions & 1 deletion packages/langchain_openai/test/embeddings/openai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ library; // Uses dart:io

import 'dart:io';

import 'package:langchain/langchain.dart';
import 'package:langchain_openai/langchain_openai.dart';
import 'package:test/test.dart';

Expand All @@ -18,7 +19,16 @@ void main() {

test('Test OpenAIEmbeddings.embedDocuments', () async {
final embeddings = OpenAIEmbeddings(apiKey: openaiApiKey, batchSize: 1);
final res = await embeddings.embedDocuments(['Hello world', 'Bye bye']);
final res = await embeddings.embedDocuments([
const Document(
id: '1',
pageContent: 'Hello world',
),
const Document(
id: '2',
pageContent: 'Bye bye',
),
]);
expect(res.length, 2);
expect(res[0].length, 1536);
expect(res[1].length, 1536);
Expand Down

0 comments on commit 1b5d6fb

Please sign in to comment.