Skip to content

Commit

Permalink
fix: align embedding input with WPM vocabulary type models (#393)
Browse files Browse the repository at this point in the history
  • Loading branch information
giladgd authored Dec 7, 2024
1 parent 4d387de commit 28c7984
Show file tree
Hide file tree
Showing 19 changed files with 570 additions and 35 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ jobs:

model-dependent-tests:
name: Model dependent tests
runs-on: macos-12
runs-on: ubuntu-24.04
env:
NODE_LLAMA_CPP_GPU: false
needs:
Expand All @@ -412,10 +412,10 @@ jobs:
name: llama.cpp
path: llama

- name: Install dependencies on macOS
- name: Install dependencies on Ubuntu
run: |
brew install cmake ninja
alias make=cmake
sudo apt-get update
sudo apt-get install ninja-build cmake
- name: Install modules
run: npm ci
Expand Down
3 changes: 2 additions & 1 deletion llama/addon/AddonContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,8 @@ Napi::Value AddonContext::GetEmbedding(const Napi::CallbackInfo& info) {
}

const int n_embd = llama_n_embd(model->model);
const auto* embeddings = llama_get_embeddings_seq(ctx, 0);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
const auto* embeddings = pooling_type == LLAMA_POOLING_TYPE_NONE ? NULL : llama_get_embeddings_seq(ctx, 0);
if (embeddings == NULL) {
embeddings = llama_get_embeddings_ith(ctx, inputTokensLength - 1);

Expand Down
30 changes: 26 additions & 4 deletions llama/addon/AddonModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "AddonModelLora.h"

static Napi::Value getNapiToken(const Napi::CallbackInfo& info, llama_model* model, llama_token token) {
if (token < 0) {
if (token < 0 || token == LLAMA_TOKEN_NULL) {
return Napi::Number::From(info.Env(), -1);
}

Expand Down Expand Up @@ -565,6 +565,22 @@ Napi::Value AddonModel::EotToken(const Napi::CallbackInfo& info) {

return getNapiToken(info, model, llama_token_eot(model));
}
Napi::Value AddonModel::ClsToken(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

return getNapiToken(info, model, llama_token_cls(model));
}
Napi::Value AddonModel::SepToken(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

return getNapiToken(info, model, llama_token_sep(model));
}
Napi::Value AddonModel::GetTokenString(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException();
Expand Down Expand Up @@ -624,11 +640,14 @@ Napi::Value AddonModel::GetVocabularyType(const Napi::CallbackInfo& info) {
return Napi::Number::From(info.Env(), int32_t(vocabularyType));
}
Napi::Value AddonModel::ShouldPrependBosToken(const Napi::CallbackInfo& info) {
const int addBos = llama_add_bos_token(model);
const bool addBos = llama_add_bos_token(model);

bool shouldPrependBos = addBos != -1 ? bool(addBos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
return Napi::Boolean::New(info.Env(), addBos);
}
Napi::Value AddonModel::ShouldAppendEosToken(const Napi::CallbackInfo& info) {
const bool addEos = llama_add_eos_token(model);

return Napi::Boolean::New(info.Env(), shouldPrependBos);
return Napi::Boolean::New(info.Env(), addEos);
}

Napi::Value AddonModel::GetModelSize(const Napi::CallbackInfo& info) {
Expand Down Expand Up @@ -659,11 +678,14 @@ void AddonModel::init(Napi::Object exports) {
InstanceMethod("middleToken", &AddonModel::MiddleToken),
InstanceMethod("suffixToken", &AddonModel::SuffixToken),
InstanceMethod("eotToken", &AddonModel::EotToken),
InstanceMethod("clsToken", &AddonModel::ClsToken),
InstanceMethod("sepToken", &AddonModel::SepToken),
InstanceMethod("getTokenString", &AddonModel::GetTokenString),
InstanceMethod("getTokenAttributes", &AddonModel::GetTokenAttributes),
InstanceMethod("isEogToken", &AddonModel::IsEogToken),
InstanceMethod("getVocabularyType", &AddonModel::GetVocabularyType),
InstanceMethod("shouldPrependBosToken", &AddonModel::ShouldPrependBosToken),
InstanceMethod("shouldAppendEosToken", &AddonModel::ShouldAppendEosToken),
InstanceMethod("getModelSize", &AddonModel::GetModelSize),
InstanceMethod("dispose", &AddonModel::Dispose),
}
Expand Down
3 changes: 3 additions & 0 deletions llama/addon/AddonModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ class AddonModel : public Napi::ObjectWrap<AddonModel> {
Napi::Value MiddleToken(const Napi::CallbackInfo& info);
Napi::Value SuffixToken(const Napi::CallbackInfo& info);
Napi::Value EotToken(const Napi::CallbackInfo& info);
Napi::Value ClsToken(const Napi::CallbackInfo& info);
Napi::Value SepToken(const Napi::CallbackInfo& info);
Napi::Value GetTokenString(const Napi::CallbackInfo& info);

Napi::Value GetTokenAttributes(const Napi::CallbackInfo& info);
Napi::Value IsEogToken(const Napi::CallbackInfo& info);
Napi::Value GetVocabularyType(const Napi::CallbackInfo& info);
Napi::Value ShouldPrependBosToken(const Napi::CallbackInfo& info);
Napi::Value ShouldAppendEosToken(const Napi::CallbackInfo& info);
Napi::Value GetModelSize(const Napi::CallbackInfo& info);

static void init(Napi::Object exports);
Expand Down
3 changes: 3 additions & 0 deletions src/bindings/AddonTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,14 @@ export type AddonModel = {
middleToken(): Token,
suffixToken(): Token,
eotToken(): Token,
clsToken(): Token,
sepToken(): Token,
getTokenString(token: number): string,
getTokenAttributes(token: Token): number,
isEogToken(token: Token): boolean,
getVocabularyType(): number,
shouldPrependBosToken(): boolean,
shouldAppendEosToken(): boolean,
getModelSize(): number
};

Expand Down
4 changes: 3 additions & 1 deletion src/bindings/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ export enum LlamaVocabularyType {
none = "none",
spm = "spm",
bpe = "bpe",
wpm = "wpm"
wpm = "wpm",
ugm = "ugm",
rwkv = "rwkv"
}
export const LlamaVocabularyTypeValues = Object.freeze([
LlamaVocabularyType.none,
Expand Down
25 changes: 15 additions & 10 deletions src/evaluator/LlamaCompletion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {getQueuedTokensBeforeStopTrigger} from "../utils/getQueuedTokensBeforeSt
import {safeEventCallback} from "../utils/safeEventCallback.js";
import {pushAll} from "../utils/pushAll.js";
import {GgufArchitectureType} from "../gguf/types/GgufMetadataTypes.js";
import {resolveBeginningTokenToPrepend} from "../utils/tokenizerUtils.js";
import {LlamaGrammarEvaluationState} from "./LlamaGrammarEvaluationState.js";
import {LlamaGrammar} from "./LlamaGrammar.js";
import {EvaluationPriority} from "./LlamaContext/types.js";
Expand Down Expand Up @@ -262,8 +263,10 @@ export class LlamaCompletion {
if (this._sequence == null || this.disposed)
throw new DisposedError();

const bosToken = this._sequence.model.tokens.bos;
const shouldPrependBosToken = this._sequence.model.tokens.shouldPrependBosToken;
const beginningTokenToPrepend = resolveBeginningTokenToPrepend(
this._sequence.model.vocabularyType,
this._sequence.model.tokens
);

const extraEosTokens = getExtraCompletionEosTokens(this._sequence.model);

Expand All @@ -274,8 +277,8 @@ export class LlamaCompletion {
}): Promise<Token[]> {
const res = [];

if (shouldPrependBosToken && bosToken != null)
res.push(bosToken);
if (beginningTokenToPrepend != null)
res.push(beginningTokenToPrepend);

const inputTokensSize = Math.max(0, Math.min(maxTokens - res.length, tokens.length));

Expand Down Expand Up @@ -305,7 +308,7 @@ export class LlamaCompletion {
const resolvedInput = tokenizeInput(
input,
this._sequence.model.tokenizer,
(shouldPrependBosToken && bosToken != null)
beginningTokenToPrepend != null
? "trimLeadingSpace"
: undefined
);
Expand Down Expand Up @@ -406,8 +409,10 @@ export class LlamaCompletion {
const prefixToken = this._sequence.model.tokens.infill.prefix;
const suffixToken = this._sequence.model.tokens.infill.suffix;
const middleToken = this._sequence.model.tokens.infill.middle;
const bosToken = this._sequence.model.tokens.bos;
const shouldPrependBosToken = this._sequence.model.tokens.shouldPrependBosToken;
const beginningTokenToPrepend = resolveBeginningTokenToPrepend(
this._sequence.model.vocabularyType,
this._sequence.model.tokens
);

if (prefixToken == null || suffixToken == null)
throw new UnsupportedError("Infill completions are not supported by this model");
Expand All @@ -425,7 +430,7 @@ export class LlamaCompletion {
// 2 - InfillPrefix token, InfillSuffix token
const specialTokensInContext = 2 +
(middleToken != null ? 1 : 0) +
((shouldPrependBosToken && bosToken != null) ? 1 : 0);
(beginningTokenToPrepend != null ? 1 : 0);
const resolvedMaxTokens = maxTokens - specialTokensInContext;
let sizeLeftToFill = resolvedMaxTokens;

Expand Down Expand Up @@ -464,8 +469,8 @@ export class LlamaCompletion {

const newContextState: Token[] = [];

if (shouldPrependBosToken && bosToken != null)
newContextState.push(bosToken);
if (beginningTokenToPrepend != null)
newContextState.push(beginningTokenToPrepend);

if (middleToken != null) {
newContextState.push(prefixToken);
Expand Down
18 changes: 17 additions & 1 deletion src/evaluator/LlamaEmbeddingContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import {AsyncDisposeAggregator, EventRelay, withLock} from "lifecycle-utils";
import {Token} from "../types.js";
import {LlamaText} from "../utils/LlamaText.js";
import {tokenizeInput} from "../utils/tokenizeInput.js";
import {resolveBeginningTokenToPrepend, resolveEndTokenToAppend} from "../utils/tokenizerUtils.js";
import {LlamaEmbedding} from "./LlamaEmbedding.js";
import type {LlamaModel} from "./LlamaModel/LlamaModel.js";
import type {LlamaContext, LlamaContextSequence} from "./LlamaContext/LlamaContext.js";
Expand Down Expand Up @@ -72,7 +73,7 @@ export class LlamaEmbeddingContext {
}

public async getEmbeddingFor(input: Token[] | string | LlamaText) {
const resolvedInput = tokenizeInput(input, this._llamaContext.model.tokenizer);
const resolvedInput = tokenizeInput(input, this._llamaContext.model.tokenizer, undefined, true);

if (resolvedInput.length > this._llamaContext.contextSize)
throw new Error(
Expand All @@ -84,6 +85,14 @@ export class LlamaEmbeddingContext {
vector: []
});

const beginningToken = resolveBeginningTokenToPrepend(this.model.vocabularyType, this.model.tokens);
if (beginningToken != null && resolvedInput[0] !== beginningToken)
resolvedInput.unshift(beginningToken);

const endToken = resolveEndTokenToAppend(this.model.vocabularyType, this.model.tokens);
if (endToken != null && resolvedInput.at(-1) !== endToken)
resolvedInput.push(endToken);

return await withLock(this, "evaluate", async () => {
await this._sequence.eraseContextTokenRanges([{
start: 0,
Expand Down Expand Up @@ -118,6 +127,10 @@ export class LlamaEmbeddingContext {
return this._llamaContext.disposed;
}

public get model() {
return this._llamaContext.model;
}

/** @internal */
public static async _create({
_model
Expand All @@ -130,6 +143,9 @@ export class LlamaEmbeddingContext {
createSignal,
ignoreMemorySafetyChecks
}: LlamaEmbeddingContextOptions) {
if (_model.fileInsights.hasEncoder && _model.fileInsights.hasDecoder)
throw new Error("Computing embeddings is not supported for encoder-decoder models.");

const llamaContext = await _model.createContext({
contextSize,
batchSize,
Expand Down
83 changes: 82 additions & 1 deletion src/evaluator/LlamaModel/LlamaModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ export class LlamaModel {
if (modelLoaded)
await model._model.dispose();

throw loadSignal.reason;
throw loadSignal!.reason;
} else if (!modelLoaded)
throw new Error("Failed to load model");

Expand All @@ -757,12 +757,17 @@ export class LlamaModelTokens {
/** @internal */ private _bosToken?: Token;
/** @internal */ private _eosToken?: Token;
/** @internal */ private _eotToken?: Token;
/** @internal */ private _clsToken?: Token;
/** @internal */ private _sepToken?: Token;
/** @internal */ private _nlToken?: Token;
/** @internal */ private _bosString?: string;
/** @internal */ private _eosString?: string;
/** @internal */ private _eotString?: string;
/** @internal */ private _clsString?: string;
/** @internal */ private _sepString?: string;
/** @internal */ private _nlString?: string;
/** @internal */ private _shouldPrependBosToken?: boolean;
/** @internal */ private _shouldAppendEosToken?: boolean;

private constructor(model: AddonModel, disposedState: DisposedState) {
this._model = model;
Expand Down Expand Up @@ -826,6 +831,36 @@ export class LlamaModelTokens {
return this._eotToken;
}

/**
* @returns The CLS (Classification) token.
*/
public get cls(): Token | null {
this._ensureNotDisposed();

if (this._clsToken == null)
this._clsToken = this._model.clsToken();

if (this._clsToken === -1)
return null;

return this._clsToken;
}

/**
* @returns The SEP (Sentence Separator) token.
*/
public get sep(): Token | null {
this._ensureNotDisposed();

if (this._sepToken == null)
this._sepToken = this._model.sepToken();

if (this._sepToken === -1)
return null;

return this._sepToken;
}

/**
* @returns The NL (New Line) token.
*/
Expand Down Expand Up @@ -892,6 +927,40 @@ export class LlamaModelTokens {
return this._eotString;
}

/**
* @returns The CLS (Classification) token text representation.
*/
public get clsString(): string | null {
this._ensureNotDisposed();

const clsToken = this.cls;

if (clsToken == null)
return null;

if (this._clsString == null)
this._clsString = this._model.getTokenString(clsToken);

return this._clsString;
}

/**
* @returns The SEP (Sentence Separator) token text representation.
*/
public get sepString(): string | null {
this._ensureNotDisposed();

const sepToken = this.sep;

if (sepToken == null)
return null;

if (this._sepString == null)
this._sepString = this._model.getTokenString(sepToken);

return this._sepString;
}

/**
* @returns The NL (New Line) token text representation.
*/
Expand Down Expand Up @@ -921,6 +990,18 @@ export class LlamaModelTokens {
return this._shouldPrependBosToken;
}

/**
* @returns Whether we should append an EOS (End Of Sequence) token for evaluations with this model.
*/
public get shouldAppendEosToken(): boolean {
this._ensureNotDisposed();

if (this._shouldAppendEosToken == null)
this._shouldAppendEosToken = this.bos != null && this._model.shouldAppendEosToken();

return this._shouldAppendEosToken;
}

/** @internal */
private _ensureNotDisposed() {
if (this._disposedState.disposed)
Expand Down
Loading

0 comments on commit 28c7984

Please sign in to comment.