Skip to content

Commit

Permalink
feat(openai_dart): Support different embedding response formats (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Nov 1, 2023
1 parent fa5d032 commit 4f676e8
Show file tree
Hide file tree
Showing 14 changed files with 630 additions and 58 deletions.
16 changes: 15 additions & 1 deletion packages/openai_dart/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ final res = await client.createEmbedding(
input: EmbeddingInput.string('The food was delicious and the waiter...'),
),
);
print(res.data.first.embedding);
print(res.data.first.embeddingVector);
// [0.002253932, -0.009333183, 0.01574578, -0.007790351, -0.004711035, ...]
```

Expand All @@ -251,6 +251,20 @@ print(res.data.first.embedding);
- `EmbeddingInput.arrayString(['input'])`: batch of string inputs.
- `EmbeddingInput.array([[...]])`: batch of tokenized inputs.

You can also request the embedding vector encoded as a base64 string:

```dart
final res = await client.createEmbedding(
request: CreateEmbeddingRequest(
model: EmbeddingModel.string('text-embedding-ada-002'),
input: EmbeddingInput.string('The food was delicious and the waiter...'),
encodingFormat: EmbeddingEncodingFormat.base64,
),
);
print(res.data.first.embeddingVectorBase64);
// tLYTOzXqGLxL/YA8M0b/uwdfmrsdNXM8iJIfvEOOHL3IJeK7Ok3rv...
```

### Fine-tuning

Manage fine-tuning jobs to tailor a model to your specific training data.
Expand Down
2 changes: 1 addition & 1 deletion packages/openai_dart/example/openai_dart_example.dart
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Future<void> _embeddings(final OpenAIClient client) async {
input: EmbeddingInput.string('The food was delicious and the waiter...'),
),
);
print(res.data.first.embedding);
print(res.data.first.embeddingVector);
}

Future<void> _fineTuning(final OpenAIClient client) async {
Expand Down
1 change: 1 addition & 0 deletions packages/openai_dart/lib/openai_dart.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
library;

export 'src/client.dart';
export 'src/extensions.dart';
export 'src/generated/client.dart' show OpenAIClientException;
export 'src/generated/schema/schema.dart';
31 changes: 31 additions & 0 deletions packages/openai_dart/lib/src/extensions.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import 'generated/schema/schema.dart';

extension EmbeddingX on Embedding {
/// The embedding vector as a list of doubles.
///
/// You can only use this field if you created the embedding with
/// [CreateEmbeddingRequest.encodingFormat] set to
/// [EmbeddingEncodingFormat.float].
List<double> get embeddingVector {
return embedding.mapOrNull(
string: (final s) => throw ArgumentError(
'`encodingFormat` is not set to `float` in the create embedding request',
),
arrayNumber: (final a) => a.value,
)!;
}

/// The embedding vector as a base64-encoded string.
///
/// You can only use this field if you created the embedding with
/// [CreateEmbeddingRequest.encodingFormat] set to
/// [EmbeddingEncodingFormat.base64].
String get embeddingVectorBase64 {
return embedding.mapOrNull(
string: (final s) => s.value,
arrayNumber: (final a) => throw ArgumentError(
'`encodingFormat` is not set to `base64` in the create embedding request',
),
)!;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ class _ChatCompletionStopConverter
if (data is String) {
return ChatCompletionStop.string(data);
}
if (data is List<String>) {
return ChatCompletionStop.arrayString(data);
if (data is List && data.every((item) => item is String)) {
return ChatCompletionStop.arrayString(data.cast());
}
throw Exception('Unexpected value for ChatCompletionStop: $data');
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,14 @@ class _CompletionPromptConverter
if (data is String) {
return CompletionPrompt.string(data);
}
if (data is List<String>) {
return CompletionPrompt.arrayString(data);
if (data is List && data.every((item) => item is String)) {
return CompletionPrompt.arrayString(data.cast());
}
if (data is List<int>) {
return CompletionPrompt.arrayInteger(data);
if (data is List && data.every((item) => item is int)) {
return CompletionPrompt.arrayInteger(data.cast());
}
if (data is List<List<int>>) {
return CompletionPrompt.array(data);
if (data is List && data.every((item) => item is List<int>)) {
return CompletionPrompt.array(data.cast());
}
return CompletionPrompt.string('<|endoftext|>');
}
Expand Down Expand Up @@ -401,8 +401,8 @@ class _CompletionStopConverter
if (data is String) {
return CompletionStop.string(data);
}
if (data is List<String>) {
return CompletionStop.arrayString(data);
if (data is List && data.every((item) => item is String)) {
return CompletionStop.arrayString(data.cast());
}
throw Exception('Unexpected value for CompletionStop: $data');
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ class _EmbeddingInputConverter
if (data is String) {
return EmbeddingInput.string(data);
}
if (data is List<String>) {
return EmbeddingInput.arrayString(data);
if (data is List && data.every((item) => item is String)) {
return EmbeddingInput.arrayString(data.cast());
}
if (data is List<int>) {
return EmbeddingInput.arrayInteger(data);
if (data is List && data.every((item) => item is int)) {
return EmbeddingInput.arrayInteger(data.cast());
}
if (data is List<List<int>>) {
return EmbeddingInput.array(data);
if (data is List && data.every((item) => item is List<int>)) {
return EmbeddingInput.array(data.cast());
}
throw Exception('Unexpected value for EmbeddingInput: $data');
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ class _ModerationInputConverter
if (data is String) {
return ModerationInput.string(data);
}
if (data is List<String>) {
return ModerationInput.arrayString(data);
if (data is List && data.every((item) => item is String)) {
return ModerationInput.arrayString(data.cast());
}
throw Exception('Unexpected value for ModerationInput: $data');
}
Expand Down
49 changes: 48 additions & 1 deletion packages/openai_dart/lib/src/generated/schema/embedding.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Embedding with _$Embedding {
required int index,

/// The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
required List<double> embedding,
@_EmbeddingVectorConverter() required EmbeddingVector embedding,

/// The object type, which is always "embedding".
required String object,
Expand All @@ -46,3 +46,50 @@ class Embedding with _$Embedding {
};
}
}

// ==========================================
// CLASS: EmbeddingVector
// ==========================================

/// The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
@freezed
sealed class EmbeddingVector with _$EmbeddingVector {
const EmbeddingVector._();

const factory EmbeddingVector.string(
String value,
) = _UnionEmbeddingVectorString;

const factory EmbeddingVector.arrayNumber(
List<double> value,
) = _UnionEmbeddingVectorArrayNumber;

/// Object construction from a JSON representation
factory EmbeddingVector.fromJson(Map<String, dynamic> json) =>
_$EmbeddingVectorFromJson(json);
}

/// Custom JSON converter for [EmbeddingVector]
class _EmbeddingVectorConverter
implements JsonConverter<EmbeddingVector, Object?> {
const _EmbeddingVectorConverter();

@override
EmbeddingVector fromJson(Object? data) {
if (data is String) {
return EmbeddingVector.string(data);
}
if (data is List && data.every((item) => item is double)) {
return EmbeddingVector.arrayNumber(data.cast());
}
throw Exception('Unexpected value for EmbeddingVector: $data');
}

@override
Object? toJson(EmbeddingVector data) {
return switch (data) {
_UnionEmbeddingVectorString(value: final v) => v,
_UnionEmbeddingVectorArrayNumber(value: final v) => v,
};
}
}
Loading

0 comments on commit 4f676e8

Please sign in to comment.