-
Notifications
You must be signed in to change notification settings - Fork 329
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init example of vit training * init vit proto * init patch emb kernel * init model and pywrapper * fix bugs * init vit export example * fix export bug * fix blockreducesum bug * update export * fix last layernorm * rm redudent ispostln * update readme and test example * with_lightseq true * delete useless moeKernel * support channel*patch*patch>=1024 * update pre-commit and code format * update format of run_vit.sh
- Loading branch information
Showing
25 changed files
with
2,255 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
149 changes: 149 additions & 0 deletions
149
examples/inference/python/export/huggingface/hf_vit_export.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
""" | ||
Export Hugging Face ViT models to hdf5 format. | ||
""" | ||
import os | ||
import h5py | ||
from collections import OrderedDict | ||
from transformers import ViTModel | ||
from lightseq.training.ops.pytorch.export import fill_hdf5_layer | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | ||
|
||
|
||
""" | ||
For the mapping dictionary: key is the value of the proto parameter, | ||
value is a powerful expression, each && split tensor name of the matching path or expression. | ||
The sub-pattern of the path is separated by spaces, and the expression starts with a expression_. | ||
You can operate separately on each tensor and support multiple expressions. Multiple matching paths | ||
and the expression will finally be concatenated on axis = -1. | ||
""" | ||
enc_layer_mapping_dict = OrderedDict( | ||
{ | ||
# VIT is pre_layernorm | ||
# NOTE: add an additional "final" at the beginning for some weight | ||
# to distinguish them from "attention output *" | ||
"multihead_norm_scale": "layernorm_before weight", | ||
"multihead_norm_bias": "layernorm_before bias", | ||
"multihead_project_kernel_qkv": "attention attention query weight&&attention attention key weight&&attention attention value weight&&expression_.transpose(0, 1)", | ||
"multihead_project_bias_qkv": "attention attention query bias&&attention attention key bias&&attention attention value bias", | ||
"multihead_project_kernel_output": "attention output dense weight&&expression_.transpose(0, 1)", | ||
"multihead_project_bias_output": "attention output dense bias", | ||
"ffn_norm_scale": "layernorm_after weight", | ||
"ffn_norm_bias": "layernorm_after bias", | ||
"ffn_first_kernel": "intermediate dense weight&&expression_.transpose(0, 1)", | ||
"ffn_first_bias": "intermediate dense bias", | ||
"ffn_second_kernel": "final output dense weight&&expression_.transpose(0, 1)", | ||
"ffn_second_bias": "final output dense bias", | ||
} | ||
) | ||
|
||
src_emb_mapping_dict = OrderedDict( | ||
{ | ||
"conv_weight": "embeddings patch_embeddings projection weight", | ||
"conv_bias": "embeddings patch_embeddings projection bias", | ||
"position_embedding": "embeddings position_embeddings", | ||
"cls_embedding": "embeddings cls_token", | ||
"norm_scale": "layernorm weight", | ||
"norm_bias": "layernorm bias", | ||
} | ||
) | ||
|
||
|
||
def extract_vit_weights( | ||
output_file, | ||
model_dir, | ||
head_num, | ||
image_size, | ||
patch_size, | ||
): | ||
# load var names | ||
encoder_state_dict = ViTModel.from_pretrained(model_dir).state_dict() | ||
|
||
# Insert additional "final" to some weight to prevent ambiguous match | ||
def _insert_final(key): | ||
l = key.split(".") | ||
l.insert(3, "final") | ||
return ".".join(l) | ||
|
||
encoder_state_dict = OrderedDict( | ||
[ | ||
(_insert_final(k), v) | ||
if len(k.split(".")) > 3 and k.split(".")[3] == "output" | ||
else (k, v) | ||
for k, v in encoder_state_dict.items() | ||
] | ||
) | ||
|
||
enc_var_name_list = list(encoder_state_dict.keys()) | ||
|
||
# initialize output file | ||
output_file += ".hdf5" | ||
print("Saving model to hdf5...") | ||
print("Writing to {0}".format(output_file)) | ||
hdf5_file = h5py.File(output_file, "w") | ||
|
||
# fill each encoder layer's params | ||
enc_tensor_names = {} | ||
for name in enc_var_name_list: | ||
name_split = name.split(".") | ||
if len(name_split) <= 2 or not name_split[2].isdigit(): | ||
continue | ||
layer_id = int(name_split[2]) | ||
enc_tensor_names.setdefault(layer_id, []).append(name) | ||
|
||
# fill encoder_stack | ||
for layer_id in sorted(enc_tensor_names.keys()): | ||
fill_hdf5_layer( | ||
enc_tensor_names[layer_id], | ||
encoder_state_dict, | ||
hdf5_file, | ||
f"encoder_stack/{layer_id}/", | ||
enc_layer_mapping_dict, | ||
) | ||
|
||
# fill src_embedding - except for position embedding | ||
fill_hdf5_layer( | ||
enc_var_name_list, | ||
encoder_state_dict, | ||
hdf5_file, | ||
"src_embedding/", | ||
src_emb_mapping_dict, | ||
) | ||
|
||
# save number of layers metadata | ||
hdf5_file.create_dataset( | ||
"model_conf/n_encoder_stack", data=len(enc_tensor_names), dtype="i4" | ||
) | ||
# fill in model_conf | ||
hdf5_file.create_dataset("model_conf/head_num", data=head_num, dtype="i4") | ||
hdf5_file.create_dataset("model_conf/use_gelu", data=True, dtype="?") | ||
hdf5_file.create_dataset("model_conf/is_post_ln", data=False, dtype="?") | ||
hdf5_file.create_dataset("model_conf/image_size", data=image_size, dtype="i4") | ||
hdf5_file.create_dataset("model_conf/patch_size", data=patch_size, dtype="i4") | ||
|
||
hdf5_file.close() | ||
# read-in again to double check | ||
hdf5_file = h5py.File(output_file, "r") | ||
|
||
def _print_pair(key, value): | ||
value = value[()] | ||
print(f"{key}: {value}") | ||
|
||
list(map(lambda x: _print_pair(*x), hdf5_file["model_conf"].items())) | ||
|
||
|
||
if __name__ == "__main__": | ||
output_lightseq_model_name = "lightseq_vit" | ||
input_huggingface_vit_model = "google/vit-base-patch16-224-in21k" | ||
head_number = 12 | ||
image_size = 224 | ||
patch_size = 16 | ||
|
||
extract_vit_weights( | ||
output_lightseq_model_name, | ||
input_huggingface_vit_model, | ||
head_number, | ||
image_size, | ||
patch_size, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import time | ||
import torch | ||
import lightseq.inference as lsi | ||
from transformers import ViTFeatureExtractor, ViTForImageClassification | ||
from PIL import Image | ||
import requests | ||
|
||
|
||
def ls_vit(model, inputs): | ||
torch.cuda.synchronize() | ||
start_time = time.perf_counter() | ||
ls_output = model.infer(inputs) | ||
torch.cuda.synchronize() | ||
end_time = time.perf_counter() | ||
return ls_output, end_time - start_time | ||
|
||
|
||
def hf_vit(model, inputs): | ||
torch.cuda.synchronize() | ||
start_time = time.perf_counter() | ||
hf_output = model(inputs.cuda()) | ||
torch.cuda.synchronize() | ||
end_time = time.perf_counter() | ||
return hf_output, end_time - start_time | ||
|
||
|
||
def ls_generate(model, inputs): | ||
print("=========lightseq=========") | ||
print("lightseq generating...") | ||
ls_output, ls_time = ls_vit(model, inputs) | ||
print(f"lightseq time: {ls_time}s") | ||
print("lightseq results (class predictions):") | ||
print(ls_output.argmax(axis=1).detach().cpu().numpy()) | ||
|
||
|
||
def hf_generate(model, inputs): | ||
print("=========huggingface=========") | ||
print("huggingface generating...") | ||
hf_output, hf_time = hf_vit(model, inputs) | ||
print(f"huggingface time: {hf_time}s") | ||
print("huggingface results (class predictions):") | ||
print(hf_output.logits.argmax(axis=1).detach().cpu().numpy()) | ||
|
||
|
||
def one_infer(inputs, ls_model, hf_model): | ||
ls_generate(ls_model, inputs) | ||
hf_generate(hf_model, inputs) | ||
|
||
|
||
class LightseqVitClassification: | ||
def __init__(self, ls_weight_path, hf_model): | ||
self.ls_vit = lsi.Vit(ls_weight_path, 8) | ||
self.classifier = hf_model.classifier | ||
|
||
def infer(self, inputs): | ||
last_hidden_states = self.ls_vit.infer(inputs) | ||
last_hidden_states = torch.Tensor(last_hidden_states).float().cuda() | ||
logits = self.classifier(last_hidden_states[:, 0, :]) | ||
return logits | ||
|
||
|
||
def main(): | ||
|
||
url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
image = Image.open(requests.get(url, stream=True).raw) | ||
feature_extractor = ViTFeatureExtractor.from_pretrained( | ||
"google/vit-base-patch16-224-in21k" | ||
) | ||
inputs = feature_extractor(images=image, return_tensors="pt") | ||
inputs = inputs["pixel_values"] | ||
|
||
print("creating huggingface model...") | ||
hf_model = ViTForImageClassification.from_pretrained( | ||
"google/vit-base-patch16-224-in21k" | ||
).cuda() | ||
|
||
print("creating lightseq model...") | ||
ls_model = LightseqVitClassification("lightseq_vit.hdf5", hf_model) | ||
|
||
print("====================START warmup====================") | ||
one_infer(inputs, ls_model, hf_model) | ||
print("====================END warmup====================") | ||
|
||
one_infer(inputs, ls_model, hf_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Empty file.
71 changes: 71 additions & 0 deletions
71
examples/training/huggingface/vit/ls_hf_vit_encoder_layer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import torch | ||
from lightseq.training.ops.pytorch.transformer_encoder_layer import ( | ||
LSTransformerEncoderLayer, | ||
) | ||
|
||
|
||
class LSVITTransformerEncoderLayer(LSTransformerEncoderLayer): | ||
def __init__(self, *args, **kwargs): | ||
super(LSVITTransformerEncoderLayer, self).__init__(*args, **kwargs) | ||
|
||
def forward(self, hidden_states, *args, **kwargs): | ||
ls_encoder_padding_mask = torch.zeros(hidden_states.size()[:-1]) | ||
output = super().forward(hidden_states, ls_encoder_padding_mask) | ||
return (output,) | ||
|
||
|
||
def gen_vit_config(training_args, config): | ||
num_patches = (config.image_size // config.patch_size) ** 2 + 1 | ||
max_batch_size = max( | ||
training_args.per_device_train_batch_size, | ||
training_args.per_device_eval_batch_size, | ||
) | ||
vit_config = LSTransformerEncoderLayer.get_config( | ||
max_batch_tokens=num_patches * max_batch_size, | ||
max_seq_len=num_patches, | ||
hidden_size=config.hidden_size, | ||
intermediate_size=config.intermediate_size, | ||
nhead=config.num_attention_heads, | ||
attn_prob_dropout_ratio=config.attention_probs_dropout_prob, | ||
activation_dropout_ratio=config.hidden_dropout_prob, | ||
hidden_dropout_ratio=config.hidden_dropout_prob, | ||
pre_layer_norm=True, | ||
fp16=training_args.fp16, | ||
local_rank=training_args.local_rank, | ||
activation_fn="gelu", | ||
) | ||
return vit_config | ||
|
||
|
||
def inject_ls_enc_layer(model, training_args, config): | ||
for i in range(config.num_hidden_layers): | ||
vit_config = gen_vit_config(training_args, config) | ||
init_ws, init_bs = get_hf_vit_enc_layer_params(model.vit.encoder.layer[i]) | ||
model.vit.encoder.layer[i] = LSVITTransformerEncoderLayer( | ||
vit_config, init_ws, init_bs | ||
).cuda() | ||
|
||
|
||
def get_hf_vit_enc_layer_params(layer): | ||
init_ws = [] | ||
init_bs = [] | ||
|
||
init_ws.append(layer.attention.attention.query.weight.detach().clone()) | ||
init_bs.append(layer.attention.attention.query.bias.detach().clone()) | ||
init_ws.append(layer.attention.attention.key.weight.detach().clone()) | ||
init_bs.append(layer.attention.attention.key.bias.detach().clone()) | ||
init_ws.append(layer.attention.attention.value.weight.detach().clone()) | ||
init_bs.append(layer.attention.attention.value.bias.detach().clone()) | ||
init_ws.append(layer.attention.output.dense.weight.detach().clone()) | ||
init_bs.append(layer.attention.output.dense.bias.detach().clone()) | ||
init_ws.append(layer.layernorm_before.weight.detach().clone()) | ||
init_bs.append(layer.layernorm_before.bias.detach().clone()) | ||
|
||
init_ws.append(layer.intermediate.dense.weight.detach().clone()) | ||
init_bs.append(layer.intermediate.dense.bias.detach().clone()) | ||
init_ws.append(layer.output.dense.weight.detach().clone()) | ||
init_bs.append(layer.output.dense.bias.detach().clone()) | ||
init_ws.append(layer.layernorm_after.weight.detach().clone()) | ||
init_bs.append(layer.layernorm_after.bias.detach().clone()) | ||
|
||
return init_ws, init_bs |
Oops, something went wrong.