Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: flash attention #264

Merged
merged 3 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llama/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,10 @@ class AddonContext : public Napi::ObjectWrap<AddonContext> {
context_params.embeddings = options.Get("embeddings").As<Napi::Boolean>().Value();
}

if (options.Has("flashAttention")) {
context_params.flash_attn = options.Get("flashAttention").As<Napi::Boolean>().Value();
}

if (options.Has("threads")) {
const auto n_threads = options.Get("threads").As<Napi::Number>().Uint32Value();
const auto resolved_n_threads = n_threads == 0 ? std::thread::hardware_concurrency() : n_threads;
Expand Down
1 change: 1 addition & 0 deletions src/bindings/AddonTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export type BindingModule = {
contextSize?: number,
batchSize?: number,
sequences?: number,
flashAttention?: boolean,
logitsAll?: boolean,
embeddings?: boolean,
threads?: number
Expand Down
9 changes: 6 additions & 3 deletions src/bindings/Llama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {GbnfJsonSchema} from "../utils/gbnfJson/types.js";
import {LlamaJsonSchemaGrammar} from "../evaluator/LlamaJsonSchemaGrammar.js";
import {LlamaGrammar, LlamaGrammarOptions} from "../evaluator/LlamaGrammar.js";
import {BindingModule} from "./AddonTypes.js";
import {BuildGpu, BuildMetadataFile, LlamaLocks, LlamaLogLevel} from "./types.js";
import {BuildGpu, BuildMetadataFile, LlamaGpuType, LlamaLocks, LlamaLogLevel} from "./types.js";
import {MemoryOrchestrator, MemoryReservation} from "./utils/MemoryOrchestrator.js";

const LlamaLogLevelToAddonLogLevel: ReadonlyMap<LlamaLogLevel, number> = new Map([
Expand All @@ -31,7 +31,7 @@ export class Llama {
/** @internal */ public readonly _vramOrchestrator: MemoryOrchestrator;
/** @internal */ public readonly _vramPadding: MemoryReservation;
/** @internal */ public readonly _debug: boolean;
/** @internal */ private readonly _gpu: BuildGpu;
/** @internal */ private readonly _gpu: LlamaGpuType;
/** @internal */ private readonly _buildType: "localBuild" | "prebuilt";
/** @internal */ private readonly _cmakeOptions: Readonly<Record<string, string>>;
/** @internal */ private readonly _supportsGpuOffloading: boolean;
Expand Down Expand Up @@ -244,7 +244,10 @@ export class Llama {
await this._bindings.init();
}

/** @internal */
/**
* Log messages related to the Llama instance
* @internal
*/
public _log(level: LlamaLogLevel, message: string) {
this._onAddonLog(LlamaLogLevelToAddonLogLevel.get(level) ?? defaultLogLevel, message + "\n");
}
Expand Down
10 changes: 8 additions & 2 deletions src/bindings/getLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
} from "./utils/compileLLamaCpp.js";
import {getLastBuildInfo} from "./utils/lastBuildInfo.js";
import {getClonedLlamaCppRepoReleaseInfo, isLlamaCppRepoCloned} from "./utils/cloneLlamaCppRepo.js";
import {BuildGpu, BuildMetadataFile, BuildOptions, LlamaLogLevel} from "./types.js";
import {BuildGpu, BuildMetadataFile, BuildOptions, LlamaGpuType, LlamaLogLevel} from "./types.js";
import {BinaryPlatform, getPlatform} from "./utils/getPlatform.js";
import {getBuildFolderNameForBuildOptions} from "./utils/getBuildFolderNameForBuildOptions.js";
import {resolveCustomCmakeOptions} from "./utils/resolveCustomCmakeOptions.js";
Expand Down Expand Up @@ -46,7 +46,10 @@ export type LlamaOptions = {
*
* `"auto"` by default.
*/
gpu?: "auto" | "metal" | "cuda" | "vulkan" | false,
gpu?: "auto" | LlamaGpuType | {
type: "auto",
exclude?: LlamaGpuType[]
},

/**
* Set the minimum log level for llama.cpp.
Expand Down Expand Up @@ -298,6 +301,9 @@ export async function getLlamaForOptions({
}
}

if (buildGpusToTry.length === 0)
throw new Error("No GPU types available to try building with");

if (build === "auto" || build === "never") {
for (let i = 0; i < buildGpusToTry.length; i++) {
const gpu = buildGpusToTry[i];
Expand Down
1 change: 1 addition & 0 deletions src/bindings/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {BinaryPlatform} from "./utils/getPlatform.js";
import {BinaryPlatformInfo} from "./utils/getPlatformInfo.js";

export const buildGpuOptions = ["metal", "cuda", "vulkan", false] as const;
export type LlamaGpuType = "metal" | "cuda" | "vulkan" | false;
export const nodeLlamaCppGpuOptions = [
"auto",
...buildGpuOptions
Expand Down
23 changes: 18 additions & 5 deletions src/bindings/utils/getGpuTypesToUseForOption.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,41 @@
import process from "process";
import {BuildGpu, buildGpuOptions} from "../types.js";
import {LlamaOptions} from "../getLlama.js";
import {BinaryPlatform, getPlatform} from "./getPlatform.js";
import {getBestComputeLayersAvailable} from "./getBestComputeLayersAvailable.js";

export async function getGpuTypesToUseForOption(gpu: BuildGpu | "auto", {
export async function getGpuTypesToUseForOption(gpu: Required<LlamaOptions>["gpu"], {
platform = getPlatform(),
arch = process.arch
}: {
platform?: BinaryPlatform,
arch?: typeof process.arch
} = {}): Promise<BuildGpu[]> {
const resolvedGpu = resolveValidGpuOptionForPlatform(gpu, {
const resolvedGpuOption = typeof gpu === "object"
? gpu.type
: gpu;

function withExcludedGpuTypesRemoved(gpuTypes: BuildGpu[]) {
const resolvedExcludeTypes = typeof gpu === "object"
? new Set(gpu.exclude ?? [])
: new Set();

return gpuTypes.filter(gpuType => !resolvedExcludeTypes.has(gpuType));
}

const resolvedGpu = resolveValidGpuOptionForPlatform(resolvedGpuOption, {
platform,
arch
});

if (resolvedGpu === "auto") {
if (arch === process.arch)
return await getBestComputeLayersAvailable();
return withExcludedGpuTypesRemoved(await getBestComputeLayersAvailable());

return [false];
return withExcludedGpuTypesRemoved([false]);
}

return [resolvedGpu];
return withExcludedGpuTypesRemoved([resolvedGpu]);
}

export function resolveValidGpuOptionForPlatform(gpu: BuildGpu | "auto", {
Expand Down
22 changes: 15 additions & 7 deletions src/cli/commands/ChatCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type ChatCommand = {
noJinja?: boolean,
contextSize?: number,
batchSize?: number,
flashAttention?: boolean,
noTrimWhitespace: boolean,
grammar: "text" | Parameters<typeof LlamaGrammar.getFor>[1],
jsonSchemaGrammarFile?: string,
Expand Down Expand Up @@ -149,6 +150,12 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
type: "number",
description: "Batch size to use for the model context. The default value is the context size"
})
.option("flashAttention", {
alias: "fa",
type: "boolean",
default: false,
description: "Enable flash attention"
})
.option("noTrimWhitespace", {
type: "boolean",
alias: ["noTrim"],
Expand Down Expand Up @@ -269,7 +276,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
},
async handler({
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt,
promptFile, wrapper, noJinja, contextSize, batchSize,
promptFile, wrapper, noJinja, contextSize, batchSize, flashAttention,
noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK,
topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory,
Expand All @@ -278,9 +285,9 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
try {
await RunChat({
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja, contextSize,
batchSize, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, gpuLayers,
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
noHistory, environmentFunctions, debug, meter, printTimings
batchSize, flashAttention, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP,
gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
maxTokens, noHistory, environmentFunctions, debug, meter, printTimings
});
} catch (err) {
await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing
Expand All @@ -293,9 +300,9 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {

async function RunChat({
modelPath: modelArg, header: headerArg, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja,
contextSize, batchSize, noTrimWhitespace, grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath, threads, temperature,
minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty,
repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings
contextSize, batchSize, flashAttention, noTrimWhitespace, grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath,
threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings
}: ChatCommand) {
if (contextSize === -1) contextSize = undefined;
if (gpuLayers === -1) gpuLayers = undefined;
Expand Down Expand Up @@ -360,6 +367,7 @@ async function RunChat({
: contextSize != null
? {fitContext: {contextSize}}
: undefined,
defaultContextFlashAttention: flashAttention,
ignoreMemorySafetyChecks: gpuLayers != null,
onLoadProgress(loadProgress: number) {
progressUpdater.setProgress(loadProgress);
Expand Down
14 changes: 11 additions & 3 deletions src/cli/commands/CompleteCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type CompleteCommand = {
textFile?: string,
contextSize?: number,
batchSize?: number,
flashAttention?: boolean,
threads: number,
temperature: number,
minP: number,
Expand Down Expand Up @@ -104,6 +105,12 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
type: "number",
description: "Batch size to use for the model context. The default value is the context size"
})
.option("flashAttention", {
alias: "fa",
type: "boolean",
default: false,
description: "Enable flash attention"
})
.option("threads", {
type: "number",
default: 6,
Expand Down Expand Up @@ -194,14 +201,14 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
},
async handler({
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize,
threads, temperature, minP, topK,
flashAttention, threads, temperature, minP, topK,
topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
debug, meter, printTimings
}) {
try {
await RunCompletion({
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize,
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty,
repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
debug, meter, printTimings
Expand All @@ -216,7 +223,7 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {


async function RunCompletion({
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize,
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
threads, temperature, minP, topK, topP, gpuLayers,
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
maxTokens, debug, meter, printTimings
Expand Down Expand Up @@ -276,6 +283,7 @@ async function RunCompletion({
: contextSize != null
? {fitContext: {contextSize}}
: undefined,
defaultContextFlashAttention: flashAttention,
ignoreMemorySafetyChecks: gpuLayers != null,
onLoadProgress(loadProgress: number) {
progressUpdater.setProgress(loadProgress);
Expand Down
14 changes: 11 additions & 3 deletions src/cli/commands/InfillCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type InfillCommand = {
suffixFile?: string,
contextSize?: number,
batchSize?: number,
flashAttention?: boolean,
threads: number,
temperature: number,
minP: number,
Expand Down Expand Up @@ -114,6 +115,12 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {
type: "number",
description: "Batch size to use for the model context. The default value is the context size"
})
.option("flashAttention", {
alias: "fa",
type: "boolean",
default: false,
description: "Enable flash attention"
})
.option("threads", {
type: "number",
default: 6,
Expand Down Expand Up @@ -204,14 +211,14 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {
},
async handler({
modelPath, header, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize,
threads, temperature, minP, topK,
flashAttention, threads, temperature, minP, topK,
topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
debug, meter, printTimings
}) {
try {
await RunInfill({
modelPath, header, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize,
modelPath, header, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize, flashAttention,
threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty,
repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
debug, meter, printTimings
Expand All @@ -226,7 +233,7 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {


async function RunInfill({
modelPath: modelArg, header: headerArg, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize,
modelPath: modelArg, header: headerArg, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize, flashAttention,
threads, temperature, minP, topK, topP, gpuLayers,
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
maxTokens, debug, meter, printTimings
Expand Down Expand Up @@ -300,6 +307,7 @@ async function RunInfill({
: contextSize != null
? {fitContext: {contextSize}}
: undefined,
defaultContextFlashAttention: flashAttention,
ignoreMemorySafetyChecks: gpuLayers != null,
onLoadProgress(loadProgress: number) {
progressUpdater.setProgress(loadProgress);
Expand Down
Loading
Loading