Skip to content

Commit

Permalink
refactor(llms): Make all LLM options fields nullable and add copyWith (
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Jan 10, 2024
1 parent 6a3b646 commit 57eceb9
Show file tree
Hide file tree
Showing 38 changed files with 681 additions and 199 deletions.
2 changes: 1 addition & 1 deletion analysis_options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ linter:
- avoid_null_checks_in_equality_operators
- avoid_positional_boolean_parameters
- avoid_print
- avoid_redundant_argument_values
# - avoid_redundant_argument_values # Sometimes is useful to be explicit
- avoid_relative_lib_imports
- avoid_renaming_method_parameters
- avoid_return_types_on_setters
Expand Down
6 changes: 2 additions & 4 deletions docs/expression_language/cookbook/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ Tools are also runnables, and can therefore be used within a chain:

```dart
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final model = ChatOpenAI(
apiKey: openaiApiKey,
defaultOptions: const ChatOpenAIOptions(temperature: 0),
);
final model = ChatOpenAI(apiKey: openaiApiKey);
const stringOutputParser = StringOutputParser();
final promptTemplate = ChatPromptTemplate.fromTemplate('''
Expand All @@ -29,6 +26,7 @@ final chain = Runnable.getMapFromInput() |
final res = await chain.invoke(
'If I had 3 apples and you had 5 apples but we ate 3. '
'If we cut the remaining apples in half, how many pieces would we have?',
options: const ChatOpenAIOptions(temperature: 0),
);
print(res);
// 10.0
Expand Down
4 changes: 2 additions & 2 deletions examples/browser_summarizer/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,10 @@ packages:
dependency: transitive
description:
name: uuid
sha256: "22c94e5ad1e75f9934b766b53c742572ee2677c56bc871d850a57dad0f82127f"
sha256: "8c951c9cb6504b2aa6b3666e6de504032d9baec24bf4cbabd3eea9edd73d4d77"
url: "https://pub.dev"
source: hosted
version: "4.2.2"
version: "4.3.2"
vector_math:
dependency: transitive
description:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@ void main(final List<String> arguments) async {

Future<void> _calculator() async {
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final model = ChatOpenAI(
apiKey: openaiApiKey,
defaultOptions: const ChatOpenAIOptions(
temperature: 0,
),
);
final model = ChatOpenAI(apiKey: openaiApiKey);
const stringOutputParser = StringOutputParser();

final promptTemplate = ChatPromptTemplate.fromTemplate('''
Expand All @@ -37,6 +32,7 @@ MATH EXPRESSION:''');
final res = await chain.invoke(
'If I had 3 apples and you had 5 apples but we ate 3. '
'If we cut the remaining apples in half, how many pieces would we have?',
options: const ChatOpenAIOptions(temperature: 0),
);
print(res);
// 10.0
Expand Down
4 changes: 2 additions & 2 deletions examples/docs_examples/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,10 @@ packages:
dependency: transitive
description:
name: uuid
sha256: bb55f38968b9427ce5dcdb8aaaa41049282195e0cfa4cf48593572fa3d1f36bc
sha256: "8c951c9cb6504b2aa6b3666e6de504032d9baec24bf4cbabd3eea9edd73d4d77"
url: "https://pub.dev"
source: hosted
version: "4.3.1"
version: "4.3.2"
vector_math:
dependency: transitive
description:
Expand Down
4 changes: 2 additions & 2 deletions examples/hello_world_backend/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,10 @@ packages:
dependency: transitive
description:
name: uuid
sha256: bb55f38968b9427ce5dcdb8aaaa41049282195e0cfa4cf48593572fa3d1f36bc
sha256: "8c951c9cb6504b2aa6b3666e6de504032d9baec24bf4cbabd3eea9edd73d4d77"
url: "https://pub.dev"
source: hosted
version: "4.3.1"
version: "4.3.2"
vector_math:
dependency: transitive
description:
Expand Down
4 changes: 2 additions & 2 deletions examples/hello_world_cli/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,10 @@ packages:
dependency: transitive
description:
name: uuid
sha256: bb55f38968b9427ce5dcdb8aaaa41049282195e0cfa4cf48593572fa3d1f36bc
sha256: "8c951c9cb6504b2aa6b3666e6de504032d9baec24bf4cbabd3eea9edd73d4d77"
url: "https://pub.dev"
source: hosted
version: "4.3.1"
version: "4.3.2"
vector_math:
dependency: transitive
description:
Expand Down
4 changes: 2 additions & 2 deletions examples/hello_world_flutter/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,10 @@ packages:
dependency: transitive
description:
name: uuid
sha256: "22c94e5ad1e75f9934b766b53c742572ee2677c56bc871d850a57dad0f82127f"
sha256: "8c951c9cb6504b2aa6b3666e6de504032d9baec24bf4cbabd3eea9edd73d4d77"
url: "https://pub.dev"
source: hosted
version: "4.2.2"
version: "4.3.2"
vector_math:
dependency: transitive
description:
Expand Down
2 changes: 1 addition & 1 deletion packages/googleai_dart/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dev_dependencies:
build_runner: ^2.4.6
freezed: ^2.4.5
json_serializable: ^6.7.1
# openapi_spec: ^0.7.8
# openapi_spec: ^0.7.8
openapi_spec:
git:
url: https://github.com/davidmigloz/openapi_spec.git
Expand Down
31 changes: 31 additions & 0 deletions packages/langchain/lib/src/model_io/language_models/base.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import 'package:meta/meta.dart';

import '../../core/core.dart';
import '../chat_models/models/models.dart';
import '../prompts/models/models.dart';
Expand Down Expand Up @@ -83,4 +85,33 @@ abstract class BaseLanguageModel<Input extends Object,

@override
String toString() => modelType;

/// Throws an error if the model id is not specified.
@protected
Never throwNullModelError() {
throw ArgumentError('''
Null model in $runtimeType.
You need to specify the id of model to use either in `$runtimeType.defaultOptions`
or in the options passed when invoking the model.
Example:
```
// In defaultOptions
final model = $runtimeType(
defaultOptions: ${runtimeType}Options(
model: 'model-id',
),
);
// Or when invoking the model
final res = await model.invoke(
prompt,
options: ${runtimeType}Options(
model: 'model-id',
),
);
```
''');
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ class ChatGoogleGenerativeAI
final Map<String, String>? headers,
final Map<String, dynamic>? queryParams,
final http.Client? client,
this.defaultOptions = const ChatGoogleGenerativeAIOptions(),
this.defaultOptions = const ChatGoogleGenerativeAIOptions(
model: 'gemini-pro',
),
}) : _client = GoogleAIClient(
apiKey: apiKey,
baseUrl: baseUrl,
Expand All @@ -173,7 +175,8 @@ class ChatGoogleGenerativeAI
final ChatGoogleGenerativeAIOptions? options,
}) async {
final id = _uuid.v4();
final model = options?.model ?? defaultOptions.model;
final model =
options?.model ?? defaultOptions.model ?? throwNullModelError();
final completion = await _client.generateContent(
modelId: model,
request: _generateCompletionRequest(messages, options: options),
Expand Down Expand Up @@ -237,7 +240,7 @@ class ChatGoogleGenerativeAI
final ChatGoogleGenerativeAIOptions? options,
}) async {
final tokens = await _client.countTokens(
modelId: options?.model ?? defaultOptions.model,
modelId: options?.model ?? defaultOptions.model ?? throwNullModelError(),
request: CountTokensRequest(
contents: promptValue.toChatMessages().toContentList(),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
/// The LLM to use.
///
/// You can find a list of available models here: https://ai.google.dev/models
final String model;
final String? model;

/// The maximum cumulative probability of tokens to consider when sampling.
/// The model uses combined Top-k and nucleus sampling. Tokens are sorted
Expand Down Expand Up @@ -76,6 +76,30 @@ class ChatGoogleGenerativeAIOptions extends ChatModelOptions {
/// is no safety setting for a given category provided in the list, the API will use
/// the default safety setting for that category.
final List<ChatGoogleGenerativeAISafetySetting>? safetySettings;

/// Creates a copy of this [ChatGoogleGenerativeAIOptions] object with the given fields
/// replaced with the new values.
ChatGoogleGenerativeAIOptions copyWith({
final String? model,
final double? topP,
final int? topK,
final int? candidateCount,
final int? maxOutputTokens,
final double? temperature,
final List<String>? stopSequences,
final List<ChatGoogleGenerativeAISafetySetting>? safetySettings,
}) {
return ChatGoogleGenerativeAIOptions(
model: model ?? this.model,
topP: topP ?? this.topP,
topK: topK ?? this.topK,
candidateCount: candidateCount ?? this.candidateCount,
maxOutputTokens: maxOutputTokens ?? this.maxOutputTokens,
temperature: temperature ?? this.temperature,
stopSequences: stopSequences ?? this.stopSequences,
safetySettings: safetySettings ?? this.safetySettings,
);
}
}

/// {@template chat_google_generative_ai_safety_setting}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ class ChatVertexAI extends BaseChatModel<ChatVertexAIOptions> {
required final String project,
final String location = 'us-central1',
final String? rootUrl,
this.defaultOptions = const ChatVertexAIOptions(),
this.defaultOptions = const ChatVertexAIOptions(
publisher: 'google',
model: 'chat-bison',
),
}) : client = VertexAIGenAIClient(
httpClient: httpClient,
project: project,
Expand Down Expand Up @@ -158,25 +161,33 @@ class ChatVertexAI extends BaseChatModel<ChatVertexAIOptions> {
final examples = (options?.examples ?? defaultOptions.examples)
?.map((final e) => e.toVertexAIChatExample())
.toList(growable: false);
final model =
options?.model ?? defaultOptions.model ?? throwNullModelError();

final result = await client.chat.predict(
context: context,
examples: examples,
messages: vertexMessages,
publisher: options?.publisher ?? defaultOptions.publisher,
model: options?.model ?? defaultOptions.model,
publisher: options?.publisher ??
defaultOptions.publisher ??
ArgumentError.checkNotNull(
defaultOptions.publisher,
'VertexAIOptions.publisher',
),
model: model,
parameters: VertexAITextChatModelRequestParams(
maxOutputTokens:
options?.maxOutputTokens ?? defaultOptions.maxOutputTokens,
temperature: options?.temperature ?? defaultOptions.temperature,
topP: options?.topP ?? defaultOptions.topP,
topK: options?.topK ?? defaultOptions.topK,
stopSequences: options?.stopSequences ?? defaultOptions.stopSequences,
options?.maxOutputTokens ?? defaultOptions.maxOutputTokens ?? 1024,
temperature: options?.temperature ?? defaultOptions.temperature ?? 0.2,
topP: options?.topP ?? defaultOptions.topP ?? 0.95,
topK: options?.topK ?? defaultOptions.topK ?? 40,
stopSequences:
options?.stopSequences ?? defaultOptions.stopSequences ?? const [],
candidateCount:
options?.candidateCount ?? defaultOptions.candidateCount,
options?.candidateCount ?? defaultOptions.candidateCount ?? 1,
),
);
return result.toChatResult(id, options?.model ?? defaultOptions.model);
return result.toChatResult(id, model);
}

/// Tokenizes the given prompt using tiktoken.
Expand Down
Loading

0 comments on commit 57eceb9

Please sign in to comment.