Skip to content

Commit

Permalink
feat: Implement additive options merging for cascade bind calls (#500)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Jul 20, 2024
1 parent 44363e4 commit 8691eb2
Show file tree
Hide file tree
Showing 26 changed files with 733 additions and 138 deletions.
39 changes: 27 additions & 12 deletions packages/langchain_anthropic/lib/src/chat_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,18 @@ class ChatAnthropicOptions extends ChatModelOptions {
/// name, email address, or phone number.
final String? userId;

/// Creates a copy of this [ChatAnthropicOptions] object with the given fields
/// replaced with the new values.
@override
ChatAnthropicOptions copyWith({
String? model,
int? maxTokens,
List<String>? stopSequences,
double? temperature,
int? topK,
double? topP,
String? userId,
List<Tool>? tools,
ChatToolChoice? toolChoice,
int? concurrencyLimit,
final String? model,
final int? maxTokens,
final List<String>? stopSequences,
final double? temperature,
final int? topK,
final double? topP,
final String? userId,
final List<ToolSpec>? tools,
final ChatToolChoice? toolChoice,
final int? concurrencyLimit,
}) {
return ChatAnthropicOptions(
model: model ?? this.model,
Expand All @@ -114,6 +113,22 @@ class ChatAnthropicOptions extends ChatModelOptions {
);
}

@override
ChatAnthropicOptions merge(covariant final ChatAnthropicOptions? other) {
return copyWith(
model: other?.model,
maxTokens: other?.maxTokens,
stopSequences: other?.stopSequences,
temperature: other?.temperature,
topK: other?.topK,
topP: other?.topP,
userId: other?.userId,
tools: other?.tools,
toolChoice: other?.toolChoice,
concurrencyLimit: other?.concurrencyLimit,
);
}

@override
bool operator ==(covariant final ChatAnthropicOptions other) {
return model == other.model &&
Expand Down
50 changes: 50 additions & 0 deletions packages/langchain_community/lib/src/tools/tavily/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,15 @@ class TavilySearchResultsToolOptions extends ToolOptions {
/// {@template tavily_answer_tool_options}
/// Generation options to pass into the [TavilyAnswerTool].
/// {@endtemplate}
@immutable
class TavilyAnswerToolOptions extends ToolOptions {
/// {@macro tavily_answer_tool_options}
const TavilyAnswerToolOptions({
this.maxResults = 5,
this.searchDepth = TavilySearchDepth.basic,
this.includeDomains,
this.excludeDomains,
super.concurrencyLimit,
});

/// The number of maximum search results to return.
Expand All @@ -128,4 +130,52 @@ class TavilyAnswerToolOptions extends ToolOptions {

/// A list of domains to specifically exclude from the search results.
final List<String>? excludeDomains;

@override
TavilyAnswerToolOptions copyWith({
final int? maxResults,
final TavilySearchDepth? searchDepth,
final List<String>? includeDomains,
final List<String>? excludeDomains,
final int? concurrencyLimit,
}) {
return TavilyAnswerToolOptions(
maxResults: maxResults ?? this.maxResults,
searchDepth: searchDepth ?? this.searchDepth,
includeDomains: includeDomains ?? this.includeDomains,
excludeDomains: excludeDomains ?? this.excludeDomains,
concurrencyLimit: concurrencyLimit ?? super.concurrencyLimit,
);
}

@override
TavilyAnswerToolOptions merge(
covariant final TavilyAnswerToolOptions? other,
) {
return copyWith(
maxResults: other?.maxResults,
searchDepth: other?.searchDepth,
includeDomains: other?.includeDomains,
excludeDomains: other?.excludeDomains,
concurrencyLimit: other?.concurrencyLimit,
);
}

@override
bool operator ==(covariant final TavilyAnswerToolOptions other) {
return maxResults == other.maxResults &&
searchDepth == other.searchDepth &&
includeDomains == other.includeDomains &&
excludeDomains == other.excludeDomains &&
concurrencyLimit == other.concurrencyLimit;
}

@override
int get hashCode {
return maxResults.hashCode ^
searchDepth.hashCode ^
includeDomains.hashCode ^
excludeDomains.hashCode ^
concurrencyLimit.hashCode;
}
}
2 changes: 1 addition & 1 deletion packages/langchain_core/lib/src/chains/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import '../langchain/types.dart';
typedef ChainValues = Map<String, dynamic>;

/// {@template chain_options}
/// Options to pass to a chain.
/// Options to pass to the chain.
/// {@endtemplate}
@immutable
class ChainOptions extends BaseLangChainOptions {
Expand Down
179 changes: 159 additions & 20 deletions packages/langchain_core/lib/src/chat_models/fake.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import 'package:collection/collection.dart';

import '../../language_models.dart';
import '../prompts/types.dart';
import 'base.dart';
Expand All @@ -7,11 +9,12 @@ import 'types.dart';
/// Fake Chat Model for testing.
/// You can pass in a list of responses to return in order when called.
/// {@endtemplate}
class FakeChatModel extends SimpleChatModel {
class FakeChatModel extends BaseChatModel<FakeChatModelOptions> {
/// {@macro fake_list_llm}
FakeChatModel({
required this.responses,
}) : super(defaultOptions: const ChatModelOptions());
super.defaultOptions = const FakeChatModelOptions(),
});

/// Responses to return in order when called.
final List<String> responses;
Expand All @@ -22,25 +25,39 @@ class FakeChatModel extends SimpleChatModel {
String get modelType => 'fake-chat-model';

@override
Future<String> callInternal(
final List<ChatMessage> messages, {
final ChatModelOptions? options,
}) {
return Future<String>.value(responses[_i++ % responses.length]);
Future<ChatResult> invoke(
final PromptValue input, {
final FakeChatModelOptions? options,
}) async {
final text = responses[_i++ % responses.length];
final message = AIChatMessage(content: text);
return ChatResult(
id: '1',
output: message,
finishReason: FinishReason.unspecified,
metadata: {
'model': options?.model ?? defaultOptions.model,
...?options?.metadata ?? defaultOptions.metadata,
},
usage: const LanguageModelUsage(),
);
}

@override
Stream<ChatResult> stream(
final PromptValue input, {
final ChatModelOptions? options,
final FakeChatModelOptions? options,
}) {
final res = responses[_i++ % responses.length].split('');
return Stream.fromIterable(res).map(
(final char) => ChatResult(
id: 'fake-chat-model',
output: AIChatMessage(content: char),
finishReason: FinishReason.stop,
metadata: const {},
metadata: {
'model': options?.model ?? defaultOptions.model,
...?options?.metadata ?? defaultOptions.metadata,
},
usage: const LanguageModelUsage(),
streaming: true,
),
Expand All @@ -60,38 +77,107 @@ class FakeChatModel extends SimpleChatModel {
}
}

/// {@template fake_echo_llm}
/// {@template fake_chat_model_options}
/// Fake Chat Model Options for testing.
/// {@endtemplate}
class FakeChatModelOptions extends ChatModelOptions {
/// {@macro fake_chat_model_options}
const FakeChatModelOptions({
super.model,
this.metadata,
super.concurrencyLimit,
});

/// Metadata.
final Map<String, dynamic>? metadata;

@override
FakeChatModelOptions copyWith({
final String? model,
final Map<String, dynamic>? metadata,
final int? concurrencyLimit,
}) {
return FakeChatModelOptions(
model: model ?? this.model,
metadata: metadata ?? this.metadata,
concurrencyLimit: concurrencyLimit ?? this.concurrencyLimit,
);
}

@override
FakeChatModelOptions merge(
covariant final FakeChatModelOptions? other,
) {
return copyWith(
model: other?.model,
metadata: other?.metadata,
concurrencyLimit: other?.concurrencyLimit,
);
}

@override
bool operator ==(covariant final FakeChatModelOptions other) {
return model == other.model &&
const MapEquality<String, dynamic>().equals(metadata, other.metadata) &&
concurrencyLimit == other.concurrencyLimit;
}

@override
int get hashCode {
return model.hashCode ^
const MapEquality<String, dynamic>().hash(metadata) ^
concurrencyLimit.hashCode;
}
}

/// {@template fake_echo_chat_model}
/// Fake Chat Model for testing.
/// It just returns the content of the last message of the prompt
/// or streams it char by char.
/// {@endtemplate}
class FakeEchoChatModel extends SimpleChatModel {
/// {@macro fake_echo_llm}
const FakeEchoChatModel() : super(defaultOptions: const ChatModelOptions());
class FakeEchoChatModel extends BaseChatModel<FakeEchoChatModelOptions> {
/// {@macro fake_echo_chat_model}
const FakeEchoChatModel({
super.defaultOptions = const FakeEchoChatModelOptions(),
});

@override
String get modelType => 'fake-echo-chat-model';

@override
Future<String> callInternal(
final List<ChatMessage> messages, {
final ChatModelOptions? options,
}) {
return Future<String>.value(messages.last.contentAsString);
Future<ChatResult> invoke(
final PromptValue input, {
final FakeEchoChatModelOptions? options,
}) async {
final text = input.toChatMessages().last.contentAsString;
final message = AIChatMessage(content: text);
return ChatResult(
id: '1',
output: message,
finishReason: FinishReason.unspecified,
metadata: {
'model': options?.model ?? defaultOptions.model,
...?options?.metadata ?? defaultOptions.metadata,
},
usage: const LanguageModelUsage(),
);
}

@override
Stream<ChatResult> stream(
final PromptValue input, {
final ChatModelOptions? options,
final FakeEchoChatModelOptions? options,
}) {
final prompt = input.toChatMessages().first.contentAsString.split('');
return Stream.fromIterable(prompt).map(
(final char) => ChatResult(
id: 'fake-echo-chat-model',
output: AIChatMessage(content: char),
finishReason: FinishReason.stop,
metadata: const {},
metadata: {
'model': options?.model ?? defaultOptions.model,
...?options?.metadata ?? defaultOptions.metadata,
},
usage: const LanguageModelUsage(),
streaming: true,
),
Expand All @@ -110,3 +196,56 @@ class FakeEchoChatModel extends SimpleChatModel {
.toList(growable: false);
}
}

/// {@template fake_chat_model_options}
/// Fake Echo Chat Model Options for testing.
/// {@endtemplate}
class FakeEchoChatModelOptions extends ChatModelOptions {
/// {@macro fake_chat_model_options}
const FakeEchoChatModelOptions({
super.model,
this.metadata,
super.concurrencyLimit,
});

/// Metadata.
final Map<String, dynamic>? metadata;

@override
FakeEchoChatModelOptions copyWith({
final String? model,
final Map<String, dynamic>? metadata,
final int? concurrencyLimit,
}) {
return FakeEchoChatModelOptions(
model: model ?? this.model,
metadata: metadata ?? this.metadata,
concurrencyLimit: concurrencyLimit ?? this.concurrencyLimit,
);
}

@override
FakeEchoChatModelOptions merge(
covariant final FakeEchoChatModelOptions? other,
) {
return copyWith(
model: other?.model,
metadata: other?.metadata,
concurrencyLimit: other?.concurrencyLimit,
);
}

@override
bool operator ==(covariant final FakeEchoChatModelOptions other) {
return model == other.model &&
const MapEquality<String, dynamic>().equals(metadata, other.metadata) &&
concurrencyLimit == other.concurrencyLimit;
}

@override
int get hashCode {
return model.hashCode ^
const MapEquality<String, dynamic>().hash(metadata) ^
concurrencyLimit.hashCode;
}
}
3 changes: 2 additions & 1 deletion packages/langchain_core/lib/src/chat_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import '../tools/base.dart';
/// {@template chat_model_options}
/// Generation options to pass into the Chat Model.
/// {@endtemplate}
class ChatModelOptions extends LanguageModelOptions {
@immutable
abstract class ChatModelOptions extends LanguageModelOptions {
/// {@macro chat_model_options}
const ChatModelOptions({
super.model,
Expand Down
2 changes: 1 addition & 1 deletion packages/langchain_core/lib/src/langchain/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import 'package:meta/meta.dart';
import '../runnables/types.dart';

/// {@template base_lang_chain_options}
/// Base class for LangChain components' options.
/// Base options class for LangChain components.
/// {@endtemplate}
@immutable
class BaseLangChainOptions extends RunnableOptions {
Expand Down
Loading

0 comments on commit 8691eb2

Please sign in to comment.