Skip to content

Commit

Permalink
speculative : refactor and add a simpler example (ggerganov#10362)
Browse files Browse the repository at this point in the history
* speculative : refactor and add a simpler example

ggml-ci

* speculative : clean-up and add comments and TODOs [no ci]

* speculative : manage context in common_speculative

ggml-ci

* speculative : simplify

ggml-ci

* speculative : simplify (cont)

ggml-ci

* speculative : add --draft-min CLI arg

* speculative : minor fixup

* make : build fixes

* speculative : do not redraft previous drafts

ggml-ci

* speculative : fix the draft sampling

ggml-ci

* speculative : fix compile warning

* common : refactor args

ggml-ci

* common : change defaults [no ci]

* common : final touches

ggml-ci
  • Loading branch information
ggerganov authored Nov 25, 2024
1 parent cce5a90 commit d9d54e4
Show file tree
Hide file tree
Showing 28 changed files with 1,028 additions and 326 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,7 @@ OBJ_COMMON = \
$(DIR_COMMON)/console.o \
$(DIR_COMMON)/ngram-cache.o \
$(DIR_COMMON)/sampling.o \
$(DIR_COMMON)/speculative.o \
$(DIR_COMMON)/build-info.o \
$(DIR_COMMON)/json-schema-to-grammar.o

Expand Down
2 changes: 2 additions & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ add_library(${TARGET} STATIC
ngram-cache.h
sampling.cpp
sampling.h
speculative.cpp
speculative.h
)

if (BUILD_SHARED_LIBS)
Expand Down
438 changes: 231 additions & 207 deletions common/arg.cpp

Large diffs are not rendered by default.

76 changes: 68 additions & 8 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());

buf << "\n" << std::to_string(i)
<< ":token '" << detokenized << "'"
<< ":pos " << std::to_string(batch.pos[i])
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
<< ":logits " << std::to_string(batch.logits[i]);
buf << "\n" << std::to_string(i)
<< ", token '" << detokenized << "'"
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
}

buf << " ]";
Expand Down Expand Up @@ -925,9 +925,9 @@ struct common_init_result common_init_from_params(common_params & params) {
common_lora_adapters_apply(lctx, iparams.lora_adapters);
}

if (params.sparams.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sparams.ignore_eos = false;
params.sampling.ignore_eos = false;
}

if (params.warmup) {
Expand Down Expand Up @@ -1490,6 +1490,66 @@ void common_batch_add(
batch.n_tokens++;
}

//
// Token utils
//

size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}

return i;
}

size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}

// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();

// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;

// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);

// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}

// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}

// update the previous row for the next iteration
prev_row = curr_row;
}

// return the maximum length of the LCS
return max_length;
}

//
// Vocab utils
//
Expand Down
41 changes: 32 additions & 9 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info {
struct llama_lora_adapter * adapter;
};

using llama_tokens = std::vector<llama_token>;

// build info
extern int LLAMA_BUILD_NUMBER;
extern char const * LLAMA_COMMIT;
Expand Down Expand Up @@ -101,8 +103,8 @@ enum dimre_method {
DIMRE_METHOD_MEAN,
};

// sampler parameters
struct common_sampler_params {
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler

int32_t n_prev = 64; // number of previous tokens to remember
Expand Down Expand Up @@ -153,19 +155,30 @@ struct common_sampler_params {
std::string print() const;
};

struct common_params_speculative {
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.9f; // minimum speculative decoding probability (greedy)

struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;

std::string model = ""; // draft model for speculative decoding // NOLINT
};

struct common_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
int32_t grp_attn_n = 1; // group-attention factor
Expand All @@ -182,8 +195,6 @@ struct common_params {

struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
struct cpu_params draft_cpuparams;
struct cpu_params draft_cpuparams_batch;

ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;
Expand All @@ -195,10 +206,10 @@ struct common_params {
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings

struct common_sampler_params sparams;
struct common_params_sampling sampling;
struct common_params_speculative speculative;

std::string model = ""; // model path // NOLINT
std::string model_draft = ""; // draft model for speculative decoding // NOLINT
std::string model_alias = "unknown"; // model alias // NOLINT
std::string model_url = ""; // model url to download // NOLINT
std::string hf_token = ""; // HF token // NOLINT
Expand Down Expand Up @@ -461,7 +472,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f
// clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);

//
// Batch utils
//

void common_batch_clear(struct llama_batch & batch);

Expand All @@ -472,6 +485,16 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);

//
// Token utils
//

// longest common prefix
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);

// longet common subsequence
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);

//
// Vocab utils
//
Expand Down
45 changes: 42 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct ring_buffer {
};

struct common_sampler {
common_sampler_params params;
common_params_sampling params;

struct llama_sampler * grmr;
struct llama_sampler * chain;
Expand All @@ -125,7 +125,7 @@ struct common_sampler {
}
};

std::string common_sampler_params::print() const {
std::string common_params_sampling::print() const {
char result[1024];

snprintf(result, sizeof(result),
Expand All @@ -141,7 +141,7 @@ std::string common_sampler_params::print() const {
return std::string(result);
}

struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

lparams.no_perf = params.no_perf;
Expand Down Expand Up @@ -320,6 +320,45 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return cur_p.data[cur_p.selected].id;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

std::vector<llama_token> result;
result.reserve(idxs.size());

size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

common_sampler_accept(gsmpl, id, true);

result.push_back(id);

if (draft[i] != id) {
break;
}
}

if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

common_sampler_accept(gsmpl, id, true);

result.push_back(id);
}

return result;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}

return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}
Expand Down
23 changes: 22 additions & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct common_sampler;

// llama_sampler API overloads

struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params);
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);

void common_sampler_free(struct common_sampler * gsmpl);

Expand All @@ -60,6 +60,27 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);

// generalized version of common_sampler_sample
//
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
//
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
//
// is equivalent to
//
// common_sampler_sample(gsmpl, ctx, idx);
// common_sampler_accept(gsmpl, token, true);
//
// requires: idxs.size() == draft.size() + 1
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);

// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

// helpers
Expand Down
Loading

0 comments on commit d9d54e4

Please sign in to comment.