Skip to content

Commit

Permalink
feat(memory): Add support for ConversationTokenBufferMemory (#26)
Browse files Browse the repository at this point in the history
Co-authored-by: David Miguel <[email protected]>
  • Loading branch information
a-mpch and davidmigloz authored Aug 4, 2023
1 parent 0be06e0 commit 8113d1c
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 3 deletions.
10 changes: 10 additions & 0 deletions packages/langchain/lib/src/memory/stores/message/history.dart
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ abstract base class BaseChatMessageHistory {
return addChatMessage(ChatMessage.ai(message));
}

/// Removes and returns the first (oldest) element of the history.
///
/// The history must not be empty when this method is called.
Future<ChatMessage> removeFirst();

/// Removes and returns the last (newest) element of the history.
///
/// The history must not be empty when this method is called.
Future<ChatMessage> removeLast();

/// Clear the history.
Future<void> clear();
}
18 changes: 15 additions & 3 deletions packages/langchain/lib/src/memory/stores/message/in_memory.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import 'dart:collection';

import '../../../model_io/chat_models/chat_models.dart';
import 'history.dart';

Expand All @@ -8,20 +10,30 @@ import 'history.dart';
final class ChatMessageHistory extends BaseChatMessageHistory {
/// {@macro chat_message_history}
ChatMessageHistory({final List<ChatMessage>? messages})
: _messages = messages ?? [];
: _messages = Queue.from(messages ?? <ChatMessage>[]);

final List<ChatMessage> _messages;
final Queue<ChatMessage> _messages;

@override
Future<List<ChatMessage>> getChatMessages() {
return Future.value(_messages);
return Future.value(_messages.toList(growable: false));
}

@override
Future<void> addChatMessage(final ChatMessage message) async {
_messages.add(message);
}

@override
Future<ChatMessage> removeFirst() {
return Future.value(_messages.removeFirst());
}

@override
Future<ChatMessage> removeLast() {
return Future.value(_messages.removeLast());
}

@override
Future<void> clear() async {
_messages.clear();
Expand Down
100 changes: 100 additions & 0 deletions packages/langchain/lib/src/memory/token_buffer.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import '../model_io/chat_models/models/models.dart';
import '../model_io/chat_models/utils.dart';
import '../model_io/language_models/language_models.dart';
import '../model_io/prompts/prompts.dart';
import 'buffer_window.dart';
import 'chat.dart';
import 'models/models.dart';
import 'stores/message/in_memory.dart';

/// {@template conversation_token_buffer_memory}
/// Rolling buffer for storing a conversation and then retrieving the messages
/// at a later time.
///
/// It uses token length (rather than number of interactions like
/// [ConversationBufferWindowMemory]) to determine when to flush old
/// interactions from the buffer. This allows it to keep more context while
/// staying under a max token limit.
///
/// It uses [ChatMessageHistory] as in-memory storage by default.
///
/// Example:
/// ```dart
/// final memory = ConversationTokenBufferMemory(llm: OpenAI(apiKey: '...'));
/// await memory.saveContext({'foo': 'bar'}, {'bar': 'foo'});
/// final res = await memory.loadMemoryVariables();
/// // {'history': 'Human: bar\nAI: foo'}
/// ```
/// {@endtemplate}
final class ConversationTokenBufferMemory<
LLMInput extends Object,
LLMOptions extends LanguageModelOptions,
LLMOutput extends Object> extends BaseChatMemory {
/// {@macro conversation_token_buffer_memory}
ConversationTokenBufferMemory({
super.chatHistory,
super.inputKey,
super.outputKey,
super.returnMessages = false,
required this.llm,
this.humanPrefix = 'Human',
this.aiPrefix = 'AI',
this.memoryKey = 'history',
this.maxTokenLimit = 2000,
});

/// Language model to use for counting tokens.
final BaseLanguageModel<LLMInput, LLMOptions, LLMOutput> llm;

/// The prefix to use for human messages.
final String humanPrefix;

/// The prefix to use for AI messages.
final String aiPrefix;

/// The memory key to use for the chat history.
final String memoryKey;

/// Max number of tokens to use.
final int maxTokenLimit;

@override
Set<String> get memoryKeys => {memoryKey};

@override
Future<MemoryVariables> loadMemoryVariables([
final MemoryInputValues values = const {},
]) async {
final messages = await chatHistory.getChatMessages();
if (returnMessages) {
return {memoryKey: messages};
}
return {
memoryKey: messages.toBufferString(
humanPrefix: humanPrefix,
aiPrefix: aiPrefix,
),
};
}

@override
Future<void> saveContext({
required final MemoryInputValues inputValues,
required final MemoryOutputValues outputValues,
}) async {
await super.saveContext(
inputValues: inputValues,
outputValues: outputValues,
);
List<ChatMessage> buffer = await chatHistory.getChatMessages();
int currentBufferLength = await llm.countTokens(PromptValue.chat(buffer));
// Prune buffer if it exceeds max token limit
if (currentBufferLength > maxTokenLimit) {
while (currentBufferLength > maxTokenLimit) {
await chatHistory.removeFirst();
buffer = await chatHistory.getChatMessages();
currentBufferLength = await llm.countTokens(PromptValue.chat(buffer));
}
}
}
}
16 changes: 16 additions & 0 deletions packages/langchain/test/memory/stores/message/in_memory_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ void main() {
expect(messages.first.content, 'This is an AI msg');
});

test('Test removeOldestMessage', () async {
final history = ChatMessageHistory();
final message = ChatMessage.human('This is a test');
final message2 = ChatMessage.ai('This is an AI msg');
history
..addChatMessage(message)
..addChatMessage(message2);
final oldestMessage = await history.removeOldestMessage();
expect(oldestMessage, isA<HumanChatMessage>());
expect(oldestMessage.content, 'This is a test');
final messages = await history.getChatMessages();
expect(messages.length, 1);
expect(messages.first, isA<AIChatMessage>());
expect(messages.first.content, 'This is an AI msg');
});

test('Test clear', () async {
final history = ChatMessageHistory();
final message = ChatMessage.human('This is a test');
Expand Down
90 changes: 90 additions & 0 deletions packages/langchain/test/memory/token_buffer_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import 'package:langchain/src/memory/memory.dart';
import 'package:langchain/src/memory/token_buffer.dart';
import 'package:langchain/src/model_io/chat_models/chat_models.dart';
import 'package:langchain/src/model_io/llms/fake.dart';
import 'package:test/test.dart';

void main() {
group('ConversationTokenBufferMemory tests', () {
test('Test buffer memory', () async {
const model = FakeEchoLLM();
final memory = ConversationTokenBufferMemory(llm: model);
final result1 = await memory.loadMemoryVariables();
expect(result1, {'history': ''});

await memory.saveContext(
inputValues: {'foo': 'bar'},
outputValues: {'bar': 'foo'},
);
const expectedString = 'Human: bar\nAI: foo';
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': expectedString});
});

test('Test buffer memory return messages', () async {
const model = FakeEchoLLM();
final memory = ConversationTokenBufferMemory(
llm: model,
returnMessages: true,
maxTokenLimit: 4,
);
final result1 = await memory.loadMemoryVariables();
expect(result1, {'history': <ChatMessage>[]});

await memory.saveContext(
inputValues: {'foo': 'bar'},
outputValues: {'bar': 'foo'},
);
final expectedResult = [
ChatMessage.human('bar'),
ChatMessage.ai('foo'),
];
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': expectedResult});

await memory.saveContext(
inputValues: {'foo': 'bar1'},
outputValues: {'bar': 'foo1'},
);

final expectedResult2 = [
ChatMessage.ai('foo'),
ChatMessage.human('bar1'),
ChatMessage.ai('foo1'),
];
final result3 = await memory.loadMemoryVariables();
expect(result3, {'history': expectedResult2});
});

test('Test buffer memory with pre-loaded history', () async {
final pastMessages = [
ChatMessage.human("My name's Jonas"),
ChatMessage.ai('Nice to meet you, Jonas!'),
];
const model = FakeEchoLLM();
final memory = ConversationTokenBufferMemory(
llm: model,
maxTokenLimit: 3,
returnMessages: true,
chatHistory: ChatMessageHistory(messages: pastMessages),
);
final result = await memory.loadMemoryVariables();
expect(result, {'history': pastMessages});
});

test('Test clear memory', () async {
final memory = ConversationBufferMemory();
await memory.saveContext(
inputValues: {'foo': 'bar'},
outputValues: {'bar': 'foo'},
);
const expectedString = 'Human: bar\nAI: foo';
final result1 = await memory.loadMemoryVariables();
expect(result1, {'history': expectedString});

memory.clear();
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': ''});
});
});
}

0 comments on commit 8113d1c

Please sign in to comment.