Skip to content

Commit

Permalink
speculative : simplify the implementation
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Nov 26, 2024
1 parent 9fd8c26 commit 1bc0e9a
Showing 1 changed file with 24 additions and 33 deletions.
57 changes: 24 additions & 33 deletions examples/speculative-simple/speculative-simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ int main(int argc, char ** argv) {
llama_token id_last = inp.back();

// all tokens currently in the target context
auto prompt_tgt = std::vector<llama_token>(inp.begin(), inp.end() - 1);
llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
prompt_tgt.reserve(llama_n_ctx(ctx_tgt));

int n_past = inp.size() - 1;

Expand Down Expand Up @@ -181,54 +182,44 @@ int main(int argc, char ** argv) {
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token

n_past += ids.size() - 1;
n_drafted += batch_tgt.n_tokens - 1;
n_drafted += draft.size(); // note: we ignore the discarded small drafts
n_accept += ids.size() - 1;
n_predict += ids.size();

// process the accepted tokens and update contexts
//
// this is the standard token post-processing that we normally do
// in this case, we do it for a group of accepted tokens at once
//
{
llama_token id;
std::string token_str;

for (size_t i = 0; i < ids.size(); ++i) {
id = ids[i];

++n_predict;

if (llama_token_is_eog(model_tgt, id)) {
has_eos = true;
break;
}

token_str = common_token_to_piece(ctx_tgt, id);
for (size_t i = 0; i < ids.size(); ++i) {
prompt_tgt.push_back(id_last);

if (params.use_color && i + 1 < ids.size()) {
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
} else {
LOG("%s", token_str.c_str());
}
}
id_last = ids[i];

if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
if (llama_token_is_eog(model_tgt, id_last)) {
has_eos = true;
break;
}

LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
const std::string token_str = common_token_to_piece(ctx_tgt, id_last);

{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);

llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
if (params.use_color && i + 1 < ids.size()) {
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
} else {
LOG("%s", token_str.c_str());
}
}

prompt_tgt.push_back(id_last);
prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1);
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);

{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);

llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
}

// remember the last accepted token for the next iteration
id_last = id;
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
break;
}
}

Expand Down

0 comments on commit 1bc0e9a

Please sign in to comment.