Skip to content

Commit

Permalink
feat(vertex_ai): Add support for stopSequence and candidateCount (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Sep 4, 2023
1 parent 421d36b commit eab7d96
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 6 deletions.
28 changes: 25 additions & 3 deletions packages/vertex_ai/lib/src/gen_ai/models/text.dart
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class VertexAITextModelRequestParams {
this.maxOutputTokens = 1024,
this.topP = 0.95,
this.topK = 40,
this.stopSequence = const [],
this.candidateCount = 1,
});

/// The temperature is used for sampling during response generation, which
Expand Down Expand Up @@ -108,13 +110,24 @@ class VertexAITextModelRequestParams {
/// Range: `[1, 40]`
final int topK;

/// Specifies a list of strings that tells the model to stop generating text
/// if one of the strings is encountered in the response. If a string appears
/// multiple times in the response, then the response truncates where it's
/// first encountered. The strings are case-sensitive.
final List<String> stopSequence;

/// The number of response variations to return.
final int candidateCount;

/// Converts this object to a [Map].
Map<String, dynamic> toMap() {
return {
'temperature': temperature,
'maxOutputTokens': maxOutputTokens,
'topP': topP,
'topK': topK,
'stopSequence': stopSequence,
'candidateCount': candidateCount,
};
}

Expand All @@ -125,22 +138,31 @@ class VertexAITextModelRequestParams {
temperature == other.temperature &&
maxOutputTokens == other.maxOutputTokens &&
topP == other.topP &&
topK == other.topK;
topK == other.topK &&
const ListEquality<String>().equals(
stopSequence,
other.stopSequence,
) &&
candidateCount == other.candidateCount;

@override
int get hashCode =>
temperature.hashCode ^
maxOutputTokens.hashCode ^
topP.hashCode ^
topK.hashCode;
topK.hashCode ^
const ListEquality<String>().hash(stopSequence) ^
candidateCount.hashCode;

@override
String toString() {
return 'VertexAITextModelRequestParams{'
'temperature: $temperature, '
'maxOutputTokens: $maxOutputTokens, '
'topP: $topP, '
'topK: $topK}';
'topK: $topK, '
'stopSequence: $stopSequence, '
'candidateCount: $candidateCount}';
}
}

Expand Down
28 changes: 25 additions & 3 deletions packages/vertex_ai/lib/src/gen_ai/models/text_chat.dart
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class VertexAITextChatModelRequestParams {
this.maxOutputTokens = 1024,
this.topP = 0.95,
this.topK = 40,
this.stopSequence = const [],
this.candidateCount = 1,
});

/// The temperature is used for sampling during response generation, which
Expand Down Expand Up @@ -154,13 +156,24 @@ class VertexAITextChatModelRequestParams {
/// Range: `[1, 40]`
final int topK;

/// Specifies a list of strings that tells the model to stop generating text
/// if one of the strings is encountered in the response. If a string appears
/// multiple times in the response, then the response truncates where it's
/// first encountered. The strings are case-sensitive.
final List<String> stopSequence;

/// The number of response variations to return.
final int candidateCount;

/// Converts this object to a [Map].
Map<String, dynamic> toMap() {
return {
'temperature': temperature,
'maxOutputTokens': maxOutputTokens,
'topP': topP,
'topK': topK,
'stopSequence': stopSequence,
'candidateCount': candidateCount,
};
}

Expand All @@ -171,22 +184,31 @@ class VertexAITextChatModelRequestParams {
temperature == other.temperature &&
maxOutputTokens == other.maxOutputTokens &&
topP == other.topP &&
topK == other.topK;
topK == other.topK &&
const ListEquality<String>().equals(
stopSequence,
other.stopSequence,
) &&
candidateCount == other.candidateCount;

@override
int get hashCode =>
temperature.hashCode ^
maxOutputTokens.hashCode ^
topP.hashCode ^
topK.hashCode;
topK.hashCode ^
const ListEquality<String>().hash(stopSequence) ^
candidateCount.hashCode;

@override
String toString() {
return 'VertexAITextChatModelRequestParams{'
'temperature: $temperature, '
'maxOutputTokens: $maxOutputTokens, '
'topP: $topP, '
'topK: $topK}';
'topK: $topK, '
'stopSequence: $stopSequence, '
'candidateCount: $candidateCount}';
}
}

Expand Down
4 changes: 4 additions & 0 deletions packages/vertex_ai/test/gen_ai/mappers/chat_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ void main() {
maxOutputTokens: 256,
topP: 0.1,
topK: 30,
stopSequence: ['STOP'],
candidateCount: 10,
),
);
final expected = GoogleCloudAiplatformV1PredictRequest(
Expand Down Expand Up @@ -62,6 +64,8 @@ void main() {
'maxOutputTokens': 256,
'topP': 0.1,
'topK': 30,
'stopSequence': ['STOP'],
'candidateCount': 10,
},
);

Expand Down
4 changes: 4 additions & 0 deletions packages/vertex_ai/test/gen_ai/mappers/text_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ void main() {
maxOutputTokens: 256,
topP: 0.1,
topK: 30,
stopSequence: ['STOP'],
candidateCount: 10,
),
);
final expected = GoogleCloudAiplatformV1PredictRequest(
Expand All @@ -26,6 +28,8 @@ void main() {
'maxOutputTokens': 256,
'topP': 0.1,
'topK': 30,
'stopSequence': ['STOP'],
'candidateCount': 10,
},
);

Expand Down

0 comments on commit eab7d96

Please sign in to comment.