Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain[minor]: Added multi-message streaming to llama_cpp #3463

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/core_docs/docs/integrations/chat/llama_cpp.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,14 @@ import ChainExample from "@examples/models/chat/integration_llama_cpp_chain.ts";

### Streaming

We can also stream with Llama CPP:
We can also stream with Llama CPP, this can be using a raw 'single prompt' string:

import StreamExample from "@examples/models/chat/integration_llama_cpp_stream.ts";

<CodeBlock language="typescript">{StreamExample}</CodeBlock>

Or you can provide multiple messages, note that this takes the input and then submits a Llama2 formatted prompt to the model.

import StreamMultiExample from "@examples/models/chat/integration_llama_cpp_stream_multi.ts";

<CodeBlock language="typescript">{StreamMultiExample}</CodeBlock>
4 changes: 1 addition & 3 deletions examples/src/models/chat/integration_llama_cpp_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.7 });

const stream = await model.stream([
["human", "Tell me a short story about a happy Llama."],
]);
const stream = await model.stream("Tell me a short story about a happy Llama.");

for await (const chunk of stream) {
console.log(chunk.content);
Expand Down
43 changes: 43 additions & 0 deletions examples/src/models/chat/integration_llama_cpp_stream_multi.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { ChatLlamaCpp } from "langchain/chat_models/llama_cpp";
import { SystemMessage, HumanMessage } from "langchain/schema";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.7 });

const stream = await llamaCpp.stream([
new SystemMessage(
"You are a pirate, responses must be very verbose and in pirate dialect."
),
new HumanMessage("Tell me about Llamas?"),
]);

for await (const chunk of stream) {
console.log(chunk.content);
}

/*

Ar
rr
r
,
me
heart
y
!

Ye
be
ask
in
'
about
llam
as
,
e
h
?
...
*/
73 changes: 47 additions & 26 deletions langchain/src/chat_models/llama_cpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
BaseMessage,
ChatGenerationChunk,
AIMessageChunk,
ChatMessage,
} from "../schema/index.js";

/**
Expand Down Expand Up @@ -155,34 +156,27 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
_options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (input.length !== 1) {
throw new Error("Only one human message should be provided.");
} else {
const promptOptions = {
temperature: this?.temperature,
topK: this?.topK,
topP: this?.topP,
};
const promptOptions = {
temperature: this?.temperature,
topK: this?.topK,
topP: this?.topP,
};

const stream = await this.caller.call(async () =>
this._context.evaluate(
this._context.encode(`${input[0].content}`),
promptOptions
)
);
const prompt = this._buildPrompt(input);

for await (const chunk of stream) {
yield new ChatGenerationChunk({
text: this._context.decode([chunk]),
message: new AIMessageChunk({
content: this._context.decode([chunk]),
}),
generationInfo: {},
});
await runManager?.handleLLMNewToken(
this._context.decode([chunk]) ?? ""
);
}
const stream = await this.caller.call(async () =>
this._context.evaluate(this._context.encode(prompt), promptOptions)
);

for await (const chunk of stream) {
yield new ChatGenerationChunk({
text: this._context.decode([chunk]),
message: new AIMessageChunk({
content: this._context.decode([chunk]),
}),
generationInfo: {},
});
await runManager?.handleLLMNewToken(this._context.decode([chunk]) ?? "");
}
}

Expand Down Expand Up @@ -297,4 +291,31 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {

return result;
}

protected _buildPrompt(input: BaseMessage[]): string {
const prompt = input
.map((message) => {
let messageText;
if (message._getType() === "human") {
messageText = `[INST] ${message.content} [/INST]`;
} else if (message._getType() === "ai") {
messageText = message.content;
} else if (message._getType() === "system") {
messageText = `<<SYS>> ${message.content} <</SYS>>`;
} else if (ChatMessage.isInstance(message)) {
messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice(
1
)}: ${message.content}`;
} else {
console.warn(
`Unsupported message type passed to llama_cpp: "${message._getType()}"`
);
messageText = "";
}
return messageText;
})
.join("\n");

return prompt;
}
}
23 changes: 21 additions & 2 deletions langchain/src/chat_models/tests/chatllama_cpp.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,33 @@ test.skip("Test chain with memory", async () => {
test.skip("test streaming call", async () => {
const llamaCpp = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.7 });

const stream = await llamaCpp.stream(
"Tell me a short story about a happy Llama."
);

const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk.content);
console.log(chunk.content);
}

expect(chunks.length).toBeGreaterThan(1);
});

test.skip("test multi-mesage streaming call", async () => {
const llamaCpp = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.7 });

const stream = await llamaCpp.stream([
["human", "Tell me a short story about a happy Llama."],
new SystemMessage(
"You are a pirate, responses must be very verbose and in pirate dialect."
),
new HumanMessage("Tell me about Llamas?"),
]);

const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk.content);
process.stdout.write(chunks.join(""));
console.log(chunk.content);
}

expect(chunks.length).toBeGreaterThan(1);
Expand Down