-
Notifications
You must be signed in to change notification settings - Fork 329
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support QAT, export and inference for quantized BERT, GPT2 (#285)
* modify readme of examples * modify table in example readme * add cpp example of quant_transformer * support huggingface bert ptq (stage 1) * fix huggingface bert weight loading fp16 bug * finetune quant bert from fp16 ckpt * add emb quant of bert * add example of hf bert squad training, modify dir of huggingface training * format * rename huggingface dir to fix conflict with datasets * fix typo of gpt * export fairseq models to hdf5 * quant hdf5 load (stage 1) * quant hdf5 transformer finished * fix fairseq infer bug * export quant beert, delete hf quant pos emb * add quant bert files * support quant bert inference (not test) * fix quant bert expoort name bug * support quant bert inference * update black pre-coommit version * add quant bert test example * support cpp quant bert example * format * modify readme * do not use ffn2 out quant if using gelu * polish gemm test * fix gemm test lt col bug * support gpt2 qat * add causal mask for gpt encoder * support quant gpt export * add quant gpt required files * support quant gpt inference (stage 1) * add fake quant for logits gemm * support quant gpt inference (stage 2) * support quant gpt inference (stage 3) * support quant gpt inference (ppl) * support quant gpt inference (TODO: fix qkv bias out clip_max, sampling) * support quant gpt inference (ppl) * support quant gpt inference (sampling) * support quant decoder sampling * modify readme (add install command) * optimizer quant gpt gemm, fix gelu bug * optimize cpp example * replace quant gpt cache memcpy with pointer wsitch * fuse quant gpt softmax kernel * optimize quant gpt arrange-qkv kernel * modify PiPI spelling * fix gpt memory spelling
- Loading branch information
1 parent
3c1c506
commit 4024ae1
Showing
111 changed files
with
9,766 additions
and
1,715 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
## Dockerfiles of lightseq | ||
|
||
Pypi: for publish python package. | ||
PyPI: for publish python package. | ||
|
||
Tritonserver: for publish tritonserver |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#include "model_base.h" | ||
#include "util.h" | ||
|
||
/** | ||
@file | ||
Example of how to run QuantBert inference using our implementation. | ||
*/ | ||
|
||
int main(int argc, char* argv[]) { | ||
std::string model_weights_path = argv[1]; | ||
std::vector<int> example_input = {2859, 2758, 2051, 2157, | ||
2005, 6629, 7566, 1012}; | ||
int eg_seq_len = example_input.size(); | ||
int max_batch_size = 128; | ||
int batch_size = 1; | ||
int batch_seq_len = eg_seq_len; | ||
|
||
if (argc == 4) { | ||
batch_size = atoi(argv[2]); | ||
batch_seq_len = atoi(argv[3]); | ||
} | ||
if (batch_size > max_batch_size) { | ||
throw std::runtime_error("batch_size exceeds the maximum (128)!"); | ||
} | ||
|
||
std::vector<int> host_input; | ||
for (int i = 0; i < batch_size; ++i) { | ||
for (int j = 0; j < batch_seq_len; ++j) { | ||
host_input.push_back(example_input[j % eg_seq_len]); | ||
} | ||
} | ||
|
||
auto model = lightseq::cuda::LSModelFactory::GetInstance().CreateModel( | ||
"QuantBert", model_weights_path, max_batch_size); | ||
|
||
void* d_input; | ||
lightseq::cuda::CHECK_GPU_ERROR( | ||
cudaMalloc(&d_input, sizeof(int) * batch_size * batch_seq_len)); | ||
lightseq::cuda::CHECK_GPU_ERROR(cudaMemcpy( | ||
d_input, host_input.data(), sizeof(int) * batch_size * batch_seq_len, | ||
cudaMemcpyHostToDevice)); | ||
|
||
model->set_input_ptr(0, d_input); | ||
model->set_input_shape(0, {batch_size, batch_seq_len}); | ||
|
||
for (int i = 0; i < model->get_output_size(); i++) { | ||
void* d_output; | ||
std::vector<int> shape = model->get_output_max_shape(i); | ||
int total_size = 1; | ||
for (int j = 0; j < shape.size(); j++) { | ||
total_size *= shape[j]; | ||
} | ||
lightseq::cuda::CHECK_GPU_ERROR( | ||
cudaMalloc(&d_output, total_size * sizeof(int))); | ||
model->set_output_ptr(i, d_output); | ||
} | ||
lightseq::cuda::CHECK_GPU_ERROR(cudaStreamSynchronize(0)); | ||
std::cout << "infer preprocessing finished" << std::endl; | ||
|
||
/* ---step5. infer and log--- */ | ||
for (int i = 0; i < 10; i++) { | ||
auto start = std::chrono::high_resolution_clock::now(); | ||
model->Infer(); | ||
lightseq::cuda::print_time_duration(start, "one infer time", 0); | ||
} | ||
|
||
for (int i = 0; i < model->get_output_size(); i++) { | ||
const float* d_output; | ||
d_output = static_cast<const float*>(model->get_output_ptr(i)); | ||
std::vector<int> shape = model->get_output_shape(i); | ||
std::cout << "output shape: "; | ||
for (int j = 0; j < shape.size(); j++) { | ||
std::cout << shape[j] << " "; | ||
} | ||
std::cout << std::endl; | ||
|
||
lightseq::cuda::print_vec(d_output, "output", 5); | ||
} | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
#include "model_base.h" | ||
#include "gpt.h" | ||
|
||
/** | ||
@file | ||
Example of how to run gpt inference using our implementation. | ||
*/ | ||
|
||
int main(int argc, char* argv[]) { | ||
std::string model_weights_path = argv[1]; | ||
std::vector<int> example_input = {40, 1842, 345, 11, 475, 345, 910, 326}; | ||
int eg_seq_len = example_input.size(); | ||
int max_batch_size = 128; | ||
int batch_size = 1; | ||
int batch_seq_len = eg_seq_len; | ||
|
||
if (argc == 4) { | ||
batch_size = atoi(argv[2]); | ||
batch_seq_len = atoi(argv[3]); | ||
} | ||
if (batch_size > max_batch_size) { | ||
throw std::runtime_error("batch_size exceeds the maximum (128)!"); | ||
} | ||
|
||
std::vector<int> host_input; | ||
for (int i = 0; i < batch_size; ++i) { | ||
for (int j = 0; j < batch_seq_len; ++j) { | ||
host_input.push_back(example_input[j % eg_seq_len]); | ||
} | ||
} | ||
|
||
auto model = lightseq::cuda::LSModelFactory::GetInstance().CreateModel( | ||
"QuantGpt", model_weights_path, max_batch_size); | ||
|
||
void* d_input; | ||
lightseq::cuda::CHECK_GPU_ERROR( | ||
cudaMalloc(&d_input, sizeof(int) * batch_size * batch_seq_len)); | ||
lightseq::cuda::CHECK_GPU_ERROR(cudaMemcpy( | ||
d_input, host_input.data(), sizeof(int) * batch_size * batch_seq_len, | ||
cudaMemcpyHostToDevice)); | ||
|
||
model->set_input_ptr(0, d_input); | ||
model->set_input_shape(0, {batch_size, batch_seq_len}); | ||
|
||
for (int i = 0; i < model->get_output_size(); i++) { | ||
void* d_output; | ||
std::vector<int> shape = model->get_output_max_shape(i); | ||
int total_size = 1; | ||
for (int j = 0; j < shape.size(); j++) { | ||
total_size *= shape[j]; | ||
} | ||
lightseq::cuda::CHECK_GPU_ERROR( | ||
cudaMalloc(&d_output, total_size * sizeof(int))); | ||
model->set_output_ptr(i, d_output); | ||
} | ||
lightseq::cuda::CHECK_GPU_ERROR(cudaStreamSynchronize(0)); | ||
std::cout << "infer preprocessing finished" << std::endl; | ||
|
||
/* ---step5. infer and log--- */ | ||
for (int i = 0; i < 10; i++) { | ||
auto start = std::chrono::high_resolution_clock::now(); | ||
model->Infer(); | ||
lightseq::cuda::print_time_duration(start, "one infer time", 0); | ||
} | ||
|
||
for (int i = 0; i < model->get_output_size(); i++) { | ||
const int* d_output; | ||
d_output = static_cast<const int*>(model->get_output_ptr(i)); | ||
std::vector<int> shape = model->get_output_shape(i); | ||
std::cout << "output shape: "; | ||
for (int j = 0; j < shape.size(); j++) { | ||
std::cout << shape[j] << " "; | ||
} | ||
std::cout << std::endl; | ||
|
||
lightseq::cuda::print_vec(d_output, "output", 10); | ||
} | ||
|
||
return 0; | ||
} |
Oops, something went wrong.