Skip to content

Commit

Permalink
llama : use enum types over int
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Mar 3, 2024
1 parent efecd06 commit 4661363
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
8 changes: 5 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ struct gpt_params {
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings

ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings

// // sampling parameters
struct llama_sampling_params sparams;
Expand Down
18 changes: 9 additions & 9 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,16 +873,16 @@ struct LLM_TN {
// gguf helpers
//

static const std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = {
static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
};

static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
if (kv.second == name) {
return kv.first;
return (llama_rope_scaling_type) kv.first;
}
}

Expand Down Expand Up @@ -1612,16 +1612,16 @@ struct llama_hparams {
float rope_freq_base_train;
float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx;
int32_t rope_scaling_type_train;

float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;

bool causal_attn = true;
bool need_kq_pos = false;

enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;

bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true;
Expand Down Expand Up @@ -1670,8 +1670,8 @@ struct llama_cparams {
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing

float rope_freq_base;
float rope_freq_scale;
float rope_freq_base;
float rope_freq_scale;

uint32_t n_yarn_orig_ctx;
// These hyperparameters are not exposed in GGUF, because all
Expand Down Expand Up @@ -11848,6 +11848,7 @@ struct llama_context_params llama_context_default_params() {
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ -1.0f,
Expand All @@ -11863,7 +11864,6 @@ struct llama_context_params llama_context_default_params() {
/*.logits_all =*/ false,
/*.embedding =*/ false,
/*.offload_kqv =*/ true,
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};
Expand Down
6 changes: 4 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,10 @@ extern "C" {
uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`

enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
// (ignored if no pooling layer)

// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model
Expand All @@ -259,7 +262,6 @@ extern "C" {
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)

// Abort callback
// if it returns true, execution of llama_decode() will be aborted
Expand Down

0 comments on commit 4661363

Please sign in to comment.