Skip to content

Commit

Permalink
feat!: Support different logic for streaming in RunnableFunction (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored Apr 30, 2024
1 parent a2b6bbb commit 8bb2b8e
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 11 deletions.
58 changes: 54 additions & 4 deletions packages/langchain_core/lib/src/runnables/function.dart
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,34 @@ import 'types.dart';
class RunnableFunction<RunInput extends Object, RunOutput extends Object>
extends Runnable<RunInput, RunnableOptions, RunOutput> {
/// {@macro runnable_function}
const RunnableFunction(this.function)
: super(defaultOptions: const RunnableOptions());
const RunnableFunction({
final FutureOr<RunOutput> Function(
RunInput input,
RunnableOptions? options,
)? invoke,
final Stream<RunOutput> Function(
Stream<RunInput> inputStream,
RunnableOptions? options,
)? stream,
super.defaultOptions = const RunnableOptions(),
}) : _invokeFunc = invoke,
_streamFunc = stream,
assert(
invoke != null || stream != null,
'Either invoke or stream must be provided',
);

/// The function to run.
final FutureOr<RunOutput> Function(
RunInput input,
RunnableOptions? options,
) function;
)? _invokeFunc;

/// The stream transformer to run.
final Stream<RunOutput> Function(
Stream<RunInput> inputStream,
RunnableOptions? options,
)? _streamFunc;

/// Invokes the [RunnableFunction] on the given [input].
///
Expand All @@ -68,6 +88,36 @@ class RunnableFunction<RunInput extends Object, RunOutput extends Object>
final RunInput input, {
final RunnableOptions? options,
}) async {
return function(input, options);
if (_invokeFunc != null) {
return _invokeFunc!(input, options);
} else {
return stream(input, options: options).first;
}
}

/// Streams the [input] through the [RunnableFunction].
///
/// - [input] - the input to stream through the [RunnableFunction].
/// - [options] - the options to use when streaming the [input].
@override
Stream<RunOutput> stream(
final RunInput input, {
final RunnableOptions? options,
}) {
return streamFromInputStream(Stream.value(input), options: options);
}

@override
Stream<RunOutput> streamFromInputStream(
final Stream<RunInput> inputStream, {
final RunnableOptions? options,
}) async* {
if (_streamFunc != null) {
yield* _streamFunc!(inputStream, options);
} else {
yield* inputStream.asyncMap((final input) async {
return invoke(input, options: options);
});
}
}
}
15 changes: 11 additions & 4 deletions packages/langchain_core/lib/src/runnables/runnable.dart
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,20 @@ abstract class Runnable<RunInput extends Object?,
///
/// - [function] - the function to run.
static Runnable<RunInput, RunnableOptions, RunOutput>
fromFunction<RunInput extends Object, RunOutput extends Object>(
fromFunction<RunInput extends Object, RunOutput extends Object>({
final FutureOr<RunOutput> Function(
RunInput input,
RunnableOptions? options,
) function,
) {
return RunnableFunction<RunInput, RunOutput>(function);
)? invoke,
final Stream<RunOutput> Function(
Stream<RunInput> inputStream,
RunnableOptions? options,
)? stream,
}) {
return RunnableFunction<RunInput, RunOutput>(
invoke: invoke,
stream: stream,
);
}

/// Creates a [RunnableRouter] from a Dart function.
Expand Down
36 changes: 33 additions & 3 deletions packages/langchain_core/test/runnables/function_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ import 'package:test/test.dart';

void main() {
group('RunnableFunction tests', () {
test('RunnableFunction from Runnable.fromFunction', () async {
test('Invoke RunnableFunction', () async {
final prompt = PromptTemplate.fromTemplate('Hello {input}!');
const model = FakeEchoChatModel();
const outputParser = StringOutputParser<ChatResult>();
final chain = prompt |
model |
outputParser |
Runnable.fromFunction<String, int>(
(final input, final options) => input.length,
invoke: (final input, final options) => input.length,
);

final res = await chain.invoke({'input': 'world'});
Expand All @@ -24,7 +24,7 @@ void main() {

test('Streaming RunnableFunction', () async {
final function = Runnable.fromFunction<String, int>(
(final input, final options) => input.length,
invoke: (final input, final options) => input.length,
);
final stream = function.stream('world');

Expand All @@ -35,5 +35,35 @@ void main() {
final item = streamList.first;
expect(item, 5);
});

test('Streaming input RunnableFunction', () async {
final function = Runnable.fromFunction<String, int>(
invoke: (final input, final options) => input.length,
);
final stream = function.streamFromInputStream(
Stream.fromIterable(['w', 'o', 'r', 'l', 'd']),
);

final streamList = await stream.toList();
expect(streamList.length, 5);
expect(streamList, [1, 1, 1, 1, 1]);
});

test('Separate logic for invoke and stream', () async {
final function = Runnable.fromFunction<String, int>(
invoke: (final input, final options) => input.length,
stream: (final inputStream, final options) async* {
final input = (await inputStream.toList()).reduce((a, b) => a + b);
yield input.length;
},
);

final invokeRes = await function.invoke('world');
expect(invokeRes, 5);
final streamRes = await function
.streamFromInputStream(Stream.fromIterable(['w', 'o', 'r', 'l', 'd']))
.toList();
expect(streamRes, [5]);
});
});
}

0 comments on commit 8bb2b8e

Please sign in to comment.