-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added multiple model prompt formats to llama_cpp chat streaming. #3588
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,8 +30,16 @@ export interface LlamaCppInputs | |
export interface LlamaCppCallOptions extends BaseLanguageModelCallOptions { | ||
/** The maximum number of tokens the response should contain. */ | ||
maxTokens?: number; | ||
/** A function called when matching the provided token array */ | ||
onToken?: (tokens: number[]) => void; | ||
/** deprecated, This function is not used. */ | ||
onToken?: (tokens: number[]) => void; | ||
/** This lets the streaming method know which model we are using, | ||
* valid options are: | ||
* 'llama2' - A Llama2 model, this is the default. | ||
* 'chatML' - A ChatML model | ||
* 'falcon' - A Falcon model | ||
* 'general' - Any other model, uses "### Human\n", "### Assistant\n" format | ||
*/ | ||
streamingModel?: string; | ||
} | ||
|
||
/** | ||
|
@@ -64,6 +72,8 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> { | |
|
||
maxTokens?: number; | ||
|
||
streamingModel?: string; | ||
|
||
temperature?: number; | ||
|
||
topK?: number; | ||
|
@@ -84,7 +94,18 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> { | |
|
||
constructor(inputs: LlamaCppInputs) { | ||
super(inputs); | ||
if (inputs.streamingModel) { | ||
if ( | ||
inputs.streamingModel !== "llama2" && | ||
inputs.streamingModel !== "chatML" && | ||
inputs.streamingModel !== "falcon" && | ||
inputs.streamingModel !== "general" | ||
) { | ||
throw new Error("Unknown streaming model specified."); | ||
} | ||
} | ||
this.maxTokens = inputs?.maxTokens; | ||
this.streamingModel = inputs?.streamingModel; | ||
this.temperature = inputs?.temperature; | ||
this.topK = inputs?.topK; | ||
this.topP = inputs?.topP; | ||
|
@@ -116,7 +137,7 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> { | |
/** @ignore */ | ||
async _call( | ||
messages: BaseMessage[], | ||
options: this["ParsedCallOptions"] | ||
_options: this["ParsedCallOptions"] | ||
): Promise<string> { | ||
let prompt = ""; | ||
|
||
|
@@ -137,7 +158,6 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> { | |
|
||
try { | ||
const promptOptions = { | ||
onToken: options.onToken, | ||
maxTokens: this?.maxTokens, | ||
temperature: this?.temperature, | ||
topK: this?.topK, | ||
|
@@ -164,7 +184,6 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> { | |
}; | ||
|
||
const prompt = this._buildPrompt(input); | ||
|
||
const stream = await this.caller.call(async () => | ||
this._context.evaluate(this._context.encode(prompt), promptOptions) | ||
); | ||
|
@@ -298,11 +317,50 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> { | |
.map((message) => { | ||
let messageText; | ||
if (message._getType() === "human") { | ||
messageText = `[INST] ${message.content} [/INST]`; | ||
switch (this.streamingModel) { | ||
case "chatML": | ||
messageText = `<|im_start|>user\n${message.content}<|im_end|>`; | ||
break; | ||
case "falcon": | ||
messageText = `User: ${message.content}`; | ||
break; | ||
case "general": | ||
messageText = `### Human\n${message.content}`; | ||
break; | ||
default: | ||
messageText = `[INST] ${message.content} [/INST]`; | ||
break; | ||
} | ||
} else if (message._getType() === "ai") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The next version will support the generation of a response based on chat histories that can have "system", "user" and "model" roles, so I think most of this code could be replaced by the I predict that the next beta that includes this functionality will be ready in about ~2 weeks if you're willing to wait for it, but in any case, I think this better be handled by |
||
messageText = message.content; | ||
switch (this.streamingModel) { | ||
case "chatML": | ||
messageText = `<|im_start|>assistant\n${message.content}<|im_end|>`; | ||
break; | ||
case "falcon": | ||
messageText = `Assistant: ${message.content}`; | ||
break; | ||
case "general": | ||
messageText = `### Assistant\n${message.content}`; | ||
break; | ||
default: | ||
messageText = message.content; | ||
break; | ||
} | ||
} else if (message._getType() === "system") { | ||
messageText = `<<SYS>> ${message.content} <</SYS>>`; | ||
switch (this.streamingModel) { | ||
case "chatML": | ||
messageText = `<|im_start|>system\n${message.content}<|im_end|>`; | ||
break; | ||
case "falcon": | ||
messageText = message.content; | ||
break; | ||
case "general": | ||
messageText = message.content; | ||
break; | ||
default: | ||
messageText = `<<SYS>> ${message.content} <</SYS>>`; | ||
break; | ||
} | ||
} else if (ChatMessage.isInstance(message)) { | ||
messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice( | ||
1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nigel-daniels Why aren't you using a
LlamaChatSession
with the relevantChatPromptWrapper
for it?For example, you have a
ChatPromptWrapper
for ChatML underChatMLChatPromptWrapper
It's better to only have one centralized place for the handling of chat formats, especially as more will get supported in
node-llama-cpp
over time.Furthermore, in the next major version of
node-llama-cpp
it'll figure out the right chat wrapper by default, so it won't require a manual config most of the time.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because the call for streaming uses the
context.evaluate()
method:This returned a generator function where I can use
yield
to pass back the tokens as they arrive.