-
-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add support for GoogleGenerativeAIEmbeddings (#362)
- Loading branch information
1 parent
68bfdb0
commit d4f888a
Showing
11 changed files
with
306 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
docs/modules/retrieval/text_embedding/integrations/google_ai.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Google AI Embeddings | ||
|
||
The embedding service in the [Gemini API](https://ai.google.dev/docs/embeddings_guide) generates state-of-the-art embeddings for words, phrases, and sentences. The resulting embeddings can then be used for NLP tasks, such as semantic search, text classification and clustering among many others. | ||
|
||
## Available models | ||
|
||
- `embedding-001` (default) | ||
* Optimized for creating embeddings for text of up to 2048 tokens | ||
|
||
The previous list of models may not be exhaustive or up-to-date. Check out the [Google AI documentation](https://ai.google.dev/models/gemini) for the latest list of available models. | ||
|
||
### Task type | ||
|
||
Google AI support specifying a 'task type' when embedding documents. The task type is then used by the model to improve the quality of the embeddings. | ||
|
||
This integration uses the specifies the following task type: | ||
- `retrievalDocument`: for embedding documents | ||
- `retrievalQuery`: for embedding queries | ||
|
||
## Usage | ||
|
||
```dart | ||
final apiKey = Platform.environment['GOOGLEAI_API_KEY']; | ||
final embeddings = GoogleGenerativeAIEmbeddings( | ||
apiKey: apiKey, | ||
); | ||
// Embedding a document | ||
const doc = Document(pageContent: 'This is a test document.'); | ||
final res1 = await embeddings.embedDocuments([doc]); | ||
print(res1); | ||
// [[0.05677966, 0.0030236526, -0.06441004, ...]] | ||
// Embedding a retrieval query | ||
const text = 'This is a test query.'; | ||
final res2 = await embeddings.embedQuery(text); | ||
print(res2); | ||
// [0.025963314, -0.06858828, -0.026590854, ...] | ||
embeddings.close(); | ||
``` | ||
|
||
### Title | ||
|
||
Google AI support specifying a document title when embedding documents. The title is then used by the model to improve the quality of the embeddings. | ||
|
||
To specify a document title, add the title to the document's metadata. Then, specify the metadata key in the [docTitleKey] parameter. | ||
|
||
Example: | ||
```dart | ||
final embeddings = GoogleGenerativeAIEmbeddings( | ||
apiKey: 'your-api-key', | ||
); | ||
final result = await embeddings.embedDocuments([ | ||
Document( | ||
pageContent: 'Hello world', | ||
metadata: {'title': 'Hello!'}, | ||
), | ||
]); | ||
``` |
27 changes: 27 additions & 0 deletions
27
examples/docs_examples/bin/modules/retrieval/text_embedding/integrations/google_ai.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// ignore_for_file: avoid_print | ||
import 'dart:io'; | ||
|
||
import 'package:langchain/langchain.dart'; | ||
import 'package:langchain_google/langchain_google.dart'; | ||
import 'package:langchain_openai/langchain_openai.dart'; | ||
|
||
void main(final List<String> arguments) async { | ||
final apiKey = Platform.environment['GOOGLEAI_API_KEY']; | ||
final embeddings = GoogleGenerativeAIEmbeddings( | ||
apiKey: apiKey, | ||
); | ||
|
||
// Embedding a document | ||
const doc = Document(pageContent: 'This is a test document.'); | ||
final res1 = await embeddings.embedDocuments([doc]); | ||
print(res1); | ||
// [[0.05677966, 0.0030236526, -0.06441004, ...]] | ||
|
||
// Embedding a retrieval query | ||
const text = 'This is a test query.'; | ||
final res2 = await embeddings.embedQuery(text); | ||
print(res2); | ||
// [0.025963314, -0.06858828, -0.026590854, ...] | ||
|
||
embeddings.close(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
export 'vertex_ai.dart'; | ||
export 'google_ai/google_ai_embeddings.dart'; | ||
export 'vertex_ai/vertex_ai_embeddings.dart'; |
167 changes: 167 additions & 0 deletions
167
packages/langchain_google/lib/src/embeddings/google_ai/google_ai_embeddings.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import 'package:collection/collection.dart'; | ||
import 'package:googleai_dart/googleai_dart.dart'; | ||
import 'package:http/http.dart' as http; | ||
import 'package:langchain_core/documents.dart'; | ||
import 'package:langchain_core/embeddings.dart'; | ||
import 'package:langchain_core/utils.dart'; | ||
|
||
/// {@template google_generative_ai_embeddings} | ||
/// Wrapper around Google AI embedding models API | ||
/// | ||
/// Example: | ||
/// ```dart | ||
/// final embeddings = GoogleGenerativeAIEmbeddings( | ||
/// apiKey: 'your-api-key', | ||
/// ); | ||
/// final result = await embeddings.embedQuery('Hello world'); | ||
/// ``` | ||
/// | ||
/// Google AI documentation: https://ai.google.dev/ | ||
/// | ||
/// ### Available models | ||
/// | ||
/// - `embedding-001` | ||
/// * Optimized for creating embeddings for text of up to 2048 tokens | ||
/// | ||
/// The previous list of models may not be exhaustive or up-to-date. Check out | ||
/// the [Google AI documentation](https://ai.google.dev/models/gemini) | ||
/// for the latest list of available models. | ||
/// | ||
/// ### Task type | ||
/// | ||
/// Google AI support specifying a 'task type' when embedding documents. | ||
/// The task type is then used by the model to improve the quality of the | ||
/// embeddings. | ||
/// | ||
/// This class uses the specifies the following task type: | ||
/// - `retrievalDocument`: for embedding documents | ||
/// - `retrievalQuery`: for embedding queries | ||
/// | ||
/// ### Title | ||
/// | ||
/// Google AI support specifying a document title when embedding documents. | ||
/// The title is then used by the model to improve the quality of the | ||
/// embeddings. | ||
/// | ||
/// To specify a document title, add the title to the document's metadata. | ||
/// Then, specify the metadata key in the [docTitleKey] parameter. | ||
/// | ||
/// Example: | ||
/// ```dart | ||
/// final embeddings = GoogleGenerativeAIEmbeddings( | ||
/// apiKey: 'your-api-key', | ||
/// ); | ||
/// final result = await embeddings.embedDocuments([ | ||
/// Document( | ||
/// pageContent: 'Hello world', | ||
/// metadata: {'title': 'Hello!'}, | ||
/// ), | ||
/// ]); | ||
/// ``` | ||
/// {@endtemplate} | ||
class GoogleGenerativeAIEmbeddings implements Embeddings { | ||
/// Create a new [GoogleGenerativeAIEmbeddings] instance. | ||
/// | ||
/// Main configuration options: | ||
/// - `apiKey`: your Google AI API key. You can find your API key in the | ||
/// [Google AI Studio dashboard](https://makersuite.google.com/app/apikey). | ||
/// - [GoogleGenerativeAIEmbeddings.model] | ||
/// - [GoogleGenerativeAIEmbeddings.batchSize] | ||
/// - [GoogleGenerativeAIEmbeddings.docTitleKey] | ||
/// | ||
/// Advance configuration options: | ||
/// - `baseUrl`: the base URL to use. Defaults to Google AI's API URL. You can | ||
/// override this to use a different API URL, or to use a proxy. | ||
/// - `headers`: global headers to send with every request. You can use | ||
/// this to set custom headers, or to override the default headers. | ||
/// - `queryParams`: global query parameters to send with every request. You | ||
/// can use this to set custom query parameters. | ||
/// - `client`: the HTTP client to use. You can set your own HTTP client if | ||
/// you need further customization (e.g. to use a Socks5 proxy). | ||
GoogleGenerativeAIEmbeddings({ | ||
final String? apiKey, | ||
final String? baseUrl, | ||
final Map<String, String>? headers, | ||
final Map<String, dynamic>? queryParams, | ||
final http.Client? client, | ||
this.model = 'embedding-001', | ||
this.batchSize = 100, | ||
this.docTitleKey = 'title', | ||
}) : _client = GoogleAIClient( | ||
apiKey: apiKey, | ||
baseUrl: baseUrl, | ||
headers: headers, | ||
queryParams: queryParams, | ||
client: client, | ||
); | ||
|
||
/// A client for interacting with Google AI API. | ||
final GoogleAIClient _client; | ||
|
||
/// The embeddings model to use. | ||
/// | ||
/// You can find a list of available embedding models here: | ||
/// https://ai.google.dev/models/gemini | ||
final String model; | ||
|
||
/// The maximum number of documents to embed in a single request. | ||
final int batchSize; | ||
|
||
/// The metadata key used to store the document's (optional) title. | ||
final String docTitleKey; | ||
|
||
/// Set or replace the API key. | ||
set apiKey(final String value) => _client.apiKey = value; | ||
|
||
/// Get the API key. | ||
String get apiKey => _client.apiKey; | ||
|
||
@override | ||
Future<List<List<double>>> embedDocuments( | ||
final List<Document> documents, | ||
) async { | ||
final batches = chunkArray(documents, chunkSize: batchSize); | ||
|
||
final List<List<List<double>>> embeddings = await Future.wait( | ||
batches.map((final batch) async { | ||
final data = await _client.batchEmbedContents( | ||
modelId: model, | ||
request: BatchEmbedContentsRequest( | ||
requests: batch.map((final doc) { | ||
return EmbedContentRequest( | ||
title: doc.metadata[docTitleKey], | ||
content: Content(parts: [Part(text: doc.pageContent)]), | ||
taskType: EmbedContentRequestTaskType.retrievalDocument, | ||
model: 'models/$model', | ||
); | ||
}).toList(growable: false), | ||
), | ||
); | ||
return data.embeddings | ||
?.map((final p) => p.values) | ||
.whereNotNull() | ||
.toList(growable: false) ?? | ||
const []; | ||
}), | ||
); | ||
|
||
return embeddings.expand((final e) => e).toList(growable: false); | ||
} | ||
|
||
@override | ||
Future<List<double>> embedQuery(final String query) async { | ||
final data = await _client.embedContent( | ||
modelId: model, | ||
request: EmbedContentRequest( | ||
content: Content(parts: [Part(text: query)]), | ||
taskType: EmbedContentRequestTaskType.retrievalQuery, | ||
), | ||
); | ||
return data.embedding?.values ?? const []; | ||
} | ||
|
||
/// Closes the client and cleans up any resources associated with it. | ||
void close() { | ||
_client.endSession(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45 changes: 45 additions & 0 deletions
45
packages/langchain_google/test/embeddings/google_ai/google_ai_embeddings_test.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
@TestOn('vm') | ||
library; // Uses dart:io | ||
|
||
import 'dart:io'; | ||
|
||
import 'package:langchain_core/documents.dart'; | ||
import 'package:langchain_google/langchain_google.dart'; | ||
import 'package:test/test.dart'; | ||
|
||
void main() { | ||
group('GoogleGenerativeAIEmbeddings tests', () { | ||
late GoogleGenerativeAIEmbeddings embeddings; | ||
|
||
setUp(() async { | ||
embeddings = GoogleGenerativeAIEmbeddings( | ||
apiKey: Platform.environment['GOOGLEAI_API_KEY'], | ||
); | ||
}); | ||
|
||
tearDown(() { | ||
embeddings.close(); | ||
}); | ||
|
||
test('Test GoogleGenerativeAIEmbeddings.embedQuery', () async { | ||
final res = await embeddings.embedQuery('Hello world'); | ||
expect(res.length, 768); | ||
}); | ||
|
||
test('Test GoogleGenerativeAIEmbeddings.embedDocuments', () async { | ||
final res = await embeddings.embedDocuments([ | ||
const Document( | ||
id: '1', | ||
pageContent: 'Hello world', | ||
), | ||
const Document( | ||
id: '2', | ||
pageContent: 'Bye bye', | ||
), | ||
]); | ||
expect(res.length, 2); | ||
expect(res[0].length, 768); | ||
expect(res[1].length, 768); | ||
}); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters