Skip to content

Commit

Permalink
feat: flash attention in model selection (#266)
Browse files Browse the repository at this point in the history
* feat: flash attention in model selection
* fix: adapt to `llama.cpp` breaking changes
* fix: Llama 3 function calling
  • Loading branch information
giladgd authored Jul 8, 2024
1 parent c2e322c commit c35ff5a
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 46 deletions.
35 changes: 9 additions & 26 deletions llama/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,6 @@ static void adjustNapiExternalMemorySubtract(Napi::Env env, uint64_t size) {
}
}

std::string addon_model_token_to_piece(const struct llama_model* model, llama_token token, bool specialTokens) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), specialTokens);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(model, token, result.data(), result.size(), specialTokens);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}

return std::string(result.data(), result.size());
}

#ifdef GPU_INFO_USE_CUDA
void logCudaError(const char* message) {
addonLlamaCppLogCallback(GGML_LOG_LEVEL_ERROR, (std::string("CUDA error: ") + std::string(message)).c_str(), nullptr);
Expand Down Expand Up @@ -395,21 +381,18 @@ class AddonModel : public Napi::ObjectWrap<AddonModel> {
? info[1].As<Napi::Boolean>().Value()
: false;

// Create a stringstream for accumulating the decoded string.
std::stringstream ss;
std::vector<char> result(8, 0);
const int n_length = llama_detokenize(model, (llama_token*)tokens.Data(), tokens.ElementLength(), result.data(), result.size(), false, decodeSpecialTokens);

// Decode each token and accumulate the result.
for (size_t i = 0; i < tokens.ElementLength(); i++) {
const std::string piece = addon_model_token_to_piece(model, (llama_token)tokens[i], decodeSpecialTokens);

if (piece.empty()) {
continue;
}

ss << piece;
if (n_length < 0) {
result.resize(-n_length);
int check = llama_detokenize(model, (llama_token*)tokens.Data(), tokens.ElementLength(), result.data(), result.size(), false, decodeSpecialTokens);
GGML_ASSERT(check == -n_length);
} else {
result.resize(n_length);
}

return Napi::String::New(info.Env(), ss.str());
return Napi::String::New(info.Env(), result.data(), result.size());
}

Napi::Value GetTrainContextSize(const Napi::CallbackInfo& info) {
Expand Down
4 changes: 2 additions & 2 deletions src/chatWrappers/Llama3ChatWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export class Llama3ChatWrapper extends ChatWrapper {
functions: {
call: {
optionalPrefixSpace: true,
prefix: "||call:",
prefix: "||call: ",
paramsPrefix: LlamaText(new SpecialTokensText("(")),
suffix: LlamaText(new SpecialTokensText(")"))
},
Expand Down Expand Up @@ -56,7 +56,7 @@ export class Llama3ChatWrapper extends ChatWrapper {
functions: {
call: {
optionalPrefixSpace: true,
prefix: "||call:",
prefix: "||call: ",
paramsPrefix: LlamaText(new SpecialTokensText("(")),
suffix: LlamaText(new SpecialTokensText(")"))
},
Expand Down
12 changes: 12 additions & 0 deletions src/chatWrappers/generic/JinjaTemplateChatWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,20 @@ import {ChatHistoryFunctionCallMessageTemplate, parseFunctionCallMessageTemplate

export type JinjaTemplateChatWrapperOptions = {
template: string,

/**
* Defaults to `"assistant"`.
*/
modelRoleName?: string,

/**
* Defaults to `"user"`.
*/
userRoleName?: string,

/**
* Defaults to `"system"`.
*/
systemRoleName?: string,

/**
Expand Down
4 changes: 3 additions & 1 deletion src/cli/commands/ChatCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,9 @@ async function RunChat({
});
const logBatchSize = batchSize != null;

const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers);
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
flashAttention
});

if (systemInfo)
console.log(llama.systemInfo);
Expand Down
4 changes: 3 additions & 1 deletion src/cli/commands/CompleteCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ async function RunCompletion({
});
const logBatchSize = batchSize != null;

const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers);
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
flashAttention
});

if (systemInfo)
console.log(llama.systemInfo);
Expand Down
4 changes: 3 additions & 1 deletion src/cli/commands/InfillCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ async function RunInfill({
});
const logBatchSize = batchSize != null;

const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers);
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
flashAttention
});

if (systemInfo)
console.log(llama.systemInfo);
Expand Down
41 changes: 31 additions & 10 deletions src/cli/utils/interactivelyAskForModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ export async function interactivelyAskForModel({
llama,
modelsDirectory,
allowLocalModels = true,
downloadIntent = true
downloadIntent = true,
flashAttention = false
}: {
llama: Llama,
modelsDirectory?: string,
allowLocalModels?: boolean,
downloadIntent?: boolean
downloadIntent?: boolean,
flashAttention?: boolean
}): Promise<string> {
let localModelFileOptions: (ModelOption & { type: "localModel" })[] = [];
const recommendedModelOptions: (ModelOption & { type: "recommendedModel" })[] = [];
Expand Down Expand Up @@ -112,7 +114,9 @@ export async function interactivelyAskForModel({
readItems++;
progressUpdater.setProgress(readItems / ggufFileNames.length, renderProgress());

const compatibilityScore = await ggufInsights?.configurationResolver.scoreModelConfigurationCompatibility();
const compatibilityScore = await ggufInsights?.configurationResolver.scoreModelConfigurationCompatibility({
flashAttention: flashAttention && ggufInsights?.flashAttentionSupported
});

return {
type: "localModel",
Expand Down Expand Up @@ -211,7 +215,7 @@ export async function interactivelyAskForModel({
try {
// eslint-disable-next-line no-constant-condition
while (true) {
const minWidth = Math.min(80, process.stdout.columns - 1);
const minWidth = Math.min(80 + (flashAttention ? 26 : 0), process.stdout.columns - 1);
const selectedItem = await basicChooseFromListConsoleInteraction({
title(item, rerender) {
const title = chalk.bold("Select a model:") + " ";
Expand All @@ -235,6 +239,17 @@ export async function interactivelyAskForModel({
(String(Math.floor((vramState.used / vramState.total) * 100 * 100) / 100) + "%") + " " +
chalk.dim("(" + bytes(vramState.used) + "/" + bytes(vramState.total) + ")") +
" "
) + (
!flashAttention
? ""
: (
" " +
chalk.bgGray(
" " +
chalk.yellow("Flash attention:") + " " + "enabled" +
" "
)
)
)
);

Expand Down Expand Up @@ -273,7 +288,7 @@ export async function interactivelyAskForModel({
},
items: options,
renderItem(item, focused, rerender) {
return renderSelectionItem(item, focused, rerender, activeInteractionController.signal, llama);
return renderSelectionItem(item, focused, rerender, activeInteractionController.signal, llama, flashAttention);
},
canFocusItem(item) {
return item.type === "recommendedModel" || item.type === "localModel" || item.type === "action";
Expand Down Expand Up @@ -374,7 +389,9 @@ async function askForModelUrlOrPath(allowLocalModels: boolean): Promise<string |
);
}

function renderSelectionItem(item: ModelOption, focused: boolean, rerender: () => void, abortSignal: AbortSignal, llama: Llama) {
function renderSelectionItem(
item: ModelOption, focused: boolean, rerender: () => void, abortSignal: AbortSignal, llama: Llama, flashAttention: boolean
) {
if (item.type === "localModel") {
let modelText = item.title instanceof Function
? item.title()
Expand All @@ -398,7 +415,8 @@ function renderSelectionItem(item: ModelOption, focused: boolean, rerender: () =
recommendedModelOption: item,
abortSignal,
rerenderOption: rerender,
llama
llama,
flashAttention
});
}

Expand Down Expand Up @@ -542,12 +560,13 @@ function renderCompatibilityPercentageWithColors(percentage: number, {
}

async function selectFileForModelRecommendation({
recommendedModelOption, llama, abortSignal, rerenderOption
recommendedModelOption, llama, abortSignal, rerenderOption, flashAttention
}: {
recommendedModelOption: ModelOption & { type: "recommendedModel" },
llama: Llama,
abortSignal: AbortSignal,
rerenderOption(): void
rerenderOption(): void,
flashAttention: boolean
}) {
try {
let bestScore: number | undefined = undefined;
Expand All @@ -567,7 +586,9 @@ async function selectFileForModelRecommendation({
if (abortSignal.aborted)
return;

const compatibilityScore = await ggufInsights.configurationResolver.scoreModelConfigurationCompatibility();
const compatibilityScore = await ggufInsights.configurationResolver.scoreModelConfigurationCompatibility({
flashAttention
});

if (bestScore == null || compatibilityScore.compatibilityScore > bestScore) {
bestScore = compatibilityScore.compatibilityScore;
Expand Down
7 changes: 4 additions & 3 deletions src/cli/utils/resolveCommandGgufPath.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import {getReadablePath} from "./getReadablePath.js";
import {interactivelyAskForModel} from "./interactivelyAskForModel.js";

export async function resolveCommandGgufPath(ggufPath: string | undefined, llama: Llama, fetchHeaders?: Record<string, string>, {
targetDirectory = cliModelsDirectory
targetDirectory = cliModelsDirectory, flashAttention = false
}: {
targetDirectory?: string
targetDirectory?: string, flashAttention?: boolean
} = {}) {
let resolvedGgufPath = ggufPath;

Expand All @@ -23,7 +23,8 @@ export async function resolveCommandGgufPath(ggufPath: string | undefined, llama
llama,
modelsDirectory: targetDirectory,
allowLocalModels: true,
downloadIntent: true
downloadIntent: true,
flashAttention
});

if (!isUrl(resolvedGgufPath)) {
Expand Down
55 changes: 53 additions & 2 deletions test/modelDependent/functionary/sanity.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {describe, expect, test} from "vitest";
import {LlamaChatSession} from "../../../src/index.js";
import {LlamaChatSession, SpecialTokensText, LlamaText} from "../../../src/index.js";
import {getModelFile} from "../../utils/modelFiles.js";
import {getTestLlama} from "../../utils/getTestLlama.js";

Expand Down Expand Up @@ -86,7 +86,7 @@ describe("functionary", () => {
`);
});

test("tokenizing text and then detokenizing it arrive at the same text", {timeout: 1000 * 60 * 60 * 2}, async () => {
test("tokenizing a text and then detokenizing it arrives at the same text", {timeout: 1000 * 60 * 60 * 2}, async () => {
const modelPath = await getModelFile("functionary-small-v2.5.Q4_0.gguf");
const llama = await getTestLlama();

Expand Down Expand Up @@ -178,6 +178,57 @@ describe("functionary", () => {
expect(textWithSpecialTokens).to.eql(text);
expect(textNoSpecialTokens).to.eql(text);
}

{
const text = "Hi there";

const tokensWithTrim = model.tokenize(text, false, "trimLeadingSpace");
const tokensWithoutTrim = model.tokenize(text, false);

expect(model.detokenize(tokensWithTrim)).to.eql(text);
expect(model.detokenize(tokensWithoutTrim)).to.eql(text);
}
{
const text = " Hi there";

const tokensWithTrim = model.tokenize(text, false, "trimLeadingSpace");
const tokensWithoutTrim = model.tokenize(text, false);

expect(model.detokenize(tokensWithTrim)).to.eql(text);
expect(model.detokenize(tokensWithoutTrim)).to.eql(text);
}
});

test("tokenizing a LlamaText and then detokenizing it arrives at the same text", {timeout: 1000 * 60 * 60 * 2}, async () => {
const modelPath = await getModelFile("functionary-small-v2.5.Q4_0.gguf");
const llama = await getTestLlama();

const model = await llama.loadModel({
modelPath
});

{
const text = LlamaText([
new SpecialTokensText("<|start_header_id|>system<|end_header_id|>\n\n"),
"How much is 6+6\n"
]);

const tokens = text.tokenize(model.tokenizer);

expect(model.detokenize(tokens, true)).to.eql("<|start_header_id|>system<|end_header_id|>\n\nHow much is 6+6\n");
expect(model.detokenize(tokens, false)).to.eql("system\n\nHow much is 6+6\n");
}
{
const text = LlamaText([
new SpecialTokensText("Hi <|start_header_id|>there\n\n"),
"How much is 6+6\n"
]);

const tokens = text.tokenize(model.tokenizer);

expect(model.detokenize(tokens, true)).to.eql("Hi <|start_header_id|>there\n\nHow much is 6+6\n");
expect(model.detokenize(tokens, false)).to.eql("Hi there\n\nHow much is 6+6\n");
}
});
});
});

0 comments on commit c35ff5a

Please sign in to comment.