Skip to content

Commit

Permalink
Allow passing grammar to completion endpoint (ggerganov#2532)
Browse files Browse the repository at this point in the history
* Allow passing grammar to completion endpoint
  • Loading branch information
krasserm authored Aug 8, 2023
1 parent acfc547 commit f5bfea0
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ embedding: examples/embedding/embedding.cpp build-info.h ggml.
save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o $(OBJS)
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)

$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS)
Expand Down
2 changes: 2 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ node .

`mirostat_eta`: Set the Mirostat learning rate, parameter eta (default: 0.1).

`grammar`: Set grammar for grammar-based sampling (default: no grammar)

`seed`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).

`ignore_eos`: Ignore end of stream token and continue generating (default: false).
Expand Down
60 changes: 58 additions & 2 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "common.h"
#include "llama.h"
#include "build-info.h"
#include "grammar-parser.h"

#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
Expand Down Expand Up @@ -195,6 +196,8 @@ struct llama_server_context
llama_context *ctx = nullptr;
gpt_params params;

llama_grammar *grammar = nullptr;

bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
Expand Down Expand Up @@ -226,6 +229,7 @@ struct llama_server_context
void rewind()
{
params.antiprompt.clear();
params.grammar.clear();
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
Expand All @@ -237,6 +241,7 @@ struct llama_server_context
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
grammar = nullptr;

n_remain = 0;
n_past = 0;
Expand All @@ -257,6 +262,33 @@ struct llama_server_context
return true;
}

bool loadGrammar()
{
if (!params.grammar.empty()) {
grammar_parser::parse_state parsed_grammar;

parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
return false;
}
grammar_parser::print_grammar(stderr, parsed_grammar);

{
auto it = params.logit_bias.find(llama_token_eos());
if (it != params.logit_bias.end() && it->second == -INFINITY) {
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
}
}

std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
return true;
}

void loadPrompt()
{
params.prompt.insert(0, 1, ' '); // always add a first space
Expand Down Expand Up @@ -420,6 +452,10 @@ struct llama_server_context
logits[llama_token_nl()] = nl_logit;
}

if (grammar != nullptr) {
llama_sample_grammar(ctx, &candidates_p, grammar);
}

if (temp <= 0)
{
// Greedy sampling
Expand Down Expand Up @@ -457,10 +493,15 @@ struct llama_server_context
}
}

if (grammar != nullptr) {
llama_grammar_accept_token(ctx, grammar, result.tok);
}

for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
{
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
}

last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(result.tok);
num_tokens_predicted++;
Expand Down Expand Up @@ -947,6 +988,7 @@ static json format_generation_settings(llama_server_context &llama)
{"stream", llama.stream},
{"logit_bias", llama.params.logit_bias},
{"n_probs", llama.params.n_probs},
{"grammar", llama.params.grammar},
};
}

Expand Down Expand Up @@ -1048,6 +1090,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
llama.params.seed = body.value("seed", default_params.seed);
llama.params.prompt = body.value("prompt", default_params.prompt);
llama.params.grammar = body.value("grammar", default_params.grammar);
llama.params.n_probs = body.value("n_probs", default_params.n_probs);

llama.params.logit_bias.clear();
Expand Down Expand Up @@ -1179,6 +1222,12 @@ int main(int argc, char **argv)

parse_options_completion(json::parse(req.body), llama);

if (!llama.loadGrammar())
{
res.status = 400;
return;
}

llama.loadPrompt();
llama.beginCompletion();

Expand Down Expand Up @@ -1334,8 +1383,12 @@ int main(int argc, char **argv)

svr.set_error_handler([](const Request &, Response &res)
{
res.set_content("File Not Found", "text/plain");
res.status = 404; });
if (res.status == 400) {
res.set_content("Invalid request", "text/plain");
} else {
res.set_content("File Not Found", "text/plain");
res.status = 404;
} });

// set timeouts and change hostname and port
svr.set_read_timeout(sparams.read_timeout);
Expand Down Expand Up @@ -1363,6 +1416,9 @@ int main(int argc, char **argv)
return 1;
}

if (llama.grammar != nullptr) {
llama_grammar_free(llama.grammar);
}
llama_backend_free();

return 0;
Expand Down

0 comments on commit f5bfea0

Please sign in to comment.