Skip to content

Commit

Permalink
feat: Support estimating the number of tokens for a given prompt (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Jul 10, 2023
1 parent 7bfa6d1 commit e22f22c
Show file tree
Hide file tree
Showing 16 changed files with 348 additions and 18 deletions.
6 changes: 3 additions & 3 deletions docs/modules/model_io/models/chat_models/chat_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ print(chatRes);

The `generate` APIs return an `ChatResult` which contains a `ChatGeneration`
object with the `output` messages and some metadata about the generation. It
also contains some additional information like `tokensUsage` and `modelOutput`.
also contains some additional information like `usage` and `modelOutput`.

```dart
final chatRes1 = await chat.generate(messages);
print(chatRes1.generations);
// -> [ChatGeneration{
// output: AIChatMessage{content: J'adore la programmation., example: false},
// generationInfo: {index: 0, finish_reason: stop}}]
print(chatRes1.tokensUsage);
print(chatRes1.usage?.totalTokens);
// -> 36
print(chatRes1.modelOutput);
// -> {id: chatcmpl-7QHTjpTCELFuGbxRaazFqvYtepXOc, created: 2023-06-11 17:41:11.000, model: gpt-3.5-turbo}
Expand All @@ -95,6 +95,6 @@ print(chatRes1.modelOutput);
```dart
final chatRes2 = await chat.generatePrompt(ChatPromptValue(messages));
print(chatRes2.generations);
print(chatRes2.tokensUsage);
print(chatRes2.usage);
print(chatRes2.modelOutput);
```
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ currently only implemented for the OpenAI API.
```dart
final openai = OpenAI(apiKey: openaiApiKey, temperature: 0.9);
final result = await openai.generate('Tell me a joke');
print(result.tokensUsage);
-> 24
final usage = result.usage;
print(usage?.promptTokens); // 4
print(usage?.responseTokens); // 20
print(usage?.totalTokens); // 24
```
4 changes: 2 additions & 2 deletions docs/modules/model_io/models/llms/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ print(llmRes.generations.first);
// [LLMGeneration(output='\n\nWhy did the chicken cross the road?\n\nTo get to the other side!')]
```

`tokensUsage` field contains the amount of tokens used for the generation. This is useful for
`usage` field contains the amount of tokens used for the generation. This is useful for
tracking usage and billing.

```dart
print(llmRes.tokensUsage); // 641
print(llmRes.usage?.totalUsage); // 641
```

You can also access provider specific information that is returned.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ChatResult extends LanguageModelResult<ChatMessage> {
/// {@macro chat_result}
const ChatResult({
required super.generations,
super.tokensUsage,
super.usage,
super.modelOutput,
});

Expand All @@ -30,7 +30,7 @@ class ChatResult extends LanguageModelResult<ChatMessage> {
return '''
ChatResult{
generations: $generations,
tokensUsage: $tokensUsage,
usage: $usage,
modelOutput: $modelOutput},
''';
}
Expand Down Expand Up @@ -281,7 +281,8 @@ class FunctionChatMessage extends ChatMessage {
@override
String toString() {
return '''
SystemChatMessage{
FunctionChatMessage{
name: $name,
content: $content,
}
''';
Expand Down
25 changes: 25 additions & 0 deletions packages/langchain/lib/src/model_io/language_models/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,31 @@ abstract class BaseLanguageModel<Input extends Object,
final Options? options,
});

/// Tokenizes the given prompt using the encoding used by the language
/// model.
///
/// - [promptValue] The prompt to tokenize.
Future<List<int>> tokenize(final PromptValue promptValue);

/// Returns the number of tokens resulting from [tokenize] the given prompt.
///
/// Knowing how many tokens are in a text string can tell you:
/// - Whether the string is too long for a text model to process.
/// - How much the API call can costs (as usage is usually priced by token).
///
/// In message-based models the exact way that tokens are counted from
/// messages may change from model to model. Consider the result from this
/// method an estimate, not a timeless guarantee.
///
/// - [promptValue] The prompt to tokenize.
///
/// Note: subclasses can override this method to provide a more accurate
/// implementation.
Future<int> countTokens(final PromptValue promptValue) async {
final tokens = await tokenize(promptValue);
return tokens.length;
}

@override
String toString() => modelType;
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ abstract class LanguageModelResult<O extends Object> {
/// {@macro language_model}
const LanguageModelResult({
required this.generations,
this.tokensUsage,
this.usage,
this.modelOutput,
});

/// Generated outputs.
final List<LanguageModelGeneration<O>> generations;

/// The total number of tokens used for the generation.
final int? tokensUsage;
/// Usage stats for the generation.
final LanguageModelUsage? usage;

/// For arbitrary model provider specific output.
final Map<String, dynamic>? modelOutput;
Expand All @@ -41,6 +41,43 @@ abstract class LanguageModelResult<O extends Object> {
}
}

/// {@template language_model_usage}
/// Usage stats for the generation.
///
/// You can use this information to determine how much the model call costed
/// (as usage is usually priced by token).
///
/// This is only available for some models.
/// {@endtemplate}
@immutable
class LanguageModelUsage {
/// {@macro language_model_usage}
const LanguageModelUsage({
this.promptTokens,
this.responseTokens,
this.totalTokens,
});

/// The number of tokens in the prompt.
final int? promptTokens;

/// The number of tokens in the completion.
final int? responseTokens;

/// The total number of tokens in the prompt and completion.
final int? totalTokens;

@override
String toString() {
return '''
LanguageModelUsage{
promptTokens: $promptTokens,
responseTokens: $responseTokens,
totalTokens: $totalTokens},
''';
}
}

/// {@template language_model_generation}
/// Output of a single generation.
/// {@endtemplate}
Expand Down
19 changes: 19 additions & 0 deletions packages/langchain/lib/src/model_io/llms/fake.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import '../prompts/models/models.dart';
import 'base.dart';
import 'models/models.dart';

Expand All @@ -24,6 +25,15 @@ class FakeListLLM extends SimpleLLM {
}) {
return Future<String>.value(responses[i++ % responses.length]);
}

@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
return promptValue
.toString()
.split(' ')
.map((final word) => word.hashCode)
.toList(growable: false);
}
}

/// {@template fake_echo_llm}
Expand All @@ -44,4 +54,13 @@ class FakeEchoLLM extends SimpleLLM {
}) {
return Future<String>.value(prompt);
}

@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
return promptValue
.toString()
.split(' ')
.map((final word) => word.hashCode)
.toList(growable: false);
}
}
4 changes: 2 additions & 2 deletions packages/langchain/lib/src/model_io/llms/models/models.dart
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class LLMResult extends LanguageModelResult<String> {
/// {@macro llm_result}
const LLMResult({
required super.generations,
super.tokensUsage,
super.usage,
super.modelOutput,
});

Expand All @@ -29,7 +29,7 @@ class LLMResult extends LanguageModelResult<String> {
return '''
LLMResult{
generations: $generations,
tokensUsage: $tokensUsage,
usage: $usage,
modelOutput: $modelOutput},
''';
}
Expand Down
12 changes: 11 additions & 1 deletion packages/langchain/lib/src/model_io/prompts/models/models.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import '../../chat_models/utils.dart';
/// When working with a Chat model, the [toChatMessages] method will be used.
/// {@endtemplate}
@immutable
abstract interface class PromptValue {
sealed class PromptValue {
/// {@macro prompt_value}
const PromptValue();

Expand All @@ -21,6 +21,16 @@ abstract interface class PromptValue {

/// Returns a list of messages representing the prompt.
List<ChatMessage> toChatMessages();

/// {@macro string_prompt_template}
factory PromptValue.string(final String value) {
return StringPromptValue(value);
}

/// {@macro chat_prompt_template}
factory PromptValue.chat(final List<ChatMessage> messages) {
return ChatPromptValue(messages);
}
}

/// {@template string_prompt_template}
Expand Down
13 changes: 12 additions & 1 deletion packages/langchain_openai/lib/src/chat_models/mappers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ extension OpenAIChatCompletionMapper on OpenAIChatCompletion {
generations: choices
.map((final choice) => choice.toChatGeneration())
.toList(growable: false),
tokensUsage: usage.totalTokens,
usage: usage.toLanguageModelUsage(),
modelOutput: {
'id': id,
'created': created,
Expand All @@ -60,6 +60,17 @@ extension _OpenAIChatCompletionChoiceMapper on OpenAIChatCompletionChoice {
}
}

/// Mapper for [OpenAIChatCompletionUsage] to [LanguageModelUsage].
extension _OpenAIChatCompletionUsageMapper on OpenAIChatCompletionUsage {
LanguageModelUsage toLanguageModelUsage() {
return LanguageModelUsage(
promptTokens: promptTokens,
responseTokens: completionTokens,
totalTokens: totalTokens,
);
}
}

/// Mapper for [OpenAIChatCompletionMessage] to [ChatMessage].
extension _OpenAIChatCompletionMessageMapper on OpenAIChatCompletionMessage {
ChatMessage toChatMessage() {
Expand Down
Loading

0 comments on commit e22f22c

Please sign in to comment.