Skip to content

Commit

Permalink
feat: Add support for GoogleGenerativeAIEmbeddings (#362)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Apr 3, 2024
1 parent 68bfdb0 commit d4f888a
Show file tree
Hide file tree
Showing 11 changed files with 306 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/_sidebar.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
- Integrations
- [OpenAI](/modules/retrieval/text_embedding/integrations/openai.md)
- [GCP Vertex AI](/modules/retrieval/text_embedding/integrations/gcp_vertex_ai.md)
- [Google AI](/modules/retrieval/text_embedding/integrations/google_ai.md)
- [Ollama](/modules/retrieval/text_embedding/integrations/ollama.md)
- [Mistral AI](/modules/retrieval/text_embedding/integrations/mistralai.md)
- [Together AI](/modules/retrieval/text_embedding/integrations/together_ai.md)
Expand Down
60 changes: 60 additions & 0 deletions docs/modules/retrieval/text_embedding/integrations/google_ai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Google AI Embeddings

The embedding service in the [Gemini API](https://ai.google.dev/docs/embeddings_guide) generates state-of-the-art embeddings for words, phrases, and sentences. The resulting embeddings can then be used for NLP tasks, such as semantic search, text classification and clustering among many others.

## Available models

- `embedding-001` (default)
* Optimized for creating embeddings for text of up to 2048 tokens

The previous list of models may not be exhaustive or up-to-date. Check out the [Google AI documentation](https://ai.google.dev/models/gemini) for the latest list of available models.

### Task type

Google AI support specifying a 'task type' when embedding documents. The task type is then used by the model to improve the quality of the embeddings.

This integration uses the specifies the following task type:
- `retrievalDocument`: for embedding documents
- `retrievalQuery`: for embedding queries

## Usage

```dart
final apiKey = Platform.environment['GOOGLEAI_API_KEY'];
final embeddings = GoogleGenerativeAIEmbeddings(
apiKey: apiKey,
);
// Embedding a document
const doc = Document(pageContent: 'This is a test document.');
final res1 = await embeddings.embedDocuments([doc]);
print(res1);
// [[0.05677966, 0.0030236526, -0.06441004, ...]]
// Embedding a retrieval query
const text = 'This is a test query.';
final res2 = await embeddings.embedQuery(text);
print(res2);
// [0.025963314, -0.06858828, -0.026590854, ...]
embeddings.close();
```

### Title

Google AI support specifying a document title when embedding documents. The title is then used by the model to improve the quality of the embeddings.

To specify a document title, add the title to the document's metadata. Then, specify the metadata key in the [docTitleKey] parameter.

Example:
```dart
final embeddings = GoogleGenerativeAIEmbeddings(
apiKey: 'your-api-key',
);
final result = await embeddings.embedDocuments([
Document(
pageContent: 'Hello world',
metadata: {'title': 'Hello!'},
),
]);
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// ignore_for_file: avoid_print
import 'dart:io';

import 'package:langchain/langchain.dart';
import 'package:langchain_google/langchain_google.dart';
import 'package:langchain_openai/langchain_openai.dart';

void main(final List<String> arguments) async {
final apiKey = Platform.environment['GOOGLEAI_API_KEY'];
final embeddings = GoogleGenerativeAIEmbeddings(
apiKey: apiKey,
);

// Embedding a document
const doc = Document(pageContent: 'This is a test document.');
final res1 = await embeddings.embedDocuments([doc]);
print(res1);
// [[0.05677966, 0.0030236526, -0.06441004, ...]]

// Embedding a retrieval query
const text = 'This is a test query.';
final res2 = await embeddings.embedQuery(text);
print(res2);
// [0.025963314, -0.06858828, -0.026590854, ...]

embeddings.close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class ChatGoogleGenerativeAI
/// - [ChatGoogleGenerativeAI.defaultOptions]
///
/// Advance configuration options:
/// - `baseUrl`: the base URL to use. Defaults to Mistral AI's API URL. You can
/// - `baseUrl`: the base URL to use. Defaults to Google AI's API URL. You can
/// override this to use a different API URL, or to use a proxy.
/// - `headers`: global headers to send with every request. You can use
/// this to set custom headers, or to override the default headers.
Expand Down
3 changes: 2 additions & 1 deletion packages/langchain_google/lib/src/embeddings/embeddings.dart
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export 'vertex_ai.dart';
export 'google_ai/google_ai_embeddings.dart';
export 'vertex_ai/vertex_ai_embeddings.dart';
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import 'package:collection/collection.dart';
import 'package:googleai_dart/googleai_dart.dart';
import 'package:http/http.dart' as http;
import 'package:langchain_core/documents.dart';
import 'package:langchain_core/embeddings.dart';
import 'package:langchain_core/utils.dart';

/// {@template google_generative_ai_embeddings}
/// Wrapper around Google AI embedding models API
///
/// Example:
/// ```dart
/// final embeddings = GoogleGenerativeAIEmbeddings(
/// apiKey: 'your-api-key',
/// );
/// final result = await embeddings.embedQuery('Hello world');
/// ```
///
/// Google AI documentation: https://ai.google.dev/
///
/// ### Available models
///
/// - `embedding-001`
/// * Optimized for creating embeddings for text of up to 2048 tokens
///
/// The previous list of models may not be exhaustive or up-to-date. Check out
/// the [Google AI documentation](https://ai.google.dev/models/gemini)
/// for the latest list of available models.
///
/// ### Task type
///
/// Google AI support specifying a 'task type' when embedding documents.
/// The task type is then used by the model to improve the quality of the
/// embeddings.
///
/// This class uses the specifies the following task type:
/// - `retrievalDocument`: for embedding documents
/// - `retrievalQuery`: for embedding queries
///
/// ### Title
///
/// Google AI support specifying a document title when embedding documents.
/// The title is then used by the model to improve the quality of the
/// embeddings.
///
/// To specify a document title, add the title to the document's metadata.
/// Then, specify the metadata key in the [docTitleKey] parameter.
///
/// Example:
/// ```dart
/// final embeddings = GoogleGenerativeAIEmbeddings(
/// apiKey: 'your-api-key',
/// );
/// final result = await embeddings.embedDocuments([
/// Document(
/// pageContent: 'Hello world',
/// metadata: {'title': 'Hello!'},
/// ),
/// ]);
/// ```
/// {@endtemplate}
class GoogleGenerativeAIEmbeddings implements Embeddings {
/// Create a new [GoogleGenerativeAIEmbeddings] instance.
///
/// Main configuration options:
/// - `apiKey`: your Google AI API key. You can find your API key in the
/// [Google AI Studio dashboard](https://makersuite.google.com/app/apikey).
/// - [GoogleGenerativeAIEmbeddings.model]
/// - [GoogleGenerativeAIEmbeddings.batchSize]
/// - [GoogleGenerativeAIEmbeddings.docTitleKey]
///
/// Advance configuration options:
/// - `baseUrl`: the base URL to use. Defaults to Google AI's API URL. You can
/// override this to use a different API URL, or to use a proxy.
/// - `headers`: global headers to send with every request. You can use
/// this to set custom headers, or to override the default headers.
/// - `queryParams`: global query parameters to send with every request. You
/// can use this to set custom query parameters.
/// - `client`: the HTTP client to use. You can set your own HTTP client if
/// you need further customization (e.g. to use a Socks5 proxy).
GoogleGenerativeAIEmbeddings({
final String? apiKey,
final String? baseUrl,
final Map<String, String>? headers,
final Map<String, dynamic>? queryParams,
final http.Client? client,
this.model = 'embedding-001',
this.batchSize = 100,
this.docTitleKey = 'title',
}) : _client = GoogleAIClient(
apiKey: apiKey,
baseUrl: baseUrl,
headers: headers,
queryParams: queryParams,
client: client,
);

/// A client for interacting with Google AI API.
final GoogleAIClient _client;

/// The embeddings model to use.
///
/// You can find a list of available embedding models here:
/// https://ai.google.dev/models/gemini
final String model;

/// The maximum number of documents to embed in a single request.
final int batchSize;

/// The metadata key used to store the document's (optional) title.
final String docTitleKey;

/// Set or replace the API key.
set apiKey(final String value) => _client.apiKey = value;

/// Get the API key.
String get apiKey => _client.apiKey;

@override
Future<List<List<double>>> embedDocuments(
final List<Document> documents,
) async {
final batches = chunkArray(documents, chunkSize: batchSize);

final List<List<List<double>>> embeddings = await Future.wait(
batches.map((final batch) async {
final data = await _client.batchEmbedContents(
modelId: model,
request: BatchEmbedContentsRequest(
requests: batch.map((final doc) {
return EmbedContentRequest(
title: doc.metadata[docTitleKey],
content: Content(parts: [Part(text: doc.pageContent)]),
taskType: EmbedContentRequestTaskType.retrievalDocument,
model: 'models/$model',
);
}).toList(growable: false),
),
);
return data.embeddings
?.map((final p) => p.values)
.whereNotNull()
.toList(growable: false) ??
const [];
}),
);

return embeddings.expand((final e) => e).toList(growable: false);
}

@override
Future<List<double>> embedQuery(final String query) async {
final data = await _client.embedContent(
modelId: model,
request: EmbedContentRequest(
content: Content(parts: [Part(text: query)]),
taskType: EmbedContentRequestTaskType.retrievalQuery,
),
);
return data.embedding?.values ?? const [];
}

/// Closes the client and cleans up any resources associated with it.
void close() {
_client.endSession();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,6 @@ class VertexAIEmbeddings implements Embeddings {
return null;
}

return VertexAITextEmbeddingsModelTaskType.retrievalDocument;
return defaultTaskType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import 'package:langchain_core/prompts.dart';
import 'package:langchain_google/langchain_google.dart';
import 'package:test/test.dart';

import '../utils/auth.dart';
import '../../utils/auth.dart';

void main() async {
final authHttpClient = await getAuthHttpClient();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
@TestOn('vm')
library; // Uses dart:io

import 'dart:io';

import 'package:langchain_core/documents.dart';
import 'package:langchain_google/langchain_google.dart';
import 'package:test/test.dart';

void main() {
group('GoogleGenerativeAIEmbeddings tests', () {
late GoogleGenerativeAIEmbeddings embeddings;

setUp(() async {
embeddings = GoogleGenerativeAIEmbeddings(
apiKey: Platform.environment['GOOGLEAI_API_KEY'],
);
});

tearDown(() {
embeddings.close();
});

test('Test GoogleGenerativeAIEmbeddings.embedQuery', () async {
final res = await embeddings.embedQuery('Hello world');
expect(res.length, 768);
});

test('Test GoogleGenerativeAIEmbeddings.embedDocuments', () async {
final res = await embeddings.embedDocuments([
const Document(
id: '1',
pageContent: 'Hello world',
),
const Document(
id: '2',
pageContent: 'Bye bye',
),
]);
expect(res.length, 2);
expect(res[0].length, 768);
expect(res[1].length, 768);
});
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import 'package:langchain_core/documents.dart';
import 'package:langchain_google/langchain_google.dart';
import 'package:test/test.dart';

import '../utils/auth.dart';
import '../../utils/auth.dart';

void main() async {
final authHttpClient = await getAuthHttpClient();
Expand Down

0 comments on commit d4f888a

Please sign in to comment.