Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : save and restore kv cache for single seq id #6341

Merged
merged 34 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
662aaea
llama : save and restore kv cache for single seq id
kaetemi Mar 27, 2024
5462817
remove trailing whitespace
kaetemi Mar 27, 2024
ab1c46a
respond error in case there's no space in the kv cache
kaetemi Mar 27, 2024
02a1840
add kv seq save restore to test case
kaetemi Mar 27, 2024
b8e8fac
add --slot-save-path arg to enable save restore and restrict save loc…
kaetemi Mar 27, 2024
b182f8f
Returning 0 for some cases, instead of asserting.
martindevans Mar 27, 2024
a2b48b9
cleanup error cases
kaetemi Mar 27, 2024
c4443d7
rename sequence state functions
kaetemi Mar 28, 2024
4d5356b
rename state get set functions
kaetemi Mar 28, 2024
bbcbf47
add previous function names back in with DEPRECATED notice
kaetemi Mar 29, 2024
8b5ae29
update doc
kaetemi Mar 29, 2024
a71ec3d
adjust endpoints to preferred style
kaetemi Mar 29, 2024
bf1d493
fix restoring zero cell count
kaetemi Mar 29, 2024
8ab1a17
handle seq rm return value
kaetemi Mar 29, 2024
0d22136
unused param
kaetemi Mar 29, 2024
29f18c2
keep in the size check
kaetemi Mar 29, 2024
f2e41b3
fix return types
kaetemi Mar 29, 2024
92c4681
add server test case for slot save restore
kaetemi Mar 29, 2024
60f685f
cleanup
kaetemi Mar 29, 2024
d38eef4
add cake
kaetemi Mar 30, 2024
ea717f7
cleanup style
kaetemi Mar 30, 2024
b509b8b
add special
kaetemi Mar 30, 2024
129b6ff
removing a whole sequence never fails
kaetemi Mar 30, 2024
8af7211
move sequence state file functionality from server to llama to match …
kaetemi Mar 30, 2024
3d6fa5b
catch exceptions on save as well
kaetemi Apr 1, 2024
b3f6da3
error log messages
kaetemi Apr 1, 2024
be714a0
check types for stricter restore
kaetemi Apr 1, 2024
0ccfbf2
update server doc
kaetemi Apr 1, 2024
205c44c
readme : update API changes date
ggerganov Apr 4, 2024
d9fd0d7
Merge branch 'master' into feature/save-restore-seq
kaetemi Apr 4, 2024
f2a4777
strict filename validation
kaetemi Apr 5, 2024
4a4f399
move include, reject bom as well
kaetemi Apr 5, 2024
2fbf0c3
also reject empty filename
kaetemi Apr 5, 2024
bf94e9f
reject whitespace and trailing dot
kaetemi Apr 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)

### Recent API changes

- [2024 Mar 30] State and session file functions reorganized under `llama_state_*` https://github.com/ggerganov/llama.cpp/pull/6341
ggerganov marked this conversation as resolved.
Show resolved Hide resolved
- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
Expand Down
6 changes: 3 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
// The file exists and is not empty
session_tokens.resize(n_ctx);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1;
}
Expand Down Expand Up @@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
need_to_save_session = false;
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());

LOG("saved session to %s\n", path_session.c_str());
}
Expand Down Expand Up @@ -935,7 +935,7 @@ int main(int argc, char ** argv) {

if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}

llama_print_timings(ctx);
Expand Down
101 changes: 95 additions & 6 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ int main(int argc, char ** argv) {

std::string result0;
std::string result1;
std::string result2;

// init
llama_model * model;
Expand All @@ -44,8 +45,8 @@ int main(int argc, char ** argv) {

// save state (rng, logits, embedding and kv_cache) to file
{
std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
const size_t written = llama_copy_state_data(ctx, state_mem.data());
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_state_get_data(ctx, state_mem.data());

FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
Expand Down Expand Up @@ -97,13 +98,13 @@ int main(int argc, char ** argv) {

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));

FILE * fp_read = fopen("dump_state.bin", "rb");
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_set_state_data(ctx2, state_mem.data())) {
if (read != llama_state_set_data(ctx2, state_mem.data())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
Expand Down Expand Up @@ -141,16 +142,104 @@ int main(int argc, char ** argv) {
n_past += 1;
}

printf("\n");
printf("\n\n");

llama_free(ctx2);
llama_free_model(model);

if (result0 != result1) {
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
return 1;
}

// make new context
auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));

printf("\nsingle seq run: %s", params.prompt.c_str());

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));

FILE * fp_read = fopen("dump_state.bin", "rb");
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_state_set_data(ctx3, state_mem.data())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx3);
llama_free_model(model);
return 1;
}

fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
}

// restore state (last tokens)
n_past = n_past_saved;

// save seq 0 and load into seq 1
{
// save kv of seq 0
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
if (ncopy != seq_store.size()) {
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
llama_free(ctx3);
llama_free_model(model);
return 1;
}
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);

// erase whole kv
llama_kv_cache_clear(ctx3);
fprintf(stderr, "%s : kv cache cleared\n", __func__);

// restore kv into seq 1
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
if (nset != seq_store.size()) {
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
llama_free(ctx3);
llama_free_model(model);
return 1;
}
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
}

// third run with seq 1 instead of 0
for (auto i = 0; i < params.n_predict; i++) {
auto * logits = llama_get_logits(ctx3);
auto n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
auto next_token = llama_sample_token(ctx3, &candidates_p);
auto next_token_str = llama_token_to_piece(ctx3, next_token);

printf("%s", next_token_str.c_str());
result2 += next_token_str;

if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx3);
llama_free_model(model);
return 1;
}
n_past += 1;
}

printf("\n");

llama_free(ctx3);
llama_free_model(model);

if (result0 != result2) {
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
return 1;
}

fprintf(stderr, "\n%s : success\n", __func__);

return 0;
Expand Down
Loading
Loading