diff --git a/CMakeLists.txt b/CMakeLists.txt index 56ef52053dd2d..8a65ae8b6afe0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -251,6 +251,15 @@ if (LLAMA_CUBLAS) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) endif() + if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if (LLAMA_CUDA_DMMV_F16) + set(CMAKE_CUDA_ARCHITECTURES "61") # needed for f16 CUDA intrinsics + else() + set(CMAKE_CUDA_ARCHITECTURES "52") # lowest CUDA 12 standard + endif() + endif() + message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") + else() message(WARNING "cuBLAS not found") endif() @@ -525,22 +534,6 @@ if (BUILD_SHARED_LIBS) endif() endif() -if (GGML_SOURCES_CUDA) - message(STATUS "GGML CUDA sources found, configuring CUDA architecture") - set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES "native") - set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") - - set_property(TARGET ggml_static PROPERTY CUDA_ARCHITECTURES "native") - set_property(TARGET ggml_static PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") - - if (BUILD_SHARED_LIBS) - set_property(TARGET ggml_shared PROPERTY CUDA_ARCHITECTURES "native") - set_property(TARGET ggml_shared PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") - endif() - - set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES "native") -endif() - # # programs, examples and tests diff --git a/README.md b/README.md index 2d05de333cb23..ace588606ee8c 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,8 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ **Hot topics:** +- p1 : LLM-based code completion engine at the edge : https://github.com/ggml-org/p1/discussions/1 - Roadmap June 2023: https://github.com/ggerganov/llama.cpp/discussions/1729 -- GPU support with Metal (Apple Silicon): https://github.com/ggerganov/llama.cpp/pull/1642 -- High-quality 2,3,4,5,6-bit quantization: https://github.com/ggerganov/llama.cpp/pull/1684 -- Multi-GPU support: https://github.com/ggerganov/llama.cpp/pull/1607 -- Training LLaMA models from scratch: https://github.com/ggerganov/llama.cpp/pull/1652 -- CPU threading improvements: https://github.com/ggerganov/llama.cpp/pull/1632
Table of Contents @@ -344,7 +340,7 @@ Building the program with BLAS support may lead to some performance improvements | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | | LLAMA_CUDA_DMMV_Y | Positive integer | 1 | Block size in y direction for the CUDA dequantization + mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | | LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. | - | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value 2 1 can improve performance for slow GPUs. | + | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | - #### CLBlast @@ -378,7 +374,7 @@ Building the program with BLAS support may lead to some performance improvements ```sh git clone https://github.com/CNugteren/CLBlast.git mkdir CLBlast/build - cd CLBLast/build + cd CLBlast/build cmake .. -DBUILD_SHARED_LIBS=OFF -DTUNERS=OFF cmake --build . --config Release cmake --install . --prefix /some/path diff --git a/convert.py b/convert.py index 265c41fa04b18..de6c39c67672b 100644 --- a/convert.py +++ b/convert.py @@ -130,6 +130,14 @@ def make_tensors_list() -> List[str]: TENSORS_SET = set(TENSORS_LIST) +def find_n_mult(n_ff: int, n_embd: int) -> int: + # hardcoded magic range + for n_mult in range(256, 1, -1): + calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult + if calc_ff == n_ff: + return n_mult + return 1 + @dataclass class Params: n_vocab: int @@ -137,21 +145,61 @@ class Params: n_mult: int n_head: int n_layer: int - file_type: GGMLFileType @staticmethod - def guessed(model: 'LazyModel', file_type: GGMLFileType) -> 'Params': - n_vocab, n_embd = model["tok_embeddings.weight"].shape + def guessed(model: 'LazyModel') -> 'Params': + # try transformer naming first + n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape + + # try transformer naming first + if "model.layers.0.self_attn.q_proj.weight" in model: + n_layer=next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model) + else: + n_layer=next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model) + + n_head=n_embd // 128 # guessed return Params( n_vocab=n_vocab, n_embd=n_embd, n_mult=256, - n_head=n_embd // 128, - n_layer=next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model), - file_type=file_type, + n_head=n_head, + n_layer=n_layer, ) + @staticmethod + def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params': + config = json.load(open(config_path)) + + n_vocab = config["vocab_size"]; + n_embd = config["hidden_size"]; + n_head = config["num_attention_heads"]; + n_layer = config["num_hidden_layers"]; + n_ff = config["intermediate_size"]; + + n_mult = find_n_mult(n_ff, n_embd); + + return Params( + n_vocab=n_vocab, + n_embd=n_embd, + n_mult=n_mult, + n_head=n_head, + n_layer=n_layer, + ) + + @staticmethod + def load(model_plus: 'ModelPlus') -> 'Params': + orig_config_path = model_plus.paths[0].parent / "params.json" + hf_transformer_config_path = model_plus.paths[0].parent / "config.json" + + if hf_transformer_config_path.exists(): + params = Params.loadHFTransformerJson(model_plus.model, hf_transformer_config_path) + else: + params = Params.guessed(model_plus.model) + + print(f'params: n_vocab:{params.n_vocab} n_embd:{params.n_embd} n_mult:{params.n_mult} n_head:{params.n_head} n_layer:{params.n_layer}') + return params + class SentencePieceVocab: def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: @@ -595,18 +643,17 @@ def load() -> Tensor: return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description) -def convert_transformers_to_orig(model: LazyModel) -> LazyModel: +def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel: out: LazyModel = {} out["tok_embeddings.weight"] = model["model.embed_tokens.weight"] out["norm.weight"] = model["model.norm.weight"] out["output.weight"] = model["lm_head.weight"] - n_head = model["model.layers.0.self_attn.q_proj.weight"].shape[1] // 128 for i in itertools.count(): if f"model.layers.{i}.self_attn.q_proj.weight" not in model: break - out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], n_head) - out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], n_head) + out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head) + out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head) out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"] @@ -920,7 +967,7 @@ class OutputFile: def __init__(self, fname_out: Path) -> None: self.fout = open(fname_out, "wb") - def write_file_header(self, params: Params) -> None: + def write_file_header(self, params: Params, file_type: GGMLFileType) -> None: self.fout.write(b"ggjt"[::-1]) # magic values = [ 1, # file version @@ -930,7 +977,7 @@ def write_file_header(self, params: Params) -> None: params.n_head, params.n_layer, params.n_embd // params.n_head, # rot (obsolete) - params.file_type.value, + file_type.value, ] self.fout.write(struct.pack("i" * len(values), *values)) @@ -958,10 +1005,10 @@ def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: of.fout.close() @staticmethod - def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None: + def write_all(fname_out: Path, params: Params, file_type: GGMLFileType, model: LazyModel, vocab: Vocab) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out) - of.write_file_header(params) + of.write_file_header(params, file_type) print("Writing vocab...") of.write_vocab(vocab) @@ -997,11 +1044,11 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi raise Exception(f"Unexpected combination of types: {name_to_type}") -def do_necessary_conversions(model: LazyModel) -> LazyModel: +def do_necessary_conversions(model: LazyModel, params: Params) -> LazyModel: model = handle_quantization(model) if "lm_head.weight" in model: - model = convert_transformers_to_orig(model) + model = convert_transformers_to_orig(model, params) model = filter_and_sort_tensors(model) return model @@ -1107,14 +1154,14 @@ def load_vocab(path: Path) -> SentencePieceVocab: return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) -def default_outfile(model_paths: List[Path], params: Params) -> Path: +def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path: namestr = { GGMLFileType.AllF32: "f32", GGMLFileType.MostlyF16: "f16", GGMLFileType.MostlyQ4_0: "q4_0", GGMLFileType.MostlyQ4_1: "q4_1", GGMLFileType.PerLayerIsQ4_1: "q4_1", - }[params.file_type] + }[file_type] ret = model_paths[0].parent / f"ggml-model-{namestr}.bin" if ret in model_paths: sys.stderr.write( @@ -1164,13 +1211,13 @@ def main(args_in: Optional[List[str]] = None) -> None: else: vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab = load_vocab(vocab_dir) + params = Params.load(model_plus) model = model_plus.model - model = do_necessary_conversions(model) + model = do_necessary_conversions(model, params) output_type = pick_output_type(model, args.outtype) model = convert_to_output_type(model, output_type) - params = Params.guessed(model, output_type) - outfile = args.outfile or default_outfile(model_plus.paths, params) - OutputFile.write_all(outfile, params, model, vocab) + outfile = args.outfile or default_outfile(model_plus.paths, output_type) + OutputFile.write_all(outfile, params, output_type, model, vocab) print(f"Wrote {outfile}") diff --git a/llama.cpp b/llama.cpp index 4a7d01b3297b2..e597f5048234b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -925,21 +925,21 @@ static bool kv_cache_init( struct llama_context_params llama_context_default_params() { struct llama_context_params result = { + /*.seed =*/ -1, /*.n_ctx =*/ 512, /*.n_batch =*/ 512, /*.gpu_layers =*/ 0, /*.main_gpu =*/ 0, /*.tensor_split =*/ {0}, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, /*.low_vram =*/ false, - /*.seed =*/ -1, /*.f16_kv =*/ true, /*.logits_all =*/ false, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, /*.embedding =*/ false, - /*.progress_callback =*/ nullptr, - /*.progress_callback_user_data =*/ nullptr, }; return result; diff --git a/llama.h b/llama.h index 1241ba6c0ec44..0de530d456932 100644 --- a/llama.h +++ b/llama.h @@ -71,28 +71,27 @@ extern "C" { typedef void (*llama_progress_callback)(float progress, void *ctx); - struct llama_context_params { + struct llama_context_params { + int seed; // RNG seed, -1 for random int n_ctx; // text context int n_batch; // prompt processing batch size int n_gpu_layers; // number of layers to store in VRAM int main_gpu; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs - bool low_vram; // if true, reduce VRAM usage at the cost of performance - int seed; // RNG seed, -1 for random + // called with a progress value between 0 and 1, pass NULL to disable + llama_progress_callback progress_callback; + // context pointer passed to the progress callback + void * progress_callback_user_data; + // Keep the booleans together to avoid misalignment during copy-by-value. + bool low_vram; // if true, reduce VRAM usage at the cost of performance bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM bool embedding; // embedding mode only - - // called with a progress value between 0 and 1, pass NULL to disable - llama_progress_callback progress_callback; - // context pointer passed to the progress callback - void * progress_callback_user_data; }; - // model file types enum llama_ftype { LLAMA_FTYPE_ALL_F32 = 0,