Skip to content

Commit

Permalink
feat: token biases (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
giladgd authored Apr 10, 2024
1 parent b542b53 commit 3ad4494
Show file tree
Hide file tree
Showing 16 changed files with 253 additions and 22 deletions.
44 changes: 41 additions & 3 deletions llama/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <algorithm>
#include <sstream>
#include <vector>
#include <unordered_map>

#include "common.h"
#include "common/grammar-parser.h"
Expand Down Expand Up @@ -1334,6 +1335,8 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
float repeat_penalty_presence_penalty = 0.00f; // 0.0 = disabled
float repeat_penalty_frequency_penalty = 0.00f; // 0.0 = disabled
std::vector<llama_token> repeat_penalty_tokens;
std::unordered_map<llama_token, float> tokenBiases;
bool useTokenBiases = false;
bool use_repeat_penalty = false;

AddonContextSampleTokenWorker(const Napi::CallbackInfo& info, AddonContext* ctx)
Expand Down Expand Up @@ -1378,6 +1381,19 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
use_repeat_penalty = true;
}

if (options.Has("tokenBiasKeys") && options.Has("tokenBiasValues")) {
Napi::Uint32Array tokenBiasKeys = options.Get("tokenBiasKeys").As<Napi::Uint32Array>();
Napi::Float32Array tokenBiasValues = options.Get("tokenBiasValues").As<Napi::Float32Array>();

if (tokenBiasKeys.ElementLength() == tokenBiasValues.ElementLength()) {
for (size_t i = 0; i < tokenBiasKeys.ElementLength(); i++) {
tokenBiases[static_cast<llama_token>(tokenBiasKeys[i])] = tokenBiasValues[i];
}

useTokenBiases = true;
}
}

if (options.Has("repeatPenaltyPresencePenalty")) {
repeat_penalty_presence_penalty = options.Get("repeatPenaltyPresencePenalty").As<Napi::Number>().FloatValue();
}
Expand Down Expand Up @@ -1426,18 +1442,33 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
// Select the best prediction.
auto logits = llama_get_logits_ith(ctx->ctx, batchLogitIndex);
auto n_vocab = llama_n_vocab(ctx->model->model);
auto eos_token = llama_token_eos(ctx->model->model);

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data { token_id, logits[token_id], 0.0f });
auto logit = logits[token_id];

if (useTokenBiases) {
bool hasTokenBias = tokenBiases.find(token_id) != tokenBiases.end();
if (hasTokenBias) {
auto logitBias = tokenBiases.at(token_id);
if (logitBias == -INFINITY || logitBias < -INFINITY) {
if (token_id != eos_token) {
logit = -INFINITY;
}
} else {
logit += logitBias;
}
}
}

candidates.emplace_back(llama_token_data { token_id, logit, 0.0f });
}

llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };

auto eos_token = llama_token_eos(ctx->model->model);

if (use_repeat_penalty && !repeat_penalty_tokens.empty()) {
llama_sample_repetition_penalties(
ctx->ctx,
Expand All @@ -1452,6 +1483,13 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {

if (use_grammar && (grammar_evaluation_state)->grammar != nullptr) {
llama_sample_grammar(ctx->ctx, &candidates_p, (grammar_evaluation_state)->grammar);

if ((candidates_p.size == 0 || candidates_p.data[0].logit == -INFINITY) && useTokenBiases) {
// logit biases caused grammar sampling to fail, so sampling again without logit biases
useTokenBiases = false;
SampleToken();
return;
}
}

if (temperature <= 0) {
Expand Down
4 changes: 3 additions & 1 deletion src/bindings/AddonTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ export type AddonContext = {
repeatPenaltyTokens?: Uint32Array,
repeatPenaltyPresencePenalty?: number, // alpha_presence
repeatPenaltyFrequencyPenalty?: number, // alpha_frequency
grammarEvaluationState?: AddonGrammarEvaluationState
grammarEvaluationState?: AddonGrammarEvaluationState,
tokenBiasKeys?: Uint32Array,
tokenBiasValues?: Float32Array
}): Promise<Token>,
disposeSequence(sequenceId: number): void,

Expand Down
11 changes: 9 additions & 2 deletions src/cli/commands/ChatCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ async function RunChat({
successText: chalk.blue("Model loaded"),
failText: chalk.blue("Failed to load model"),
liveUpdates: !debug,
noProgress: debug
noProgress: debug,
liveCtrlCSendsAbortSignal: true
}, async (progressUpdater) => {
try {
return await llama.loadModel({
Expand All @@ -336,8 +337,14 @@ async function RunChat({
ignoreMemorySafetyChecks: gpuLayers != null,
onLoadProgress(loadProgress: number) {
progressUpdater.setProgress(loadProgress);
}
},
loadSignal: progressUpdater.abortSignal
});
} catch (err) {
if (err === progressUpdater.abortSignal?.reason)
process.exit(0);

throw err;
} finally {
if (llama.logLevel === LlamaLogLevel.debug) {
await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing
Expand Down
11 changes: 9 additions & 2 deletions src/cli/commands/CompleteCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ async function RunCompletion({
successText: chalk.blue("Model loaded"),
failText: chalk.blue("Failed to load model"),
liveUpdates: !debug,
noProgress: debug
noProgress: debug,
liveCtrlCSendsAbortSignal: true
}, async (progressUpdater) => {
try {
return await llama.loadModel({
Expand All @@ -251,8 +252,14 @@ async function RunCompletion({
ignoreMemorySafetyChecks: gpuLayers != null,
onLoadProgress(loadProgress: number) {
progressUpdater.setProgress(loadProgress);
}
},
loadSignal: progressUpdater.abortSignal
});
} catch (err) {
if (err === progressUpdater.abortSignal?.reason)
process.exit(0);

throw err;
} finally {
if (llama.logLevel === LlamaLogLevel.debug) {
await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing
Expand Down
11 changes: 9 additions & 2 deletions src/cli/commands/InfillCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ async function RunInfill({
successText: chalk.blue("Model loaded"),
failText: chalk.blue("Failed to load model"),
liveUpdates: !debug,
noProgress: debug
noProgress: debug,
liveCtrlCSendsAbortSignal: true
}, async (progressUpdater) => {
try {
return await llama.loadModel({
Expand All @@ -275,8 +276,14 @@ async function RunInfill({
ignoreMemorySafetyChecks: gpuLayers != null,
onLoadProgress(loadProgress: number) {
progressUpdater.setProgress(loadProgress);
}
},
loadSignal: progressUpdater.abortSignal
});
} catch (err) {
if (err === progressUpdater.abortSignal?.reason)
process.exit(0);

throw err;
} finally {
if (llama.logLevel === LlamaLogLevel.debug) {
await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing
Expand Down
1 change: 1 addition & 0 deletions src/cli/utils/ConsoleInteraction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ export class ConsoleInteraction {

if (callbacks.length === 0 && key === ConsoleInteractionKey.ctrlC) {
process.stdout.write("\n");
this.stop();
process.exit(0);
}

Expand Down
1 change: 1 addition & 0 deletions src/cli/utils/consolePromptQuestion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export async function consolePromptQuestion(question: string, {
clearLastLines(linesUsed);

if (exitOnCtrlC) {
rl.close();
process.exit(0);
} else
accept(null);
Expand Down
10 changes: 10 additions & 0 deletions src/evaluator/LlamaChat/LlamaChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {UNKNOWN_UNICODE_CHAR} from "../../consts.js";
import {getQueuedTokensBeforeStopTrigger} from "../../utils/getQueuedTokensBeforeStopTrigger.js";
import {resolveChatWrapper} from "../../chatWrappers/utils/resolveChatWrapper.js";
import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js";
import {TokenBias} from "../TokenBias.js";
import {
eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy
} from "./utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js";
Expand Down Expand Up @@ -85,6 +86,13 @@ export type LLamaChatGenerateResponseOptions<Functions extends ChatModelFunction

repeatPenalty?: false | LLamaContextualRepeatPenalty,

/**
* Adjust the probability of tokens being generated.
* Can be used to bias the model to generate tokens that you want it to lean towards,
* or to avoid generating tokens that you want it to avoid.
*/
tokenBias?: TokenBias | (() => TokenBias),

/**
* See the parameter `evaluationPriority` on the `LlamaContextSequence.evaluate()` function for more information.
*/
Expand Down Expand Up @@ -249,6 +257,7 @@ export class LlamaChat {
grammar,
trimWhitespaceSuffix = false,
repeatPenalty = {},
tokenBias,
evaluationPriority = 5,
functions,
documentFunctionParams,
Expand Down Expand Up @@ -532,6 +541,7 @@ export class LlamaChat {
frequencyPenalty,
presencePenalty
},
tokenBias,
evaluationPriority,
yieldEosToken: true
}));
Expand Down
18 changes: 15 additions & 3 deletions src/evaluator/LlamaChatSession/LlamaChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {LlamaContextSequence} from "../LlamaContext/LlamaContext.js";
import {LlamaGrammar} from "../LlamaGrammar.js";
import {LlamaChat, LLamaChatContextShiftOptions, LlamaChatResponse} from "../LlamaChat/LlamaChat.js";
import {EvaluationPriority} from "../LlamaContext/types.js";
import {TokenBias} from "../TokenBias.js";


export type LlamaChatSessionOptions = {
Expand Down Expand Up @@ -96,7 +97,14 @@ export type LLamaChatPromptOptions<Functions extends ChatSessionModelFunctions |
*/
evaluationPriority?: EvaluationPriority,

repeatPenalty?: false | LlamaChatSessionRepeatPenalty
repeatPenalty?: false | LlamaChatSessionRepeatPenalty,

/**
* Adjust the probability of tokens being generated.
* Can be used to bias the model to generate tokens that you want it to lean towards,
* or to avoid generating tokens that you want it to avoid.
*/
tokenBias?: TokenBias | (() => TokenBias)
} & ({
grammar?: LlamaGrammar,
functions?: never,
Expand Down Expand Up @@ -249,14 +257,16 @@ export class LlamaChatSession {
topP,
grammar,
trimWhitespaceSuffix = false,
repeatPenalty
repeatPenalty,
tokenBias
}: LLamaChatPromptOptions<Functions> = {}) {
const {responseText} = await this.promptWithMeta<Functions>(prompt, {
// this is a workaround to allow passing both `functions` and `grammar`
functions: functions as undefined,
documentFunctionParams: documentFunctionParams as undefined,

onToken, signal, maxTokens, temperature, minP, topK, topP, grammar, trimWhitespaceSuffix, repeatPenalty
onToken, signal, maxTokens, temperature, minP, topK, topP, grammar, trimWhitespaceSuffix, repeatPenalty,
tokenBias
});

return responseText;
Expand All @@ -279,6 +289,7 @@ export class LlamaChatSession {
grammar,
trimWhitespaceSuffix = false,
repeatPenalty,
tokenBias,
evaluationPriority
}: LLamaChatPromptOptions<Functions> = {}) {
this._ensureNotDisposed();
Expand Down Expand Up @@ -325,6 +336,7 @@ export class LlamaChatSession {
minP,
topK,
topP,
tokenBias,
maxTokens,
temperature,
trimWhitespaceSuffix,
Expand Down
14 changes: 14 additions & 0 deletions src/evaluator/LlamaCompletion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {LlamaGrammarEvaluationState} from "./LlamaGrammarEvaluationState.js";
import {LlamaGrammar} from "./LlamaGrammar.js";
import {EvaluationPriority} from "./LlamaContext/types.js";
import {LlamaContextSequence} from "./LlamaContext/LlamaContext.js";
import {TokenBias} from "./TokenBias.js";

export type LlamaCompletionOptions = {
contextSequence: LlamaContextSequence,
Expand Down Expand Up @@ -76,6 +77,13 @@ export type LlamaCompletionGenerationOptions = {

repeatPenalty?: false | LLamaContextualRepeatPenalty,

/**
* Adjust the probability of tokens being generated.
* Can be used to bias the model to generate tokens that you want it to lean towards,
* or to avoid generating tokens that you want it to avoid.
*/
tokenBias?: TokenBias | (() => TokenBias),

/**
* See the parameter `evaluationPriority` on the `LlamaContextSequence.evaluate()` function for more information.
*/
Expand Down Expand Up @@ -195,6 +203,7 @@ export class LlamaCompletion {
topP,
trimWhitespaceSuffix = false,
repeatPenalty = {},
tokenBias,
evaluationPriority = 5,
grammar,
stopGenerationTriggers,
Expand Down Expand Up @@ -274,6 +283,7 @@ export class LlamaCompletion {
topP,
trimWhitespaceSuffix,
repeatPenalty,
tokenBias,
evaluationPriority,
grammar,
contextShiftSize,
Expand Down Expand Up @@ -326,6 +336,7 @@ export class LlamaCompletion {
topP,
trimWhitespaceSuffix = false,
repeatPenalty = {},
tokenBias,
evaluationPriority = 5,
grammar,
contextShiftSize = defaultContextShiftSize,
Expand Down Expand Up @@ -455,6 +466,7 @@ export class LlamaCompletion {
topP,
trimWhitespaceSuffix,
repeatPenalty,
tokenBias,
evaluationPriority,
grammar,
contextShiftSize,
Expand Down Expand Up @@ -489,6 +501,7 @@ export class LlamaCompletion {
topP,
trimWhitespaceSuffix = false,
repeatPenalty = {},
tokenBias,
evaluationPriority = 5,
grammar,
contextShiftSize = defaultContextShiftSize,
Expand Down Expand Up @@ -603,6 +616,7 @@ export class LlamaCompletion {
frequencyPenalty,
presencePenalty
},
tokenBias,
evaluationPriority,
yieldEosToken: true
}));
Expand Down
Loading

0 comments on commit 3ad4494

Please sign in to comment.