Skip to content

Commit

Permalink
feat(llms): Migrate OpenAI to openai_dart client (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Nov 2, 2023
1 parent 8f626fe commit 6c90b37
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 112 deletions.
16 changes: 7 additions & 9 deletions packages/langchain_openai/lib/src/llms/models/mappers.dart
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import 'package:langchain/langchain.dart';
import '../../client/models/models.dart';
import 'package:openai_dart/openai_dart.dart';

/// Mapper for [OpenAICompletionModel] to [LLMResult].
extension OpenAICompletionMapper on OpenAICompletion {
extension CreateCompletionResponseMapper on CreateCompletionResponse {
LLMResult toLLMResult() {
return LLMResult(
generations: choices
Expand All @@ -18,22 +17,21 @@ extension OpenAICompletionMapper on OpenAICompletion {
}
}

/// Mapper for [OpenAICompletionChoice] to [LLMGeneration].
extension _OpenAICompletionChoiceMapper on OpenAICompletionChoice {
extension _CompletionChoiceMapper on CompletionChoice {
LLMGeneration toLLMGeneration() {
final json = toJson();
return LLMGeneration(
text,
generationInfo: {
'index': index,
'logprobs': logprobs,
'finish_reason': finishReason,
'logprobs': json['logprobs'],
'finish_reason': json['finish_reason'],
},
);
}
}

/// Mapper for [OpenAICompletionUsage] to [LanguageModelUsage].
extension _OpenAICompletionUsageMapper on OpenAICompletionUsage {
extension _CompletionUsageMapper on CompletionUsage {
LanguageModelUsage toLanguageModelUsage() {
return LanguageModelUsage(
promptTokens: promptTokens,
Expand Down
296 changes: 194 additions & 102 deletions packages/langchain_openai/lib/src/llms/openai.dart
Original file line number Diff line number Diff line change
@@ -1,89 +1,211 @@
import 'package:http/http.dart' as http;
import 'package:langchain/langchain.dart';
import 'package:openai_dart/openai_dart.dart';
import 'package:tiktoken/tiktoken.dart';

import '../client/base.dart';
import '../client/openai_client.dart';
import 'models/mappers.dart';
import 'models/models.dart';

/// {@template base_openai}
/// Wrapper around OpenAI large language models.
/// {@endtemplate}
abstract base class BaseOpenAI extends BaseLLM<OpenAIOptions> {
/// {@macro base_openai}
BaseOpenAI({
required final String? apiKey,
required final BaseOpenAIClient? apiClient,
required this.model,
required this.maxTokens,
required this.temperature,
required this.topP,
required this.n,
required this.presencePenalty,
required this.frequencyPenalty,
required this.bestOf,
required this.logitBias,
required this.encoding,
required this.user,
}) : assert(
apiKey != null || apiClient != null,
'Either apiKey or apiClient must be provided.',
),
_client = apiClient ?? OpenAIClient.instanceFor(apiKey: apiKey!);

final BaseOpenAIClient _client;
/// Wrapper around OpenAI Completions API.
///
/// Example:
/// ```dart
/// final llm = OpenAI(apiKey: '...', temperature: 1);
/// final res = await llm('Tell me a joke');
/// ```
///
/// - [Completions guide](https://platform.openai.com/docs/guides/gpt/completions-api)
/// - [Completions API docs](https://platform.openai.com/docs/api-reference/completions)
///
/// ### Authentication
///
/// The OpenAI API uses API keys for authentication. Visit your
/// [API Keys](https://platform.openai.com/account/api-keys) page to retrieve
/// the API key you'll use in your requests.
///
/// #### Organization (optional)
///
/// For users who belong to multiple organizations, you can specify which
/// organization is used for an API request. Usage from these API requests will
/// count against the specified organization's subscription quota.
///
/// ```dart
/// final client = OpenAI(
/// apiKey: 'OPENAI_API_KEY',
/// organization: 'org-dtDDtkEGoFccn5xaP5W1p3Rr',
/// );
/// ```
///
/// ### Advance
///
/// #### Custom HTTP client
///
/// You can always provide your own implementation of `http.Client` for further
/// customization:
///
/// ```dart
/// final client = OpenAI(
/// apiKey: 'OPENAI_API_KEY',
/// client: MyHttpClient(),
/// );
/// ```
///
/// #### Using a proxy
///
/// ##### HTTP proxy
///
/// You can use your own HTTP proxy by overriding the `baseUrl` and providing
/// your required `headers`:
///
/// ```dart
/// final client = OpenAI(
/// baseUrl: 'https://my-proxy.com',
/// headers: {'x-my-proxy-header': 'value'},
/// );
/// ```
///
/// If you need further customization, you can always provide your own
/// `http.Client`.
///
/// ##### SOCKS5 proxy
///
/// To use a SOCKS5 proxy, you can use the
/// [`socks5_proxy`](https://pub.dev/packages/socks5_proxy) package and a
/// custom `http.Client`.
class OpenAI extends BaseLLM<OpenAIOptions> {
/// Create a new [OpenAI] instance.
///
/// Main configuration options:
/// - `apiKey`: your OpenAI API key. You can find your API key in the
/// [OpenAI dashboard](https://platform.openai.com/account/api-keys).
/// - `organization`: your OpenAI organization ID (if applicable).
/// - [OpenAI.model]
/// - [OpenAI.bestOf]
/// - [OpenAI.frequencyPenalty]
/// - [OpenAI.logitBias]
/// - [OpenAI.logprobs]
/// - [OpenAI.maxTokens]
/// - [OpenAI.n]
/// - [OpenAI.presencePenalty]
/// - [OpenAI.suffix]
/// - [OpenAI.temperature]
/// - [OpenAI.topP]
/// - [OpenAI.user]
/// - [OpenAI.encoding]
///
/// Advance configuration options:
/// - `baseUrl`: the base URL to use. Defaults to OpenAI's API URL. You can
/// override this to use a different API URL, or to use a proxy.
/// - `headers`: global headers to send with every request. You can use
/// this to set custom headers, or to override the default headers.
/// - `client`: the HTTP client to use. You can set your own HTTP client if
/// you need further customization (e.g. to use a Socks5 proxy).
OpenAI({
final String? apiKey,
final String? organization,
final String? baseUrl,
final Map<String, String>? headers,
final http.Client? client,
this.model = 'text-davinci-003',
this.bestOf = 1,
this.frequencyPenalty = 0,
this.logitBias,
this.logprobs,
this.maxTokens,
this.n = 1,
this.presencePenalty = 0,
this.suffix,
this.temperature = 1,
this.topP = 1,
this.user,
this.encoding,
}) : _client = OpenAIClient(
apiKey: apiKey ?? '',
organization: organization,
baseUrl: baseUrl,
headers: headers,
client: client,
);

/// A client for interacting with OpenAI API.
final OpenAIClient _client;

/// ID of the model to use (e.g. 'text-davinci-003').
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions/create-model
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-model
final String model;

/// The maximum number of tokens to generate in the completion.
/// Generates best_of completions server-side and returns the "best"
/// (the one with the highest log probability per token).
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions/create-max_tokens
final int maxTokens;
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-best_of
final int bestOf;

/// What sampling temperature to use, between 0 and 2.
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on
/// their existing frequency in the text so far, decreasing the model's
/// likelihood to repeat the same line verbatim.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions/create-temperature
final double temperature;
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-frequency_penalty
final double frequencyPenalty;

/// An alternative to sampling with temperature, called nucleus sampling,
/// where the model considers the results of the tokens with top_p
/// probability mass.
/// Modify the likelihood of specified tokens appearing in the completion.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions/create-top_p
final double topP;
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-logit_bias
final Map<String, int>? logitBias;

/// Include the log probabilities on the `logprobs` most likely tokens, as
/// well the chosen tokens. For example, if `logprobs` is 5, the API will
/// return a list of the 5 most likely tokens. The API will always return the
/// `logprob` of the sampled token, so there may be up to `logprobs+1`
/// elements in the response.
///
/// The maximum value for logprobs is 5.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-logprobs
final int? logprobs;

/// The maximum number of tokens to generate in the completion.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-max_tokens
final int? maxTokens;

/// How many completions to generate for each prompt.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions/create-n
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-n
final int n;

/// Number between -2.0 and 2.0. Positive values penalize new tokens based on
/// whether they appear in the text so far, increasing the model's likelihood
/// to talk about new topics.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions/create-presence_penalty
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-presence_penalty
final double presencePenalty;

/// Number between -2.0 and 2.0. Positive values penalize new tokens based on
/// their existing frequency in the text so far, decreasing the model's
/// likelihood to repeat the same line verbatim.
/// The suffix that comes after a completion of inserted text.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions/create-frequency_penalty
final double frequencyPenalty;
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-suffix
final String? suffix;

/// Generates best_of completions server-side and returns the "best"
/// (the one with the highest log probability per token).
/// What sampling temperature to use, between 0 and 2.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions/create-best_of
final int bestOf;
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-temperature
final double temperature;

/// Modify the likelihood of specified tokens appearing in the completion.
/// An alternative to sampling with temperature, called nucleus sampling,
/// where the model considers the results of the tokens with top_p
/// probability mass.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-top_p
final double topP;

/// A unique identifier representing your end-user, which can help OpenAI to
/// monitor and detect abuse.
///
/// If you need to send different users in different requests, you can set
/// this field in [OpenAIOptions.user] instead.
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias
final Map<String, double>? logitBias;
/// Ref: https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids
final String? user;

/// The encoding to use by tiktoken when [tokenize] is called.
///
Expand All @@ -104,15 +226,6 @@ abstract base class BaseOpenAI extends BaseLLM<OpenAIOptions> {
/// https://github.com/mvitlov/tiktoken/blob/master/lib/tiktoken.dart
final String? encoding;

/// A unique identifier representing your end-user, which can help OpenAI to
/// monitor and detect abuse.
///
/// If you need to send different users in different requests, you can set
/// this field in [OpenAIOptions.user] instead.
///
/// Ref: https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids
final String? user;

@override
String get modelType => 'openai';

Expand All @@ -122,18 +235,24 @@ abstract base class BaseOpenAI extends BaseLLM<OpenAIOptions> {
final OpenAIOptions? options,
}) async {
final completion = await _client.createCompletion(
model: model,
prompts: [prompt],
maxTokens: maxTokens,
temperature: temperature,
topP: topP,
n: n,
stop: options?.stop,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
bestOf: bestOf,
logitBias: logitBias,
user: options?.user,
request: CreateCompletionRequest(
model: CompletionModel.string(model),
prompt: CompletionPrompt.string(prompt),
bestOf: bestOf,
frequencyPenalty: frequencyPenalty,
logitBias: logitBias,
logprobs: logprobs,
maxTokens: maxTokens,
n: n,
presencePenalty: presencePenalty,
stop: options?.stop != null
? CompletionStop.arrayString(options!.stop!)
: null,
suffix: suffix,
temperature: temperature,
topP: topP,
user: options?.user ?? user,
),
);
return completion.toLLMResult();
}
Expand All @@ -151,30 +270,3 @@ abstract base class BaseOpenAI extends BaseLLM<OpenAIOptions> {
return encoding.encode(promptValue.toString());
}
}

/// {@template openai}
/// Wrapper around [OpenAI Completions API](https://platform.openai.com/docs/api-reference/completions).
///
/// Example:
/// ```dart
/// final llm = OpenAI(apiKey: '...', temperature: 1);
/// ```
/// {@endtemplate}
final class OpenAI extends BaseOpenAI {
/// {@macro openai}
OpenAI({
super.apiKey,
super.apiClient,
super.model = 'text-davinci-003',
super.maxTokens = 256,
super.temperature = 1,
super.topP = 1,
super.n = 1,
super.presencePenalty = 0,
super.frequencyPenalty = 0,
super.bestOf = 1,
super.logitBias,
super.encoding,
super.user,
});
}
2 changes: 1 addition & 1 deletion packages/langchain_openai/test/llms/openai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void main() {
presencePenalty: 0.1,
frequencyPenalty: 0.1,
bestOf: 10,
logitBias: {'foo': 1.0},
logitBias: {'foo': 1},
user: 'foo',
);
expect(llm.model, 'foo');
Expand Down

0 comments on commit 6c90b37

Please sign in to comment.