Skip to content

Commit

Permalink
feat(embeddings): Support task type in VertexAIEmbeddings (#151)
Browse files Browse the repository at this point in the history
VertexAI embeddings models released in August support a new "task type" parameter that indicates the kind of task the embeddings is going to be used for, this helps the model produce better quality embeddings.

If you are using a model that supports it, LangChain.dart will use 'RETRIEVAL_DOCUMENT' when embedding a document and 'RETRIEVAL_QUERY' when embedding a query.
  • Loading branch information
davidmigloz committed Sep 5, 2023
1 parent eab7d96 commit 8a2199e
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 2 deletions.
16 changes: 16 additions & 0 deletions packages/langchain_google/lib/src/chat_models/vertex_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ import 'models/models.dart';
/// constant `ChatVertexAI.cloudPlatformScope`)
///
/// See: https://cloud.google.com/vertex-ai/docs/generative-ai/access-control
///
/// ## Available models
///
/// - `chat-bison`
/// * Max input token: 4096
/// * Max output tokens: 1024
/// * Training data: Up to Feb 2023
/// * Max turns: 2500
/// - `chat-bison-32k`
/// * Max input and output tokens combined: 32k
/// * Training data: Up to Aug 2023
/// * Max turns: 2500
///
/// The previous list of models may not be exhaustive or up-to-date. Check out
/// the [Vertex AI documentation](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models)
/// for the latest list of available models.
/// {@endtemplate}
class ChatVertexAI extends BaseChatModel<ChatVertexAIOptions> {
/// {@macro chat_vertex_ai}
Expand Down
47 changes: 45 additions & 2 deletions packages/langchain_google/lib/src/embeddings/vertex_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ import 'package:vertex_ai/vertex_ai.dart';
/// constant `VertexAIEmbeddings.cloudPlatformScope`)
///
/// See: https://cloud.google.com/vertex-ai/docs/generative-ai/access-control
///
/// ## Available models
///
/// - `textembedding-gecko`
/// * Max input token: 3072
/// * Output: 768-dimensional vector embeddings
/// - `textembedding-gecko-multilingual`: support over 100 non-English languages
/// * Max input token: 3072
/// * Output: 768-dimensional vector embeddings
///
/// The previous list of models may not be exhaustive or up-to-date. Check out
/// the [Vertex AI documentation](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models)
/// for the latest list of available models.
/// {@endtemplate}
class VertexAIEmbeddings implements Embeddings {
/// {@macro vertex_ai_embeddings}
Expand Down Expand Up @@ -117,7 +130,17 @@ class VertexAIEmbeddings implements Embeddings {
final embeddings = await Future.wait(
subDocs.map((final docsBatch) async {
final data = await client.textEmbeddings.predict(
content: docsBatch,
content: docsBatch
.map(
(final doc) => VertexAITextEmbeddingsModelContent(
taskType: _getTaskType(
defaultTaskType:
VertexAITextEmbeddingsModelTaskType.retrievalDocument,
),
content: doc,
),
)
.toList(growable: false),
publisher: publisher,
model: model,
);
Expand All @@ -133,10 +156,30 @@ class VertexAIEmbeddings implements Embeddings {
@override
Future<List<double>> embedQuery(final String query) async {
final data = await client.textEmbeddings.predict(
content: [query],
content: [
VertexAITextEmbeddingsModelContent(
taskType: _getTaskType(
defaultTaskType: VertexAITextEmbeddingsModelTaskType.retrievalQuery,
),
content: query,
),
],
publisher: publisher,
model: model,
);
return data.predictions.first.values;
}

VertexAITextEmbeddingsModelTaskType? _getTaskType({
required final VertexAITextEmbeddingsModelTaskType defaultTaskType,
}) {
// Models released before August 2023 do not support taskType.
// Currently 'textembedding-gecko' points to 'textembedding-gecko@001'
// Ref: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/model-versioning
if (model == 'textembedding-gecko' || model == 'textembedding-gecko@001') {
return null;
}

return VertexAITextEmbeddingsModelTaskType.retrievalDocument;
}
}
14 changes: 14 additions & 0 deletions packages/langchain_google/lib/src/llms/vertex_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ import 'models/models.dart';
/// constant `VertexAI.cloudPlatformScope`)
///
/// See: https://cloud.google.com/vertex-ai/docs/generative-ai/access-control
///
/// ## Available models
///
/// - `text-bison`
/// * Max input token: 8192
/// * Max output tokens: 1024
/// * Training data: Up to Feb 2023
/// - `text-bison-32k`
/// * Max input and output tokens combined: 32k
/// * Training data: Up to Aug 2023
///
/// The previous list of models may not be exhaustive or up-to-date. Check out
/// the [Vertex AI documentation](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models)
/// for the latest list of available models.
/// {@endtemplate}
class VertexAI extends BaseLLM<VertexAIOptions> {
/// {@macro vertex_ai}
Expand Down

0 comments on commit 8a2199e

Please sign in to comment.