diff --git a/docs/core_docs/docs/integrations/chat/llama_cpp.mdx b/docs/core_docs/docs/integrations/chat/llama_cpp.mdx index 343c8f5ada95..b1612ed40d16 100644 --- a/docs/core_docs/docs/integrations/chat/llama_cpp.mdx +++ b/docs/core_docs/docs/integrations/chat/llama_cpp.mdx @@ -61,7 +61,11 @@ import StreamExample from "@examples/models/chat/integration_llama_cpp_stream.ts {StreamExample} -Or you can provide multiple messages, note that this takes the input and then submits a Llama2 formatted prompt to the model. +Or you can provide multiple messages. Nb. The default is for messages to be submitted in Llama2 format, if you are using a different backend model with `node-llama-cpp` then you can specify the model format to use using the `streamingModel` option. The supported formats are: +- `llama2` - A Llama2 model, this is the default. +- `chatML` - A ChatML model +- `falcon` - A Falcon model +- `general` - Any other model, this uses `### Human\n`, `### Assistant\n` format import StreamMultiExample from "@examples/models/chat/integration_llama_cpp_stream_multi.ts"; diff --git a/langchain/src/chat_models/llama_cpp.ts b/langchain/src/chat_models/llama_cpp.ts index 3df8b0d2a3c9..55c4df71eede 100644 --- a/langchain/src/chat_models/llama_cpp.ts +++ b/langchain/src/chat_models/llama_cpp.ts @@ -30,8 +30,16 @@ export interface LlamaCppInputs export interface LlamaCppCallOptions extends BaseLanguageModelCallOptions { /** The maximum number of tokens the response should contain. */ maxTokens?: number; - /** A function called when matching the provided token array */ - onToken?: (tokens: number[]) => void; + /** deprecated, This function is not used. */ + onToken?: (tokens: number[]) => void; + /** This lets the streaming method know which model we are using, + * valid options are: + * 'llama2' - A Llama2 model, this is the default. + * 'chatML' - A ChatML model + * 'falcon' - A Falcon model + * 'general' - Any other model, uses "### Human\n", "### Assistant\n" format + */ + streamingModel?: string; } /** @@ -64,6 +72,8 @@ export class ChatLlamaCpp extends SimpleChatModel { maxTokens?: number; + streamingModel?: string; + temperature?: number; topK?: number; @@ -84,7 +94,18 @@ export class ChatLlamaCpp extends SimpleChatModel { constructor(inputs: LlamaCppInputs) { super(inputs); + if (inputs.streamingModel) { + if ( + inputs.streamingModel !== "llama2" && + inputs.streamingModel !== "chatML" && + inputs.streamingModel !== "falcon" && + inputs.streamingModel !== "general" + ) { + throw new Error("Unknown streaming model specified."); + } + } this.maxTokens = inputs?.maxTokens; + this.streamingModel = inputs?.streamingModel; this.temperature = inputs?.temperature; this.topK = inputs?.topK; this.topP = inputs?.topP; @@ -116,7 +137,7 @@ export class ChatLlamaCpp extends SimpleChatModel { /** @ignore */ async _call( messages: BaseMessage[], - options: this["ParsedCallOptions"] + _options: this["ParsedCallOptions"] ): Promise { let prompt = ""; @@ -137,7 +158,6 @@ export class ChatLlamaCpp extends SimpleChatModel { try { const promptOptions = { - onToken: options.onToken, maxTokens: this?.maxTokens, temperature: this?.temperature, topK: this?.topK, @@ -164,7 +184,6 @@ export class ChatLlamaCpp extends SimpleChatModel { }; const prompt = this._buildPrompt(input); - const stream = await this.caller.call(async () => this._context.evaluate(this._context.encode(prompt), promptOptions) ); @@ -298,11 +317,50 @@ export class ChatLlamaCpp extends SimpleChatModel { .map((message) => { let messageText; if (message._getType() === "human") { - messageText = `[INST] ${message.content} [/INST]`; + switch (this.streamingModel) { + case "chatML": + messageText = `<|im_start|>user\n${message.content}<|im_end|>`; + break; + case "falcon": + messageText = `User: ${message.content}`; + break; + case "general": + messageText = `### Human\n${message.content}`; + break; + default: + messageText = `[INST] ${message.content} [/INST]`; + break; + } } else if (message._getType() === "ai") { - messageText = message.content; + switch (this.streamingModel) { + case "chatML": + messageText = `<|im_start|>assistant\n${message.content}<|im_end|>`; + break; + case "falcon": + messageText = `Assistant: ${message.content}`; + break; + case "general": + messageText = `### Assistant\n${message.content}`; + break; + default: + messageText = message.content; + break; + } } else if (message._getType() === "system") { - messageText = `<> ${message.content} <>`; + switch (this.streamingModel) { + case "chatML": + messageText = `<|im_start|>system\n${message.content}<|im_end|>`; + break; + case "falcon": + messageText = message.content; + break; + case "general": + messageText = message.content; + break; + default: + messageText = `<> ${message.content} <>`; + break; + } } else if (ChatMessage.isInstance(message)) { messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice( 1