diff --git a/docs/core_docs/docs/integrations/chat/llama_cpp.mdx b/docs/core_docs/docs/integrations/chat/llama_cpp.mdx
index 343c8f5ada95..b1612ed40d16 100644
--- a/docs/core_docs/docs/integrations/chat/llama_cpp.mdx
+++ b/docs/core_docs/docs/integrations/chat/llama_cpp.mdx
@@ -61,7 +61,11 @@ import StreamExample from "@examples/models/chat/integration_llama_cpp_stream.ts
{StreamExample}
-Or you can provide multiple messages, note that this takes the input and then submits a Llama2 formatted prompt to the model.
+Or you can provide multiple messages. Nb. The default is for messages to be submitted in Llama2 format, if you are using a different backend model with `node-llama-cpp` then you can specify the model format to use using the `streamingModel` option. The supported formats are:
+- `llama2` - A Llama2 model, this is the default.
+- `chatML` - A ChatML model
+- `falcon` - A Falcon model
+- `general` - Any other model, this uses `### Human\n`, `### Assistant\n` format
import StreamMultiExample from "@examples/models/chat/integration_llama_cpp_stream_multi.ts";
diff --git a/langchain/src/chat_models/llama_cpp.ts b/langchain/src/chat_models/llama_cpp.ts
index 3df8b0d2a3c9..55c4df71eede 100644
--- a/langchain/src/chat_models/llama_cpp.ts
+++ b/langchain/src/chat_models/llama_cpp.ts
@@ -30,8 +30,16 @@ export interface LlamaCppInputs
export interface LlamaCppCallOptions extends BaseLanguageModelCallOptions {
/** The maximum number of tokens the response should contain. */
maxTokens?: number;
- /** A function called when matching the provided token array */
- onToken?: (tokens: number[]) => void;
+ /** deprecated, This function is not used. */
+ onToken?: (tokens: number[]) => void;
+ /** This lets the streaming method know which model we are using,
+ * valid options are:
+ * 'llama2' - A Llama2 model, this is the default.
+ * 'chatML' - A ChatML model
+ * 'falcon' - A Falcon model
+ * 'general' - Any other model, uses "### Human\n", "### Assistant\n" format
+ */
+ streamingModel?: string;
}
/**
@@ -64,6 +72,8 @@ export class ChatLlamaCpp extends SimpleChatModel {
maxTokens?: number;
+ streamingModel?: string;
+
temperature?: number;
topK?: number;
@@ -84,7 +94,18 @@ export class ChatLlamaCpp extends SimpleChatModel {
constructor(inputs: LlamaCppInputs) {
super(inputs);
+ if (inputs.streamingModel) {
+ if (
+ inputs.streamingModel !== "llama2" &&
+ inputs.streamingModel !== "chatML" &&
+ inputs.streamingModel !== "falcon" &&
+ inputs.streamingModel !== "general"
+ ) {
+ throw new Error("Unknown streaming model specified.");
+ }
+ }
this.maxTokens = inputs?.maxTokens;
+ this.streamingModel = inputs?.streamingModel;
this.temperature = inputs?.temperature;
this.topK = inputs?.topK;
this.topP = inputs?.topP;
@@ -116,7 +137,7 @@ export class ChatLlamaCpp extends SimpleChatModel {
/** @ignore */
async _call(
messages: BaseMessage[],
- options: this["ParsedCallOptions"]
+ _options: this["ParsedCallOptions"]
): Promise {
let prompt = "";
@@ -137,7 +158,6 @@ export class ChatLlamaCpp extends SimpleChatModel {
try {
const promptOptions = {
- onToken: options.onToken,
maxTokens: this?.maxTokens,
temperature: this?.temperature,
topK: this?.topK,
@@ -164,7 +184,6 @@ export class ChatLlamaCpp extends SimpleChatModel {
};
const prompt = this._buildPrompt(input);
-
const stream = await this.caller.call(async () =>
this._context.evaluate(this._context.encode(prompt), promptOptions)
);
@@ -298,11 +317,50 @@ export class ChatLlamaCpp extends SimpleChatModel {
.map((message) => {
let messageText;
if (message._getType() === "human") {
- messageText = `[INST] ${message.content} [/INST]`;
+ switch (this.streamingModel) {
+ case "chatML":
+ messageText = `<|im_start|>user\n${message.content}<|im_end|>`;
+ break;
+ case "falcon":
+ messageText = `User: ${message.content}`;
+ break;
+ case "general":
+ messageText = `### Human\n${message.content}`;
+ break;
+ default:
+ messageText = `[INST] ${message.content} [/INST]`;
+ break;
+ }
} else if (message._getType() === "ai") {
- messageText = message.content;
+ switch (this.streamingModel) {
+ case "chatML":
+ messageText = `<|im_start|>assistant\n${message.content}<|im_end|>`;
+ break;
+ case "falcon":
+ messageText = `Assistant: ${message.content}`;
+ break;
+ case "general":
+ messageText = `### Assistant\n${message.content}`;
+ break;
+ default:
+ messageText = message.content;
+ break;
+ }
} else if (message._getType() === "system") {
- messageText = `<> ${message.content} <>`;
+ switch (this.streamingModel) {
+ case "chatML":
+ messageText = `<|im_start|>system\n${message.content}<|im_end|>`;
+ break;
+ case "falcon":
+ messageText = message.content;
+ break;
+ case "general":
+ messageText = message.content;
+ break;
+ default:
+ messageText = `<> ${message.content} <>`;
+ break;
+ }
} else if (ChatMessage.isInstance(message)) {
messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice(
1