Skip to content

Commit

Permalink
support ViT model (#299)
Browse files Browse the repository at this point in the history
* init example of vit training

* init vit proto

* init patch emb kernel

* init model and pywrapper

* fix bugs

* init vit export example

* fix export bug

* fix blockreducesum bug

* update export

* fix last layernorm

* rm redudent ispostln

* update readme and test example

* with_lightseq true

* delete useless moeKernel

* support channel*patch*patch>=1024

* update pre-commit and code format

* update format of run_vit.sh
  • Loading branch information
zjersey authored Apr 21, 2022
1 parent d0906bf commit b5bc246
Show file tree
Hide file tree
Showing 25 changed files with 2,255 additions and 67 deletions.
12 changes: 11 additions & 1 deletion examples/inference/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ Export Hugging Face GPT2 models to hdf5 format.
```shell
python export/huggingface/hf_gpt2_export.py
```
4. Hugging Face ViT

Export Hugging Face ViT models to hdf5 format.
```shell
python export/huggingface/hf_vit_export.py
```
### Native Fairseq
1. Native Fairseq Transformer

Expand Down Expand Up @@ -112,8 +118,12 @@ python test/ls_bert.py
```shell
python test/ls_gpt2.py
```
4. ViT
```shell
python test/ls_vit.py
```

4. Fairseq based models using LightSeq inference
5. Fairseq based models using LightSeq inference
```shell
bash test/ls_fairseq.sh --model ${model_path}
```
149 changes: 149 additions & 0 deletions examples/inference/python/export/huggingface/hf_vit_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Export Hugging Face ViT models to hdf5 format.
"""
import os
import h5py
from collections import OrderedDict
from transformers import ViTModel
from lightseq.training.ops.pytorch.export import fill_hdf5_layer

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


"""
For the mapping dictionary: key is the value of the proto parameter,
value is a powerful expression, each && split tensor name of the matching path or expression.
The sub-pattern of the path is separated by spaces, and the expression starts with a expression_.
You can operate separately on each tensor and support multiple expressions. Multiple matching paths
and the expression will finally be concatenated on axis = -1.
"""
enc_layer_mapping_dict = OrderedDict(
{
# VIT is pre_layernorm
# NOTE: add an additional "final" at the beginning for some weight
# to distinguish them from "attention output *"
"multihead_norm_scale": "layernorm_before weight",
"multihead_norm_bias": "layernorm_before bias",
"multihead_project_kernel_qkv": "attention attention query weight&&attention attention key weight&&attention attention value weight&&expression_.transpose(0, 1)",
"multihead_project_bias_qkv": "attention attention query bias&&attention attention key bias&&attention attention value bias",
"multihead_project_kernel_output": "attention output dense weight&&expression_.transpose(0, 1)",
"multihead_project_bias_output": "attention output dense bias",
"ffn_norm_scale": "layernorm_after weight",
"ffn_norm_bias": "layernorm_after bias",
"ffn_first_kernel": "intermediate dense weight&&expression_.transpose(0, 1)",
"ffn_first_bias": "intermediate dense bias",
"ffn_second_kernel": "final output dense weight&&expression_.transpose(0, 1)",
"ffn_second_bias": "final output dense bias",
}
)

src_emb_mapping_dict = OrderedDict(
{
"conv_weight": "embeddings patch_embeddings projection weight",
"conv_bias": "embeddings patch_embeddings projection bias",
"position_embedding": "embeddings position_embeddings",
"cls_embedding": "embeddings cls_token",
"norm_scale": "layernorm weight",
"norm_bias": "layernorm bias",
}
)


def extract_vit_weights(
output_file,
model_dir,
head_num,
image_size,
patch_size,
):
# load var names
encoder_state_dict = ViTModel.from_pretrained(model_dir).state_dict()

# Insert additional "final" to some weight to prevent ambiguous match
def _insert_final(key):
l = key.split(".")
l.insert(3, "final")
return ".".join(l)

encoder_state_dict = OrderedDict(
[
(_insert_final(k), v)
if len(k.split(".")) > 3 and k.split(".")[3] == "output"
else (k, v)
for k, v in encoder_state_dict.items()
]
)

enc_var_name_list = list(encoder_state_dict.keys())

# initialize output file
output_file += ".hdf5"
print("Saving model to hdf5...")
print("Writing to {0}".format(output_file))
hdf5_file = h5py.File(output_file, "w")

# fill each encoder layer's params
enc_tensor_names = {}
for name in enc_var_name_list:
name_split = name.split(".")
if len(name_split) <= 2 or not name_split[2].isdigit():
continue
layer_id = int(name_split[2])
enc_tensor_names.setdefault(layer_id, []).append(name)

# fill encoder_stack
for layer_id in sorted(enc_tensor_names.keys()):
fill_hdf5_layer(
enc_tensor_names[layer_id],
encoder_state_dict,
hdf5_file,
f"encoder_stack/{layer_id}/",
enc_layer_mapping_dict,
)

# fill src_embedding - except for position embedding
fill_hdf5_layer(
enc_var_name_list,
encoder_state_dict,
hdf5_file,
"src_embedding/",
src_emb_mapping_dict,
)

# save number of layers metadata
hdf5_file.create_dataset(
"model_conf/n_encoder_stack", data=len(enc_tensor_names), dtype="i4"
)
# fill in model_conf
hdf5_file.create_dataset("model_conf/head_num", data=head_num, dtype="i4")
hdf5_file.create_dataset("model_conf/use_gelu", data=True, dtype="?")
hdf5_file.create_dataset("model_conf/is_post_ln", data=False, dtype="?")
hdf5_file.create_dataset("model_conf/image_size", data=image_size, dtype="i4")
hdf5_file.create_dataset("model_conf/patch_size", data=patch_size, dtype="i4")

hdf5_file.close()
# read-in again to double check
hdf5_file = h5py.File(output_file, "r")

def _print_pair(key, value):
value = value[()]
print(f"{key}: {value}")

list(map(lambda x: _print_pair(*x), hdf5_file["model_conf"].items()))


if __name__ == "__main__":
output_lightseq_model_name = "lightseq_vit"
input_huggingface_vit_model = "google/vit-base-patch16-224-in21k"
head_number = 12
image_size = 224
patch_size = 16

extract_vit_weights(
output_lightseq_model_name,
input_huggingface_vit_model,
head_number,
image_size,
patch_size,
)
88 changes: 88 additions & 0 deletions examples/inference/python/test/ls_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import time
import torch
import lightseq.inference as lsi
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests


def ls_vit(model, inputs):
torch.cuda.synchronize()
start_time = time.perf_counter()
ls_output = model.infer(inputs)
torch.cuda.synchronize()
end_time = time.perf_counter()
return ls_output, end_time - start_time


def hf_vit(model, inputs):
torch.cuda.synchronize()
start_time = time.perf_counter()
hf_output = model(inputs.cuda())
torch.cuda.synchronize()
end_time = time.perf_counter()
return hf_output, end_time - start_time


def ls_generate(model, inputs):
print("=========lightseq=========")
print("lightseq generating...")
ls_output, ls_time = ls_vit(model, inputs)
print(f"lightseq time: {ls_time}s")
print("lightseq results (class predictions):")
print(ls_output.argmax(axis=1).detach().cpu().numpy())


def hf_generate(model, inputs):
print("=========huggingface=========")
print("huggingface generating...")
hf_output, hf_time = hf_vit(model, inputs)
print(f"huggingface time: {hf_time}s")
print("huggingface results (class predictions):")
print(hf_output.logits.argmax(axis=1).detach().cpu().numpy())


def one_infer(inputs, ls_model, hf_model):
ls_generate(ls_model, inputs)
hf_generate(hf_model, inputs)


class LightseqVitClassification:
def __init__(self, ls_weight_path, hf_model):
self.ls_vit = lsi.Vit(ls_weight_path, 8)
self.classifier = hf_model.classifier

def infer(self, inputs):
last_hidden_states = self.ls_vit.infer(inputs)
last_hidden_states = torch.Tensor(last_hidden_states).float().cuda()
logits = self.classifier(last_hidden_states[:, 0, :])
return logits


def main():

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = ViTFeatureExtractor.from_pretrained(
"google/vit-base-patch16-224-in21k"
)
inputs = feature_extractor(images=image, return_tensors="pt")
inputs = inputs["pixel_values"]

print("creating huggingface model...")
hf_model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k"
).cuda()

print("creating lightseq model...")
ls_model = LightseqVitClassification("lightseq_vit.hdf5", hf_model)

print("====================START warmup====================")
one_infer(inputs, ls_model, hf_model)
print("====================END warmup====================")

one_infer(inputs, ls_model, hf_model)


if __name__ == "__main__":
main()
Empty file.
71 changes: 71 additions & 0 deletions examples/training/huggingface/vit/ls_hf_vit_encoder_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
from lightseq.training.ops.pytorch.transformer_encoder_layer import (
LSTransformerEncoderLayer,
)


class LSVITTransformerEncoderLayer(LSTransformerEncoderLayer):
def __init__(self, *args, **kwargs):
super(LSVITTransformerEncoderLayer, self).__init__(*args, **kwargs)

def forward(self, hidden_states, *args, **kwargs):
ls_encoder_padding_mask = torch.zeros(hidden_states.size()[:-1])
output = super().forward(hidden_states, ls_encoder_padding_mask)
return (output,)


def gen_vit_config(training_args, config):
num_patches = (config.image_size // config.patch_size) ** 2 + 1
max_batch_size = max(
training_args.per_device_train_batch_size,
training_args.per_device_eval_batch_size,
)
vit_config = LSTransformerEncoderLayer.get_config(
max_batch_tokens=num_patches * max_batch_size,
max_seq_len=num_patches,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
nhead=config.num_attention_heads,
attn_prob_dropout_ratio=config.attention_probs_dropout_prob,
activation_dropout_ratio=config.hidden_dropout_prob,
hidden_dropout_ratio=config.hidden_dropout_prob,
pre_layer_norm=True,
fp16=training_args.fp16,
local_rank=training_args.local_rank,
activation_fn="gelu",
)
return vit_config


def inject_ls_enc_layer(model, training_args, config):
for i in range(config.num_hidden_layers):
vit_config = gen_vit_config(training_args, config)
init_ws, init_bs = get_hf_vit_enc_layer_params(model.vit.encoder.layer[i])
model.vit.encoder.layer[i] = LSVITTransformerEncoderLayer(
vit_config, init_ws, init_bs
).cuda()


def get_hf_vit_enc_layer_params(layer):
init_ws = []
init_bs = []

init_ws.append(layer.attention.attention.query.weight.detach().clone())
init_bs.append(layer.attention.attention.query.bias.detach().clone())
init_ws.append(layer.attention.attention.key.weight.detach().clone())
init_bs.append(layer.attention.attention.key.bias.detach().clone())
init_ws.append(layer.attention.attention.value.weight.detach().clone())
init_bs.append(layer.attention.attention.value.bias.detach().clone())
init_ws.append(layer.attention.output.dense.weight.detach().clone())
init_bs.append(layer.attention.output.dense.bias.detach().clone())
init_ws.append(layer.layernorm_before.weight.detach().clone())
init_bs.append(layer.layernorm_before.bias.detach().clone())

init_ws.append(layer.intermediate.dense.weight.detach().clone())
init_bs.append(layer.intermediate.dense.bias.detach().clone())
init_ws.append(layer.output.dense.weight.detach().clone())
init_bs.append(layer.output.dense.bias.detach().clone())
init_ws.append(layer.layernorm_after.weight.detach().clone())
init_bs.append(layer.layernorm_after.bias.detach().clone())

return init_ws, init_bs
Loading

0 comments on commit b5bc246

Please sign in to comment.