Skip to content

Commit

Permalink
speculative : simplify the implementation
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Nov 25, 2024
1 parent 9fd8c26 commit 2649e27
Showing 1 changed file with 23 additions and 32 deletions.
55 changes: 23 additions & 32 deletions examples/speculative-simple/speculative-simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,54 +181,45 @@ int main(int argc, char ** argv) {
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token

n_past += ids.size() - 1;
n_drafted += batch_tgt.n_tokens - 1;
n_drafted += draft.size(); // note: we ignore the discarded small drafts
n_accept += ids.size() - 1;
n_predict += ids.size();

// process the accepted tokens and update contexts
//
// this is the standard token post-processing that we normally do
// in this case, we do it for a group of accepted tokens at once
//
{
llama_token id;
std::string token_str;

for (size_t i = 0; i < ids.size(); ++i) {
id = ids[i];

++n_predict;

if (llama_token_is_eog(model_tgt, id)) {
has_eos = true;
break;
}
for (size_t i = 0; i < ids.size(); ++i) {
const llama_token id = ids[i];

token_str = common_token_to_piece(ctx_tgt, id);

if (params.use_color && i + 1 < ids.size()) {
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
} else {
LOG("%s", token_str.c_str());
}
}
prompt_tgt.push_back(id_last);
id_last = id;

if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
if (llama_token_is_eog(model_tgt, id)) {
has_eos = true;
break;
}

LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
const std::string token_str = common_token_to_piece(ctx_tgt, id);

{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);

llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
if (params.use_color && i + 1 < ids.size()) {
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
} else {
LOG("%s", token_str.c_str());
}
}

prompt_tgt.push_back(id_last);
prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1);
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);

// remember the last accepted token for the next iteration
id_last = id;
{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);

llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
}

if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
break;
}
}

Expand Down

0 comments on commit 2649e27

Please sign in to comment.