Skip to content

Commit

Permalink
check n_ubatch >= n_tokens with non-casual attention
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Mar 13, 2024
1 parent 54cdd47 commit cda49d3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
5 changes: 3 additions & 2 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1738,7 +1738,8 @@ struct server_context {
}

// process in chunks of params.n_batch
int32_t n_batch = params.n_batch;
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);

// next, batch any pending prompts without exceeding n_batch
if (params.cont_batching || batch.n_tokens == 0) {
Expand Down Expand Up @@ -1811,7 +1812,7 @@ struct server_context {

if (slot.embedding) {
// this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_batch) {
if (slot.n_prompt_tokens > n_ubatch) {
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
slot.release();
Expand Down
9 changes: 6 additions & 3 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8774,6 +8774,8 @@ static int llama_decode_internal(

GGML_ASSERT(n_tokens_all <= cparams.n_batch);

GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");

if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
Expand Down Expand Up @@ -9011,9 +9013,6 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_MEAN:
{
// FIXME: this may not work if the sequences are split into different batches
GGML_ASSERT(n_tokens_all == n_tokens);

GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);

// extract sequence embeddings
Expand Down Expand Up @@ -13076,6 +13075,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
return ctx->cparams.n_batch;
}

uint32_t llama_n_ubatch(const struct llama_context * ctx) {
return ctx->cparams.n_ubatch;
}

uint32_t llama_n_seq_max(const struct llama_context * ctx) {
return ctx->kv_self.size;
}
Expand Down
1 change: 1 addition & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ extern "C" {

LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);

LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
Expand Down

0 comments on commit cda49d3

Please sign in to comment.