From cda49d38288421133f6dcdee6e8fdd2d8fd6990f Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 13 Mar 2024 13:59:08 +0100 Subject: [PATCH] check n_ubatch >= n_tokens with non-casual attention --- examples/server/server.cpp | 5 +++-- llama.cpp | 9 ++++++--- llama.h | 1 + 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7a210b075f01b..fe18d88ed4db1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1738,7 +1738,8 @@ struct server_context { } // process in chunks of params.n_batch - int32_t n_batch = params.n_batch; + int32_t n_batch = llama_n_batch(ctx); + int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch if (params.cont_batching || batch.n_tokens == 0) { @@ -1811,7 +1812,7 @@ struct server_context { if (slot.embedding) { // this prompt is too large to process - discard it - if (slot.n_prompt_tokens > n_batch) { + if (slot.n_prompt_tokens > n_ubatch) { slot.state = SLOT_STATE_PROCESSING; slot.command = SLOT_COMMAND_NONE; slot.release(); diff --git a/llama.cpp b/llama.cpp index a2ceac163cb48..50c744ad8744a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8774,6 +8774,8 @@ static int llama_decode_internal( GGML_ASSERT(n_tokens_all <= cparams.n_batch); + GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); + if (lctx.t_compute_start_us == 0) { lctx.t_compute_start_us = ggml_time_us(); } @@ -9011,9 +9013,6 @@ static int llama_decode_internal( case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_MEAN: { - // FIXME: this may not work if the sequences are split into different batches - GGML_ASSERT(n_tokens_all == n_tokens); - GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0); // extract sequence embeddings @@ -13076,6 +13075,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) { return ctx->cparams.n_batch; } +uint32_t llama_n_ubatch(const struct llama_context * ctx) { + return ctx->cparams.n_ubatch; +} + uint32_t llama_n_seq_max(const struct llama_context * ctx) { return ctx->kv_self.size; } diff --git a/llama.h b/llama.h index 7f89fe6d4b49e..5e1d74a097514 100644 --- a/llama.h +++ b/llama.h @@ -378,6 +378,7 @@ extern "C" { LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);