Skip to content

Commit

Permalink
feat(llms)!: Move all model config options to OpenAIOptions (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Nov 21, 2023
1 parent dfaee16 commit 16e3e8e
Show file tree
Hide file tree
Showing 20 changed files with 274 additions and 206 deletions.
5 changes: 4 additions & 1 deletion docs/modules/chains/chains.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ and returns the response from an LLM.
To use the `LLMChain`, first create a prompt template.

```dart
final llm = OpenAI(apiKey: openaiApiKey, temperature: 0.9);
final llm = OpenAI(
apiKey: openaiApiKey,
defaultOptions: const OpenAIOptions(temperature: 0.9),
);
final prompt = PromptTemplate.fromTemplate(
'What is a good name for a company that makes {product}?',
);
Expand Down
5 changes: 4 additions & 1 deletion docs/modules/chains/foundational/llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ to LLM and returns the LLM output.
## Get started

```dart
final llm = OpenAI(apiKey: openaiApiKey, temperature: 0.9);
final llm = OpenAI(
apiKey: openaiApiKey,
defaultOptions: const OpenAIOptions(temperature: 0.9),
);
final prompt = PromptTemplate.fromTemplate(
'What is a good name for a company that makes {product}?',
);
Expand Down
5 changes: 4 additions & 1 deletion docs/modules/chains/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ Now, we can try running the chain that we called:

```dart
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final llm = OpenAI(apiKey: openaiApiKey, temperature: 0.9);
final llm = OpenAI(
apiKey: openaiApiKey,
defaultOptions: const OpenAIOptions(temperature: 0.9),
);
final prompt1 = PromptTemplate.fromTemplate(
'What is a good name for a company that makes {product}?',
Expand Down
5 changes: 4 additions & 1 deletion docs/modules/chains/how_to/custom_chain.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ Now, we can try running the chain that we called:

```dart
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final llm = OpenAI(apiKey: openaiApiKey, temperature: 0.9);
final llm = OpenAI(
apiKey: openaiApiKey,
defaultOptions: const OpenAIOptions(temperature: 0.9),
);
final prompt1 = PromptTemplate.fromTemplate(
'What is a good name for a company that makes {product}?',
Expand Down
5 changes: 4 additions & 1 deletion docs/modules/memory/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ print(await memory.loadMemoryVariables());
Finally, let’s take a look at using this in a chain:

```dart
final llm = OpenAI(apiKey: openaiApiKey, temperature: 0);
final llm = OpenAI(
apiKey: openaiApiKey,
defaultOptions: const OpenAIOptions(temperature: 0),
);
final conversation = ConversationChain(
llm: llm,
memory: ConversationBufferMemory(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ This tutorial goes over how to track your token usage for specific calls. It is
currently only implemented for the OpenAI API.

```dart
final openai = OpenAI(apiKey: openaiApiKey, temperature: 0.9);
final openai = OpenAI(
apiKey: openaiApiKey,
defaultOptions: const OpenAIOptions(temperature: 0.9),
);
final result = await openai.generate('Tell me a joke');
final usage = result.usage;
print(usage?.promptTokens); // 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ Now that we have created a custom prompt template, we can use it to generate
prompts for our task.

```dart
final openai = OpenAI(apiKey: openaiApiKey, temperature: 0.9);
final openai = OpenAI(
apiKey: openaiApiKey,
defaultOptions: const OpenAIOptions(temperature: 0.9),
);
const fnExplainer = FunctionExplainerPromptTemplate();
final prompt = fnExplainer.formatPrompt({
Expand Down
10 changes: 8 additions & 2 deletions packages/langchain/lib/src/model_io/chat_models/fake.dart
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ class FakeChatModel extends SimpleChatModel {
}

@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
Future<List<int>> tokenize(
final PromptValue promptValue, {
final ChatModelOptions? options,
}) async {
return promptValue
.toString()
.split(' ')
Expand Down Expand Up @@ -77,7 +80,10 @@ class FakeEchoChatModel extends SimpleChatModel {
}

@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
Future<List<int>> tokenize(
final PromptValue promptValue, {
final ChatModelOptions? options,
}) async {
return promptValue
.toString()
.split(' ')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ abstract class BaseLanguageModel<Input extends Object,
/// model.
///
/// - [promptValue] The prompt to tokenize.
Future<List<int>> tokenize(final PromptValue promptValue);
Future<List<int>> tokenize(
final PromptValue promptValue, {
final Options? options,
});

/// Returns the number of tokens resulting from [tokenize] the given prompt.
///
Expand Down
15 changes: 12 additions & 3 deletions packages/langchain/lib/src/model_io/llms/fake.dart
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ class FakeListLLM extends SimpleLLM {
}

@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
Future<List<int>> tokenize(
final PromptValue promptValue, {
final LLMOptions? options,
}) async {
return promptValue
.toString()
.split(' ')
Expand Down Expand Up @@ -76,7 +79,10 @@ class FakeEchoLLM extends SimpleLLM {
}

@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
Future<List<int>> tokenize(
final PromptValue promptValue, {
final LLMOptions? options,
}) async {
return promptValue
.toString()
.split(' ')
Expand Down Expand Up @@ -116,7 +122,10 @@ class FakeHandlerLLM extends SimpleLLM {
}

@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
Future<List<int>> tokenize(
final PromptValue promptValue, {
final LLMOptions? options,
}) async {
return promptValue
.toString()
.split(' ')
Expand Down
5 changes: 4 additions & 1 deletion packages/langchain/test/core/runnable/binding_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ class _FakeOptionsChatModel
}

@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
Future<List<int>> tokenize(
final PromptValue promptValue, {
final _FakeOptionsChatModelOptions? options,
}) async {
return promptValue
.toString()
.split(' ')
Expand Down
5 changes: 4 additions & 1 deletion packages/langchain_google/lib/src/chat_models/vertex_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,10 @@ class ChatVertexAI extends BaseChatModel<ChatVertexAIOptions> {
///
/// - [promptValue] The prompt to tokenize.
@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
Future<List<int>> tokenize(
final PromptValue promptValue, {
final ChatVertexAIOptions? options,
}) async {
final encoding = encodingForModel('text-davinci-003');
return encoding.encode(promptValue.toString());
}
Expand Down
5 changes: 4 additions & 1 deletion packages/langchain_google/lib/src/llms/vertex_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ class VertexAI extends BaseLLM<VertexAIOptions> {
///
/// - [promptValue] The prompt to tokenize.
@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
Future<List<int>> tokenize(
final PromptValue promptValue, {
final VertexAIOptions? options,
}) async {
final encoding = encodingForModel('text-davinci-003');
return encoding.encode(promptValue.toString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ void main() async {
/// The most basic building block of LangChain is calling an LLM on some input.
Future<void> _example1() async {
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final openai = OpenAI(apiKey: openaiApiKey, temperature: 0.9);
final openai = OpenAI(
apiKey: openaiApiKey,
defaultOptions: const OpenAIOptions(temperature: 0.9),
);
final result = await openai('Tell me a joke');
print(result);
}
Expand Down
2 changes: 2 additions & 0 deletions packages/langchain_openai/lib/langchain_openai.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
/// LangChain.dart integration module for OpenAI (GPT-3, GPT-4, Functions, etc.).
library;

export 'package:openai_dart/openai_dart.dart' show OpenAIClientException;

export 'src/agents/agents.dart';
export 'src/chains/chains.dart';
export 'src/chat_models/chat_models.dart';
Expand Down
5 changes: 4 additions & 1 deletion packages/langchain_openai/lib/src/chat_models/openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,10 @@ class ChatOpenAI extends BaseChatModel<ChatOpenAIOptions> {
///
/// - [promptValue] The prompt to tokenize.
@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
Future<List<int>> tokenize(
final PromptValue promptValue, {
final ChatOpenAIOptions? options,
}) async {
return _getTiktoken().encode(promptValue.toString());
}

Expand Down
100 changes: 93 additions & 7 deletions packages/langchain_openai/lib/src/llms/models/models.dart
Original file line number Diff line number Diff line change
@@ -1,30 +1,116 @@
import 'package:langchain/langchain.dart';
import '../openai.dart';

/// {@template openai_options}
/// Options to pass into the OpenAI LLM.
/// {@endtemplate}
class OpenAIOptions extends LLMOptions {
/// {@macro openai_options}
const OpenAIOptions({
this.model = 'gpt-3.5-turbo-instruct',
this.bestOf = 1,
this.frequencyPenalty = 0,
this.logitBias,
this.logprobs,
this.maxTokens = 256,
this.n = 1,
this.presencePenalty = 0,
this.seed,
this.stop,
this.suffix,
this.temperature = 1,
this.topP = 1,
this.user,
});

/// ID of the model to use (e.g. 'gpt-3.5-turbo-instruct').
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-model
final String model;

/// Generates best_of completions server-side and returns the "best"
/// (the one with the highest log probability per token).
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-best_of
final int bestOf;

/// Number between -2.0 and 2.0. Positive values penalize new tokens based on
/// their existing frequency in the text so far, decreasing the model's
/// likelihood to repeat the same line verbatim.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-frequency_penalty
final double frequencyPenalty;

/// Modify the likelihood of specified tokens appearing in the completion.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-logit_bias
final Map<String, int>? logitBias;

/// Include the log probabilities on the `logprobs` most likely tokens, as
/// well the chosen tokens. For example, if `logprobs` is 5, the API will
/// return a list of the 5 most likely tokens. The API will always return the
/// `logprob` of the sampled token, so there may be up to `logprobs+1`
/// elements in the response.
///
/// The maximum value for logprobs is 5.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-logprobs
final int? logprobs;

/// The maximum number of tokens to generate in the completion.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-max_tokens
final int? maxTokens;

/// How many completions to generate for each prompt.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-n
final int n;

/// Number between -2.0 and 2.0. Positive values penalize new tokens based on
/// whether they appear in the text so far, increasing the model's likelihood
/// to talk about new topics.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-presence_penalty
final double presencePenalty;

/// If specified, our system will make a best effort to sample
/// deterministically, such that repeated requests with the same seed and
/// parameters should return the same result.
///
/// Determinism is not guaranteed, and you should refer to the
/// `system_fingerprint` response parameter to monitor changes in the backend.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-seed
final int? seed;

/// Up to 4 sequences where the API will stop generating further tokens.
/// The returned text will not contain the stop sequence.
///
/// Ref: https://platform.openai.com/docs/api-reference/completions/create#stop
/// Ref: https://platform.openai.com/docs/api-reference/completions/create#completions-create-stop
final List<String>? stop;

/// The suffix that comes after a completion of inserted text.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-suffix
final String? suffix;

/// What sampling temperature to use, between 0 and 2.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-temperature
final double temperature;

/// An alternative to sampling with temperature, called nucleus sampling,
/// where the model considers the results of the tokens with top_p
/// probability mass.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-top_p
final double topP;

/// A unique identifier representing your end-user, which can help OpenAI to
/// monitor and detect abuse.
///
/// If the user does not change between requests, you can set this field in
/// [OpenAI.user] instead.
///
/// If you specify it in both places, the value in [OpenAIOptions.user] will
/// be used.
/// If you need to send different users in different requests, you can set
/// this field in [OpenAIOptions.user] instead.
///
/// Ref: https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids
final String? user;
Expand Down
Loading

0 comments on commit 16e3e8e

Please sign in to comment.