diff --git a/examples/inference/python/README.md b/examples/inference/python/README.md index a2ea4130..43c56aed 100644 --- a/examples/inference/python/README.md +++ b/examples/inference/python/README.md @@ -25,6 +25,12 @@ Export Hugging Face GPT2 models to hdf5 format. ```shell python export/huggingface/hf_gpt2_export.py ``` +4. Hugging Face ViT + +Export Hugging Face ViT models to hdf5 format. +```shell +python export/huggingface/hf_vit_export.py +``` ### Native Fairseq 1. Native Fairseq Transformer @@ -112,8 +118,12 @@ python test/ls_bert.py ```shell python test/ls_gpt2.py ``` +4. ViT +```shell +python test/ls_vit.py +``` -4. Fairseq based models using LightSeq inference +5. Fairseq based models using LightSeq inference ```shell bash test/ls_fairseq.sh --model ${model_path} ``` diff --git a/examples/inference/python/export/huggingface/hf_vit_export.py b/examples/inference/python/export/huggingface/hf_vit_export.py new file mode 100644 index 00000000..ac60b634 --- /dev/null +++ b/examples/inference/python/export/huggingface/hf_vit_export.py @@ -0,0 +1,149 @@ +""" +Export Hugging Face ViT models to hdf5 format. +""" +import os +import h5py +from collections import OrderedDict +from transformers import ViTModel +from lightseq.training.ops.pytorch.export import fill_hdf5_layer + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + +""" +For the mapping dictionary: key is the value of the proto parameter, +value is a powerful expression, each && split tensor name of the matching path or expression. + +The sub-pattern of the path is separated by spaces, and the expression starts with a expression_. +You can operate separately on each tensor and support multiple expressions. Multiple matching paths +and the expression will finally be concatenated on axis = -1. +""" +enc_layer_mapping_dict = OrderedDict( + { + # VIT is pre_layernorm + # NOTE: add an additional "final" at the beginning for some weight + # to distinguish them from "attention output *" + "multihead_norm_scale": "layernorm_before weight", + "multihead_norm_bias": "layernorm_before bias", + "multihead_project_kernel_qkv": "attention attention query weight&&attention attention key weight&&attention attention value weight&&expression_.transpose(0, 1)", + "multihead_project_bias_qkv": "attention attention query bias&&attention attention key bias&&attention attention value bias", + "multihead_project_kernel_output": "attention output dense weight&&expression_.transpose(0, 1)", + "multihead_project_bias_output": "attention output dense bias", + "ffn_norm_scale": "layernorm_after weight", + "ffn_norm_bias": "layernorm_after bias", + "ffn_first_kernel": "intermediate dense weight&&expression_.transpose(0, 1)", + "ffn_first_bias": "intermediate dense bias", + "ffn_second_kernel": "final output dense weight&&expression_.transpose(0, 1)", + "ffn_second_bias": "final output dense bias", + } +) + +src_emb_mapping_dict = OrderedDict( + { + "conv_weight": "embeddings patch_embeddings projection weight", + "conv_bias": "embeddings patch_embeddings projection bias", + "position_embedding": "embeddings position_embeddings", + "cls_embedding": "embeddings cls_token", + "norm_scale": "layernorm weight", + "norm_bias": "layernorm bias", + } +) + + +def extract_vit_weights( + output_file, + model_dir, + head_num, + image_size, + patch_size, +): + # load var names + encoder_state_dict = ViTModel.from_pretrained(model_dir).state_dict() + + # Insert additional "final" to some weight to prevent ambiguous match + def _insert_final(key): + l = key.split(".") + l.insert(3, "final") + return ".".join(l) + + encoder_state_dict = OrderedDict( + [ + (_insert_final(k), v) + if len(k.split(".")) > 3 and k.split(".")[3] == "output" + else (k, v) + for k, v in encoder_state_dict.items() + ] + ) + + enc_var_name_list = list(encoder_state_dict.keys()) + + # initialize output file + output_file += ".hdf5" + print("Saving model to hdf5...") + print("Writing to {0}".format(output_file)) + hdf5_file = h5py.File(output_file, "w") + + # fill each encoder layer's params + enc_tensor_names = {} + for name in enc_var_name_list: + name_split = name.split(".") + if len(name_split) <= 2 or not name_split[2].isdigit(): + continue + layer_id = int(name_split[2]) + enc_tensor_names.setdefault(layer_id, []).append(name) + + # fill encoder_stack + for layer_id in sorted(enc_tensor_names.keys()): + fill_hdf5_layer( + enc_tensor_names[layer_id], + encoder_state_dict, + hdf5_file, + f"encoder_stack/{layer_id}/", + enc_layer_mapping_dict, + ) + + # fill src_embedding - except for position embedding + fill_hdf5_layer( + enc_var_name_list, + encoder_state_dict, + hdf5_file, + "src_embedding/", + src_emb_mapping_dict, + ) + + # save number of layers metadata + hdf5_file.create_dataset( + "model_conf/n_encoder_stack", data=len(enc_tensor_names), dtype="i4" + ) + # fill in model_conf + hdf5_file.create_dataset("model_conf/head_num", data=head_num, dtype="i4") + hdf5_file.create_dataset("model_conf/use_gelu", data=True, dtype="?") + hdf5_file.create_dataset("model_conf/is_post_ln", data=False, dtype="?") + hdf5_file.create_dataset("model_conf/image_size", data=image_size, dtype="i4") + hdf5_file.create_dataset("model_conf/patch_size", data=patch_size, dtype="i4") + + hdf5_file.close() + # read-in again to double check + hdf5_file = h5py.File(output_file, "r") + + def _print_pair(key, value): + value = value[()] + print(f"{key}: {value}") + + list(map(lambda x: _print_pair(*x), hdf5_file["model_conf"].items())) + + +if __name__ == "__main__": + output_lightseq_model_name = "lightseq_vit" + input_huggingface_vit_model = "google/vit-base-patch16-224-in21k" + head_number = 12 + image_size = 224 + patch_size = 16 + + extract_vit_weights( + output_lightseq_model_name, + input_huggingface_vit_model, + head_number, + image_size, + patch_size, + ) diff --git a/examples/inference/python/test/ls_vit.py b/examples/inference/python/test/ls_vit.py new file mode 100644 index 00000000..914a576f --- /dev/null +++ b/examples/inference/python/test/ls_vit.py @@ -0,0 +1,88 @@ +import time +import torch +import lightseq.inference as lsi +from transformers import ViTFeatureExtractor, ViTForImageClassification +from PIL import Image +import requests + + +def ls_vit(model, inputs): + torch.cuda.synchronize() + start_time = time.perf_counter() + ls_output = model.infer(inputs) + torch.cuda.synchronize() + end_time = time.perf_counter() + return ls_output, end_time - start_time + + +def hf_vit(model, inputs): + torch.cuda.synchronize() + start_time = time.perf_counter() + hf_output = model(inputs.cuda()) + torch.cuda.synchronize() + end_time = time.perf_counter() + return hf_output, end_time - start_time + + +def ls_generate(model, inputs): + print("=========lightseq=========") + print("lightseq generating...") + ls_output, ls_time = ls_vit(model, inputs) + print(f"lightseq time: {ls_time}s") + print("lightseq results (class predictions):") + print(ls_output.argmax(axis=1).detach().cpu().numpy()) + + +def hf_generate(model, inputs): + print("=========huggingface=========") + print("huggingface generating...") + hf_output, hf_time = hf_vit(model, inputs) + print(f"huggingface time: {hf_time}s") + print("huggingface results (class predictions):") + print(hf_output.logits.argmax(axis=1).detach().cpu().numpy()) + + +def one_infer(inputs, ls_model, hf_model): + ls_generate(ls_model, inputs) + hf_generate(hf_model, inputs) + + +class LightseqVitClassification: + def __init__(self, ls_weight_path, hf_model): + self.ls_vit = lsi.Vit(ls_weight_path, 8) + self.classifier = hf_model.classifier + + def infer(self, inputs): + last_hidden_states = self.ls_vit.infer(inputs) + last_hidden_states = torch.Tensor(last_hidden_states).float().cuda() + logits = self.classifier(last_hidden_states[:, 0, :]) + return logits + + +def main(): + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + feature_extractor = ViTFeatureExtractor.from_pretrained( + "google/vit-base-patch16-224-in21k" + ) + inputs = feature_extractor(images=image, return_tensors="pt") + inputs = inputs["pixel_values"] + + print("creating huggingface model...") + hf_model = ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224-in21k" + ).cuda() + + print("creating lightseq model...") + ls_model = LightseqVitClassification("lightseq_vit.hdf5", hf_model) + + print("====================START warmup====================") + one_infer(inputs, ls_model, hf_model) + print("====================END warmup====================") + + one_infer(inputs, ls_model, hf_model) + + +if __name__ == "__main__": + main() diff --git a/examples/training/huggingface/vit/__init__.py b/examples/training/huggingface/vit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/training/huggingface/vit/ls_hf_vit_encoder_layer.py b/examples/training/huggingface/vit/ls_hf_vit_encoder_layer.py new file mode 100644 index 00000000..8d492b6c --- /dev/null +++ b/examples/training/huggingface/vit/ls_hf_vit_encoder_layer.py @@ -0,0 +1,71 @@ +import torch +from lightseq.training.ops.pytorch.transformer_encoder_layer import ( + LSTransformerEncoderLayer, +) + + +class LSVITTransformerEncoderLayer(LSTransformerEncoderLayer): + def __init__(self, *args, **kwargs): + super(LSVITTransformerEncoderLayer, self).__init__(*args, **kwargs) + + def forward(self, hidden_states, *args, **kwargs): + ls_encoder_padding_mask = torch.zeros(hidden_states.size()[:-1]) + output = super().forward(hidden_states, ls_encoder_padding_mask) + return (output,) + + +def gen_vit_config(training_args, config): + num_patches = (config.image_size // config.patch_size) ** 2 + 1 + max_batch_size = max( + training_args.per_device_train_batch_size, + training_args.per_device_eval_batch_size, + ) + vit_config = LSTransformerEncoderLayer.get_config( + max_batch_tokens=num_patches * max_batch_size, + max_seq_len=num_patches, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + nhead=config.num_attention_heads, + attn_prob_dropout_ratio=config.attention_probs_dropout_prob, + activation_dropout_ratio=config.hidden_dropout_prob, + hidden_dropout_ratio=config.hidden_dropout_prob, + pre_layer_norm=True, + fp16=training_args.fp16, + local_rank=training_args.local_rank, + activation_fn="gelu", + ) + return vit_config + + +def inject_ls_enc_layer(model, training_args, config): + for i in range(config.num_hidden_layers): + vit_config = gen_vit_config(training_args, config) + init_ws, init_bs = get_hf_vit_enc_layer_params(model.vit.encoder.layer[i]) + model.vit.encoder.layer[i] = LSVITTransformerEncoderLayer( + vit_config, init_ws, init_bs + ).cuda() + + +def get_hf_vit_enc_layer_params(layer): + init_ws = [] + init_bs = [] + + init_ws.append(layer.attention.attention.query.weight.detach().clone()) + init_bs.append(layer.attention.attention.query.bias.detach().clone()) + init_ws.append(layer.attention.attention.key.weight.detach().clone()) + init_bs.append(layer.attention.attention.key.bias.detach().clone()) + init_ws.append(layer.attention.attention.value.weight.detach().clone()) + init_bs.append(layer.attention.attention.value.bias.detach().clone()) + init_ws.append(layer.attention.output.dense.weight.detach().clone()) + init_bs.append(layer.attention.output.dense.bias.detach().clone()) + init_ws.append(layer.layernorm_before.weight.detach().clone()) + init_bs.append(layer.layernorm_before.bias.detach().clone()) + + init_ws.append(layer.intermediate.dense.weight.detach().clone()) + init_bs.append(layer.intermediate.dense.bias.detach().clone()) + init_ws.append(layer.output.dense.weight.detach().clone()) + init_bs.append(layer.output.dense.bias.detach().clone()) + init_ws.append(layer.layernorm_after.weight.detach().clone()) + init_bs.append(layer.layernorm_after.bias.detach().clone()) + + return init_ws, init_bs diff --git a/examples/training/huggingface/vit/run_vit.py b/examples/training/huggingface/vit/run_vit.py new file mode 100644 index 00000000..d32cb670 --- /dev/null +++ b/examples/training/huggingface/vit/run_vit.py @@ -0,0 +1,417 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +import datasets +import numpy as np +import torch +from datasets import load_dataset +from PIL import Image +from torchvision.transforms import ( + CenterCrop, + Compose, + Normalize, + RandomHorizontalFlip, + RandomResizedCrop, + Resize, + ToTensor, +) + +import transformers +from transformers import ( + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + AutoConfig, + AutoFeatureExtractor, + AutoModelForImageClassification, + HfArgumentParser, + Trainer, + TrainingArguments, +) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version +from ls_hf_vit_encoder_layer import inject_ls_enc_layer + + +""" Fine-tuning a 🤗 Transformers model for image classification""" + +logger = logging.getLogger(__name__) + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.18.0.dev0") + +require_version( + "datasets>=1.8.0", + "To fix: pip install -r examples/pytorch/image-classification/requirements.txt", +) + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) +# ('poolformer', 'convnext', 'van', 'resnet', 'swin', 'imagegpt', 'segformer', 'perceiver', 'beit', 'deit', 'vit') + + +def pil_loader(path: str): + with open(path, "rb") as f: + im = Image.open(f) + return im.convert("RGB") + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + dataset_name: Optional[str] = field( + default="nateraw/image-folder", + metadata={"help": "Name of a dataset from the datasets package"}, + ) + dataset_config_name: Optional[str] = field( + default=None, + metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)." + }, + ) + train_dir: Optional[str] = field( + default=None, metadata={"help": "A folder containing the training data."} + ) + validation_dir: Optional[str] = field( + default=None, metadata={"help": "A folder containing the validation data."} + ) + train_val_split: Optional[float] = field( + default=0.15, metadata={"help": "Percent to split off of train for validation."} + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + + def __post_init__(self): + data_files = dict() + if self.train_dir is not None: + data_files["train"] = self.train_dir + if self.validation_dir is not None: + data_files["val"] = self.validation_dir + self.data_files = data_files if data_files else None + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + default="google/vit-base-patch16-224-in21k", + metadata={ + "help": "Path to pretrained model or model identifier from huggingface.co/models" + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={ + "help": "If training from scratch, pass a model type from the list: " + + ", ".join(MODEL_TYPES) + }, + ) + config_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from s3" + }, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, + ) + feature_extractor_name: str = field( + default=None, metadata={"help": "Name or path of preprocessor config."} + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + with_lightseq: bool = field( + default=True, + metadata={"help": "Whether to use lightseq TransformerEncoder"}, + ) + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + labels = torch.tensor([example["labels"] for example in examples]) + return {"pixel_values": pixel_values, "labels": labels} + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments) + ) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif ( + last_checkpoint is not None and training_args.resume_from_checkpoint is None + ): + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Initialize our dataset and prepare it for the 'image-classification' task. + ds = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + data_files=data_args.data_files, + cache_dir=model_args.cache_dir, + task="image-classification", + ) + + # If we don't have a validation split, split off a percentage of train as validation. + data_args.train_val_split = ( + None if "validation" in ds.keys() else data_args.train_val_split + ) + if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: + split = ds["train"].train_test_split(data_args.train_val_split) + ds["train"] = split["train"] + ds["validation"] = split["test"] + + # Prepare label mappings. + # We'll include these in the model's config to get human readable labels in the Inference API. + labels = ds["train"].features["labels"].names + label2id, id2label = dict(), dict() + for i, label in enumerate(labels): + label2id[label] = str(i) + id2label[str(i)] = label + + # Load the accuracy metric from the datasets package + metric = datasets.load_metric("accuracy") + + # Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a + # predictions and label_ids field) and has to return a dictionary string to float. + def compute_metrics(p): + """Computes accuracy on a batch of predictions""" + return metric.compute( + predictions=np.argmax(p.predictions, axis=1), references=p.label_ids + ) + + config = AutoConfig.from_pretrained( + model_args.config_name or model_args.model_name_or_path, + num_labels=len(labels), + label2id=label2id, + id2label=id2label, + finetuning_task="image-classification", + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForImageClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Replace with LightSeq encoder layers. + if model_args.with_lightseq: + inject_ls_enc_layer(model, training_args, config) + + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name or model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Define torchvision transforms to be applied to each image. + normalize = Normalize( + mean=feature_extractor.image_mean, std=feature_extractor.image_std + ) + _train_transforms = Compose( + [ + RandomResizedCrop(feature_extractor.size), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ] + ) + _val_transforms = Compose( + [ + Resize(feature_extractor.size), + CenterCrop(feature_extractor.size), + ToTensor(), + normalize, + ] + ) + + def train_transforms(example_batch): + """Apply _train_transforms across a batch.""" + example_batch["pixel_values"] = [ + _train_transforms(pil_img.convert("RGB")) + for pil_img in example_batch["image"] + ] + return example_batch + + def val_transforms(example_batch): + """Apply _val_transforms across a batch.""" + example_batch["pixel_values"] = [ + _val_transforms(pil_img.convert("RGB")) + for pil_img in example_batch["image"] + ] + return example_batch + + if training_args.do_train: + if "train" not in ds: + raise ValueError("--do_train requires a train dataset") + if data_args.max_train_samples is not None: + ds["train"] = ( + ds["train"] + .shuffle(seed=training_args.seed) + .select(range(data_args.max_train_samples)) + ) + # Set the training transforms + ds["train"].set_transform(train_transforms) + + if training_args.do_eval: + if "validation" not in ds: + raise ValueError("--do_eval requires a validation dataset") + if data_args.max_eval_samples is not None: + ds["validation"] = ( + ds["validation"] + .shuffle(seed=training_args.seed) + .select(range(data_args.max_eval_samples)) + ) + # Set the validation transforms + ds["validation"].set_transform(val_transforms) + + # Initalize our trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=ds["train"] if training_args.do_train else None, + eval_dataset=ds["validation"] if training_args.do_eval else None, + compute_metrics=compute_metrics, + tokenizer=feature_extractor, + data_collator=collate_fn, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Write model card and (optionally) push to hub + kwargs = { + "finetuned_from": model_args.model_name_or_path, + "tasks": "image-classification", + "dataset": data_args.dataset_name, + "tags": ["image-classification"], + } + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/examples/training/huggingface/vit/run_vit.sh b/examples/training/huggingface/vit/run_vit.sh new file mode 100644 index 00000000..49f38424 --- /dev/null +++ b/examples/training/huggingface/vit/run_vit.sh @@ -0,0 +1,37 @@ +# Copyright 2021 The LightSeq Team +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +THIS_DIR=$(dirname $(readlink -f $0)) + +python3 $THIS_DIR/run_vit.py \ + --dataset_name beans \ + --output_dir /tmp/beans_outputs \ + --overwrite_output_dir \ + --remove_unused_columns False \ + --do_train \ + --do_eval \ + --learning_rate 2e-5 \ + --num_train_epochs 5 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 8 \ + --logging_strategy steps \ + --logging_steps 10 \ + --evaluation_strategy epoch \ + --save_strategy epoch \ + --load_best_model_at_end True \ + --save_total_limit 3 \ + --seed 1337 \ + --fp16 \ + --with_lightseq true diff --git a/lightseq/inference/kernels/embKernels.cc.cu b/lightseq/inference/kernels/embKernels.cc.cu index 265c541f..e4e918f6 100644 --- a/lightseq/inference/kernels/embKernels.cc.cu +++ b/lightseq/inference/kernels/embKernels.cc.cu @@ -543,5 +543,91 @@ template void launch_dec_emb<__half>(const __half *token_emb, int beam_size, int hidden_dim, int vocab_size, int step, int max_step, int multilg_type, cudaStream_t stream); + +/** +@brief: ker_patch_emb +patch embedding by conv2d, concat cls embedding, add position embedding + +@thread +gridDim.x = batch_size +gridDim.y = max_step +gridDim.z = hidden_dim +blockDim.x = MAX_THREADS + +@param +conv_weight: [hidden_dim, channel_input, patch_size, patch_size] +conv_bias: [hidden_dim] +pos_emb: [max_step, hidden_dim] +cls_emb: [hidden_dim] +input: [batch_size, channel_input, image_size, image_size] +output: result, [batch_size, max_step, hidden_dim] +*/ +template +__global__ void ker_patch_emb(const T *conv_weight, const T *conv_bias, + const T *pos_emb, const T *cls_emb, + const float *input, T *output, int patch_size, + int image_size, int channel_input) { + if (blockIdx.y == 0) { + if (threadIdx.x == 0) { + output[flat_3dim(blockIdx.x, 0, blockIdx.z, gridDim.y, gridDim.z)] = + __ldg(&cls_emb[blockIdx.z]) + __ldg(&pos_emb[blockIdx.z]); + } + return; + } + + int val_num_per_block = channel_input * patch_size * patch_size; + int patch_row_id, patch_col_id, value_row_id, value_col_id, channel_id; + decompose_2dim(blockIdx.y - 1, image_size / patch_size, &patch_row_id, + &patch_col_id); + + float val = 0.f; + for (int idx = threadIdx.x; idx < val_num_per_block; idx += blockDim.x) { + decompose_3dim(idx, patch_size, patch_size, &channel_id, &value_row_id, + &value_col_id); + int conv_weight_offset = flat_2dim(blockIdx.z, idx, val_num_per_block); + int in_offset = flat_4dim(blockIdx.x, channel_id, + patch_row_id * patch_size + value_row_id, + patch_col_id * patch_size + value_col_id, + channel_input, image_size, image_size); + val += __ldg(&input[in_offset]) * + (float)__ldg(&conv_weight[conv_weight_offset]); + } + + float rsum = blockReduceSum(val); + if (threadIdx.x == 0) { + float out_float; + int out_offset = + flat_3dim(blockIdx.x, blockIdx.y, blockIdx.z, gridDim.y, gridDim.z); + out_float = + rsum + (float)__ldg(&conv_bias[blockIdx.z]) + + (float)__ldg(&pos_emb[flat_2dim(blockIdx.y, blockIdx.z, gridDim.z)]); + output[out_offset] = (T)out_float; + } +} + +template +void launch_patch_emb(const T *conv_weight, const T *conv_bias, + const T *pos_emb, const T *cls_emb, const float *input, + T *output, int patch_size, int image_size, int batch_size, + int max_step, int hidden_dim, int channel_input, + cudaStream_t stream) { + ker_patch_emb + <<>>( + conv_weight, conv_bias, pos_emb, cls_emb, input, output, patch_size, + image_size, channel_input); +} + +template void launch_patch_emb( + const float *conv_weight, const float *conv_bias, const float *pos_emb, + const float *cls_emb, const float *input, float *output, int patch_size, + int image_size, int batch_size, int max_step, int hidden_dim, + int channel_input, cudaStream_t stream); + +template void launch_patch_emb<__half>( + const __half *conv_weight, const __half *conv_bias, const __half *pos_emb, + const __half *cls_emb, const float *input, __half *output, int patch_size, + int image_size, int batch_size, int max_step, int hidden_dim, + int channel_input, cudaStream_t stream); + } // namespace cuda } // namespace lightseq diff --git a/lightseq/inference/kernels/embKernels.h b/lightseq/inference/kernels/embKernels.h index 8ae555fd..77782de9 100644 --- a/lightseq/inference/kernels/embKernels.h +++ b/lightseq/inference/kernels/embKernels.h @@ -23,5 +23,12 @@ void launch_dec_emb(const T *token_emb, const T *pos_emb, int *tokens, int vocab_size, int step, int max_step, int multilg_type, cudaStream_t stream); +template +void launch_patch_emb(const T *conv_weight, const T *conv_bias, + const T *pos_emb, const T *cls_emb, const float *input, + T *output, int patch_size, int image_size, int batch_size, + int max_step, int hidden_dim, int channel_input, + cudaStream_t stream); + } // namespace cuda } // namespace lightseq diff --git a/lightseq/inference/kernels/moeKernels.cc.cu b/lightseq/inference/kernels/moeKernels.cc.cu index 58b8b033..6a2cadec 100644 --- a/lightseq/inference/kernels/moeKernels.cc.cu +++ b/lightseq/inference/kernels/moeKernels.cc.cu @@ -149,62 +149,6 @@ template void ker_norm_layer_prepost_launcher<__half>( __half* output, const __half* scale, const __half* bias, const int max_thread_per_block, bool is_post_ln); -template -__global__ void ker_residual(const T* input, T* output, const int hidden_size) { - uint block_start = blockIdx.x * hidden_size; - uint start = block_start + threadIdx.x; - uint end = block_start + hidden_size; - for (uint i = start; i < end; i += blockDim.x) { - output[i] += __ldg(&input[i]); - } -} - -template <> -__global__ void ker_residual<__half>(const __half* input, __half* output, - const int half_hidden_size) { - uint block_start = blockIdx.x * half_hidden_size; - uint start = block_start + threadIdx.x; - uint end = blockIdx.x * half_hidden_size + half_hidden_size; - half2* pinput = (half2*)input; - half2* poutput = (half2*)output; - - for (uint i = start; i < end; i += blockDim.x) { - float2 local_f2 = safe_half2_to_float2(poutput[i]); - float2 residual_val = __half22float2(__ldg(&pinput[i])); - float2 new_output_f2; - new_output_f2.x = local_f2.x + residual_val.x; - new_output_f2.y = local_f2.y + residual_val.y; - poutput[i] = __float22half2_rn(new_output_f2); - } -} - -template -void ker_residual_launcher(int token_num, int hidden_size, cudaStream_t stream, - const T* input, T* output, - const int max_thread_per_block) { - ker_residual<<>>(input, output, - hidden_size); -} - -template <> -void ker_residual_launcher<__half>(int token_num, int hidden_size, - cudaStream_t stream, const __half* input, - __half* output, - const int max_thread_per_block) { - ker_residual<__half><<>>( - input, output, hidden_size / 2); -} - -template void ker_residual_launcher(int token_num, int hidden_size, - cudaStream_t stream, - const float* input, float* output, - const int max_thread_per_block); - -template void ker_residual_launcher<__half>(int token_num, int hidden_size, - cudaStream_t stream, - const __half* input, __half* output, - const int max_thread_per_block); - /** @brief: ker_softmax_topk_router softmax of gate output and route each token to topk experts diff --git a/lightseq/inference/kernels/moeKernels.h b/lightseq/inference/kernels/moeKernels.h index 7c8a4d61..4e17355f 100644 --- a/lightseq/inference/kernels/moeKernels.h +++ b/lightseq/inference/kernels/moeKernels.h @@ -7,10 +7,6 @@ namespace lightseq { namespace cuda { -template -void ker_residual_launcher(int token_num, int hidden_size, cudaStream_t stream, - const T* input, T* output, - const int max_thread_per_block); template void ker_norm_layer_prepost_launcher(int token_num, int hidden_size, cudaStream_t stream, T* input, T* output, diff --git a/lightseq/inference/model/CMakeLists.txt b/lightseq/inference/model/CMakeLists.txt index ba5b7668..16275320 100644 --- a/lightseq/inference/model/CMakeLists.txt +++ b/lightseq/inference/model/CMakeLists.txt @@ -62,3 +62,13 @@ else() target_link_libraries(moe_model PRIVATE CUDA::cublas_static CUDA::cublasLt_static) endif() + +add_library(vit_model STATIC vit_encoder.cc.cu) +target_link_libraries(vit_model PUBLIC cuda_kernels) +target_link_libraries(vit_model PUBLIC vit_weight) +if(DYNAMIC_API) + target_link_libraries(vit_model PRIVATE CUDA::cublas CUDA::cublasLt) +else() + target_link_libraries(vit_model PRIVATE CUDA::cublas_static + CUDA::cublasLt_static) +endif() diff --git a/lightseq/inference/model/bert_encoder.cc.cu b/lightseq/inference/model/bert_encoder.cc.cu index 755267b8..a7160eb3 100644 --- a/lightseq/inference/model/bert_encoder.cc.cu +++ b/lightseq/inference/model/bert_encoder.cc.cu @@ -4,7 +4,7 @@ /** @file -Transformer encoder, composed by gemm lib and +Bert encoder, composed by gemm lib and custom cuda kernel function */ diff --git a/lightseq/inference/model/bert_encoder.h b/lightseq/inference/model/bert_encoder.h index ef7ae6af..4a4a5ca1 100644 --- a/lightseq/inference/model/bert_encoder.h +++ b/lightseq/inference/model/bert_encoder.h @@ -17,7 +17,7 @@ /** @file -Transformer decoder, composed by gemm lib and +Bert encoder, composed by gemm lib and custom cuda kernel function */ diff --git a/lightseq/inference/model/vit_encoder.cc.cu b/lightseq/inference/model/vit_encoder.cc.cu new file mode 100644 index 00000000..96f30c26 --- /dev/null +++ b/lightseq/inference/model/vit_encoder.cc.cu @@ -0,0 +1,301 @@ +#include "vit_encoder.h" +#include "../kernels/embKernels.h" +#include "../kernels/transformerKernels.h" + +/** +@file +ViT encoder, composed by gemm lib and + custom cuda kernel function +*/ + +namespace lightseq { +namespace cuda { + +template +VitEncoder::VitEncoder(int max_batch_size, + const float *p_d_pixel_input, + int *p_d_padding_mask, _DataType *p_d_output, + const VitWeight &tw, + cudaStream_t stream, cublasHandle_t hd) + : _max_batch_size(max_batch_size), + _p_d_pixel_input(p_d_pixel_input), + _p_d_padding_mask(p_d_padding_mask), + _p_d_output(p_d_output), + _tw(tw), + _stream(stream), + _hd(hd), + _p_d_src_emb_wei(tw.get_src_emb_wei()), + _p_d_enc_wei(tw.get_enc_wei()), + _fone((_DataType)1.f), + _fzero((_DataType)0.f), + _atten_scaler((_DataType)sqrt(1.f / tw._dim_per_head)), + _max_batch_dim(max_batch_size * tw._max_step * tw._hidden_size), + _max_thread_per_block(1024) {} + +/** +Compute GPU memory size needed by transformer encoder, + to see how these memory is used, checkout init_buffer() for detail +*/ +template +long VitEncoder::compute_buffer_bytesize() { + long sz1 = _max_batch_dim * 6 + + _max_batch_size * _tw._head_num * _tw._max_step * _tw._max_step; + long sz2 = _max_batch_dim + _max_batch_size * _tw._max_step * _tw._inner_size; + return max(sz1, sz2) * sizeof(_DataType); +} + +/** +Init the GPU memory pointer which point to + the memory buffer needed by encoder. +These buffer are used during custom cuda kernel function, + find the corresponding function to see how these buffer are used +*/ +template +void VitEncoder::init_buffer(void *pbuf) { + _DataType *p_d_buf = reinterpret_cast<_DataType *>(pbuf); + _p_d_qkv_projected = p_d_buf; + _p_d_q = _p_d_qkv_projected + _max_batch_dim * 3; + _p_d_k = _p_d_q + _max_batch_dim; + _p_d_v = _p_d_k + _max_batch_dim; + _p_d_c = _p_d_v + _max_batch_dim; + _p_d_ffn_buf1 = p_d_buf; + _p_d_ffn_buf2 = _p_d_ffn_buf1 + _max_batch_dim; + return; +} + +/** +Some requirements needed by custom cuda kernel function +*/ +template +std::string VitEncoder::check() { + // if (_max_thread_per_block < _tw._hidden_size) { + // return "violate hidden_size <= max_thread_per_block"; + // } + if (_tw._inner_size & 1) { + return "violate inner_size % 2 = 0"; + } + if (_tw._dim_per_head & 1) { + return "violate dim_per_head % 2 = 0"; + } + if (_p_d_src_emb_wei.size() != 6) { + return "violate p_d_src_emb_wei.size() = 6"; + } + if (_p_d_enc_wei.size() != _tw._weight_per_enc_layer * _tw._n_enc_layer) { + return "violate p_d_enc_wei.size() = weight_per_enc_layer * n_enc_layer"; + } + if (_tw._max_step != (_tw._image_size / _tw._patch_size) * + (_tw._image_size / _tw._patch_size) + + 1) { + return "violate max_step = (image_size / patch_size) ** 2 + 1"; + } + return ""; +} + +/** +Encoder inference +*/ +template +void VitEncoder::run_one_infer(int batch_size) { + if (batch_size > _max_batch_size) { + throw std::runtime_error("batch size of input greater than max_batch_size"); + } + /* ---step1. init--- */ + _batch_size = batch_size; + _batch_seq_len = _tw._max_step; + _batch_token_num = batch_size * _batch_seq_len; + + /* ---step2. encoder feedforward--- */ + launch_patch_emb<_DataType>(_p_d_src_emb_wei[0], _p_d_src_emb_wei[1], + _p_d_src_emb_wei[2], _p_d_src_emb_wei[3], + _p_d_pixel_input, _p_d_output, _tw._patch_size, + _tw._image_size, _batch_size, _tw._max_step, + _tw._hidden_size, _tw._channel_input, _stream); +#ifdef DEBUG_RESULT + for (int i = 0; i < _batch_size; i++) { // batch_id + for (int j = 0; j < 10; j++) { // patch_id + std::cout << "emb out: patch-" << j << std::endl; + print_vec(_p_d_output + i * _batch_seq_len * _tw._hidden_size + + j * _tw._hidden_size, + "emb out", 20); + } + } +#endif + for (_layer_id = 0; _layer_id < _tw._n_enc_layer; _layer_id++) { + _weight_offset = _layer_id * _tw._weight_per_enc_layer; + self_attention(); + ffn_add_norm(); + } + // last layer norm + ker_norm_layer_launcher<_DataType>( + _batch_token_num, _tw._hidden_size, _stream, _p_d_output, + _p_d_src_emb_wei[4], _p_d_src_emb_wei[5], _max_thread_per_block); + +#ifdef DEBUG_RESULT + for (int i = 0; i < _batch_size; i++) { // batch_id + for (int j = 0; j < _batch_seq_len; j++) { // patch_id + std::cout << "encoder output: token-" << j << std::endl; + print_vec(_p_d_output + i * _batch_seq_len * _tw._hidden_size + + j * _tw._hidden_size, + "encoder_output", _tw._dim_per_head); + } + } +#endif + return; +} + +/** +Encoder self attention +*/ +template +void VitEncoder::self_attention() { + /* ---step 0. layer_norm, add output_bias to "query"--- */ + ker_norm_layer_resual_launcher<_DataType>( + _batch_token_num, _tw._hidden_size, _stream, _p_d_output, _p_d_q, + _p_d_enc_wei[_weight_offset], _p_d_enc_wei[_weight_offset + 1], + _p_d_enc_wei[_weight_offset + 5], _max_thread_per_block, _tw._is_post_ln); + +#ifdef DEBUG_RESULT + print_vec(_p_d_enc_wei[_weight_offset], "layer norm scale(head): ", 5); + print_vec(_p_d_enc_wei[_weight_offset + 1], "layer norm bias(head): ", 5); + print_vec(_p_d_q, "layer norm out(head): ", 5); + print_vec(_p_d_q + _batch_token_num * _tw._hidden_size - 5, + "layer norm out(tail): ", 5); +#endif + + /* ---step 1. qkv = ori_q * qkv_wei + bias, and reshape qkv for multi-head + * gemm--- */ + CHECK_GPU_ERROR(cublasGemmEx( + _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._hidden_size * 3, _batch_token_num, + _tw._hidden_size, &_fone, _p_d_enc_wei[_weight_offset + 2], _AType, + _tw._hidden_size * 3, _p_d_q, _BType, _tw._hidden_size, &_fzero, + _p_d_qkv_projected, _CType, _tw._hidden_size * 3, _computeType, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + +#ifdef DEBUG_RESULT + print_vec(_p_d_qkv_projected, "self qkv(head): ", 5); + print_vec(_p_d_qkv_projected + _batch_token_num * _tw._hidden_size * 3 - 5, + "self qkv(tail): ", 5); +#endif + + // get q, k, v by split and reshape qkv + ker_arrange_encself_qkv_launcher<_DataType>( + _batch_token_num, _tw._hidden_size, _stream, _p_d_qkv_projected, + _p_d_enc_wei[_weight_offset + 3], _p_d_q, _max_batch_dim, _batch_seq_len, + _tw._dim_per_head, _tw._head_num, _max_thread_per_block); + + /* ---step 2. correlation = q * k, perform softmax on correlation--- */ + CHECK_GPU_ERROR(cublasGemmStridedBatchedEx( + _hd, CUBLAS_OP_T, CUBLAS_OP_N, _batch_seq_len, _batch_seq_len, + _tw._dim_per_head, &_atten_scaler, _p_d_k, _AType, _tw._dim_per_head, + _batch_seq_len * _tw._dim_per_head, _p_d_q, _BType, _tw._dim_per_head, + _batch_seq_len * _tw._dim_per_head, &_fzero, _p_d_c, _CType, + _batch_seq_len, _batch_seq_len * _batch_seq_len, + _batch_size * _tw._head_num, _computeType, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + ker_correlation_softmax_encself_launcher<_DataType>( + _batch_size, _batch_seq_len, _tw._head_num, _stream, _p_d_c, + _p_d_padding_mask); + +#ifdef DEBUG_RESULT + print_vec(_p_d_c, "self attn correlation(head): ", 5); + print_vec(_p_d_c + _batch_token_num * _tw._head_num * _batch_seq_len - 5, + "self attn correlation(tail): ", 5); +#endif + + /* ---step 3. new_q = correlation * v--- */ + CHECK_GPU_ERROR(cublasGemmStridedBatchedEx( + _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._dim_per_head, _batch_seq_len, + _batch_seq_len, &_fone, _p_d_v, _AType, _tw._dim_per_head, + _batch_seq_len * _tw._dim_per_head, _p_d_c, _BType, _batch_seq_len, + _batch_seq_len * _batch_seq_len, &_fzero, _p_d_q, _CType, + _tw._dim_per_head, _batch_seq_len * _tw._dim_per_head, + _batch_size * _tw._head_num, _computeType, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + // use v to save reshaped q, since they are in same size and v + // will not be use again before the next multi-head-attention + ker_arrange_atten_output_launcher<_DataType>( + _batch_token_num, _tw._hidden_size, _stream, _p_d_q, _p_d_v, + _batch_seq_len, _tw._dim_per_head, _tw._head_num, _max_thread_per_block); + +#ifdef DEBUG_RESULT + print_vec(_p_d_v, "self attn before ffn(head): ", 5); +#endif + + /* ---step 4. new_q = ori_q + new_q * output_wei--- */ + CHECK_GPU_ERROR(cublasGemmEx( + _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._hidden_size, _batch_token_num, + _tw._hidden_size, &_fone, _p_d_enc_wei[_weight_offset + 4], _AType, + _tw._hidden_size, _p_d_v, _BType, _tw._hidden_size, &_fone, _p_d_output, + _CType, _tw._hidden_size, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + +#ifdef DEBUG_RESULT + print_vec(_p_d_output, "self attn ffn out(head): ", 5); + print_vec(_p_d_output + _batch_token_num * _tw._hidden_size - 5, + "self attn ffn out(tail): ", 5); + + print_vec(_p_d_enc_wei[_weight_offset + 4], "enc wei:", 5); +#endif + + return; +} + +template +void VitEncoder::ffn_add_norm() { + /* ---step 0. layer_norm, add output_bias to "query"--- */ + ker_norm_layer_resual_launcher<_DataType>( + _batch_token_num, _tw._hidden_size, _stream, _p_d_output, _p_d_ffn_buf1, + _p_d_enc_wei[_weight_offset + 6], _p_d_enc_wei[_weight_offset + 7], + _p_d_enc_wei[_weight_offset + 11], _max_thread_per_block, + _tw._is_post_ln); + +#ifdef DEBUG_RESULT + print_vec(_p_d_enc_wei[_weight_offset + 6], "layer norm scale(head): ", 5); + print_vec(_p_d_enc_wei[_weight_offset + 7], "layer norm bias(head): ", 5); + print_vec(_p_d_ffn_buf1, "layer norm(head): ", 5); + print_vec(_p_d_ffn_buf1 + _batch_token_num * _tw._hidden_size - 5, + "layer norm(tail): ", 5); +#endif + + /* ---step 1. first ffn layer--- */ + CHECK_GPU_ERROR(cublasGemmEx( + _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._inner_size, _batch_token_num, + _tw._hidden_size, &_fone, _p_d_enc_wei[_weight_offset + 8], _AType, + _tw._inner_size, _p_d_ffn_buf1, _BType, _tw._hidden_size, &_fzero, + _p_d_ffn_buf2, _CType, _tw._inner_size, _computeType, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + if (_tw._use_gelu) { + ker_bias_gelu_launcher<_DataType>( + _batch_token_num, _max_thread_per_block, _stream, _p_d_ffn_buf2, + _p_d_enc_wei[_weight_offset + 9], _tw._inner_size); + } else { + ker_bias_relu_launcher<_DataType>( + _batch_token_num, _max_thread_per_block, _stream, _p_d_ffn_buf2, + _p_d_enc_wei[_weight_offset + 9], _tw._inner_size); + } + +#ifdef DEBUG_RESULT + print_vec(_p_d_ffn_buf2, "ffn activation(head): ", 5); + print_vec(_p_d_ffn_buf2 + _batch_token_num * _tw._hidden_size - 5, + "ffn activation(tail): ", 5); +#endif + + /* ---step 2. second ffn layer--- */ + CHECK_GPU_ERROR(cublasGemmEx( + _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._hidden_size, _batch_token_num, + _tw._inner_size, &_fone, _p_d_enc_wei[_weight_offset + 10], _AType, + _tw._hidden_size, _p_d_ffn_buf2, _BType, _tw._inner_size, &_fone, + _p_d_output, _CType, _tw._hidden_size, _computeType, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +#ifdef DEBUG_RESULT + print_vec(_p_d_output, "ffn output (head): ", 5); + print_vec(_p_d_output + _batch_token_num * _tw._hidden_size - 5, + "ffn output (tail): ", 5); +#endif +} + +template class VitEncoder; +template class VitEncoder; + +} // namespace cuda +} // namespace lightseq diff --git a/lightseq/inference/model/vit_encoder.h b/lightseq/inference/model/vit_encoder.h new file mode 100644 index 00000000..cb33e260 --- /dev/null +++ b/lightseq/inference/model/vit_encoder.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../proto/vit_weight.h" +#include "../tools/util.h" + +/** +@file +ViT encoder, composed by gemm lib and + custom cuda kernel function +*/ + +namespace lightseq { +namespace cuda { + +template +class VitEncoder { + private: + typedef OperationTypeTraits _optraits; + typedef typename _optraits::DataType _DataType; + const cudaDataType_t _computeType = _optraits::computeType; + const cudaDataType_t _AType = _optraits::AType; + const cudaDataType_t _BType = _optraits::BType; + const cudaDataType_t _CType = _optraits::CType; + + // private member function + void self_attention(); + void ffn_add_norm(); + + const int _max_batch_size; + int *_p_d_padding_mask; // true sequence length(remove padding), [batch_size] + + const VitWeight &_tw; + cudaStream_t _stream; + cublasHandle_t _hd; + const _DataType _fone; + const _DataType _fzero; + const _DataType _atten_scaler; + const int _max_batch_dim; + const int _max_thread_per_block; + + _DataType *_p_d_qkv_projected; + _DataType *_p_d_q; + _DataType *_p_d_k; + _DataType *_p_d_v; + _DataType *_p_d_c; + _DataType *_p_d_ffn_buf1; + _DataType *_p_d_ffn_buf2; + + // {conv_weight, conv_bias, pos_emb, cls_embedding} + const std::vector &_p_d_src_emb_wei; + // {multihead_norm_scale, multihead_norm_bias, multihead_qkv_kernel, + // multihead_qkv_bias multihead_output_kernel, multihead_output_bias + // ffn_norm_scale, ffn_norm_bias} + // ffn_first_kernel, ffn_first_bias, ffn_second_kernel, ffn_second_bias} * + // encoder_layer_num + const std::vector &_p_d_enc_wei; + + int _batch_size; + int _batch_seq_len; + int _batch_token_num; + int _layer_id; + int _weight_offset; + + public: + const float *_p_d_pixel_input; // input pixels [batch_size, channel_input, + // image_size, image_size] + _DataType + *_p_d_output; // encoder output, [batch_size, batch_seq_len, hidden_size] + + VitEncoder(int max_batch_size, const float *p_d_pixel_input, + int *p_d_padding_mask, _DataType *p_d_output, + const VitWeight &tw, cudaStream_t stream, + cublasHandle_t hd); + long compute_buffer_bytesize(); + void init_buffer(void *pbuf); + std::string check(); + void run_one_infer(int batch_size); +}; + +} // namespace cuda +} // namespace lightseq diff --git a/lightseq/inference/proto/CMakeLists.txt b/lightseq/inference/proto/CMakeLists.txt index 7d162569..bfa710f5 100644 --- a/lightseq/inference/proto/CMakeLists.txt +++ b/lightseq/inference/proto/CMakeLists.txt @@ -15,6 +15,7 @@ protobuf_generate_cpp(Q_TRANSFORMER_PROTO_SRC Q_TRANSFORMER_PROTO_HEADER protobuf_generate_cpp(TRANSFORMER_PROTO_SRC TRANSFORMER_PROTO_HEADER transformer.proto) protobuf_generate_cpp(MOE_PROTO_SRC MOE_PROTO_HEADER moe.proto) +protobuf_generate_cpp(VIT_PROTO_SRC VIT_PROTO_HEADER vit.proto) add_library(gpt_weight STATIC gpt_weight.cc ${GPT_PROTO_SRC} ${GPT_PROTO_HEADER}) @@ -56,3 +57,9 @@ target_link_libraries(moe_weight PUBLIC utils ${Protobuf_LIBRARIES} ${HDF5_LIBRARIES}) target_include_directories(moe_weight PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(moe_weight PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + +add_library(vit_weight STATIC vit_weight.cc ${VIT_PROTO_SRC} + ${VIT_PROTO_HEADER}) +target_link_libraries(vit_weight PUBLIC utils ${Protobuf_LIBRARIES}) +target_include_directories(vit_weight PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_include_directories(vit_weight PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/lightseq/inference/proto/vit.proto b/lightseq/inference/proto/vit.proto new file mode 100644 index 00000000..1b7122c3 --- /dev/null +++ b/lightseq/inference/proto/vit.proto @@ -0,0 +1,65 @@ +syntax = "proto3"; +option optimize_for = LITE_RUNTIME; +// all the matrix are stored in row-major order, +// plz see https://en.wikipedia.org/wiki/Row-_and_column-major_order for details + +// the definition of "Multi-Head Attention", "Scaled Dot-Product Attention" and +// "Feed-Forward Networks" +// plz see https://arxiv.org/abs/1706.03762 for details + +message VitEncoderLayer { + // layer norm before "Multi-Head Attention" + repeated float multihead_norm_scale = 1; // [hidden_size] + repeated float multihead_norm_bias = 2; // [hidden_size] + + // "Multi-Head Attention" linearly project weights kernel for query, key, + // value, + // before "Scaled Dot-Product Attention, with shape (hidden_size, + // hidden_size*3) + // is built by numpy.concatenate((query_kernel, key_kernel, value_kernel), + // axis=1) + // perform numpy.dot(input, multihead_project_kernel_qkv) will get the [query, + // key, value] of + // "Scaled Dot-Product Attention" + repeated float multihead_project_kernel_qkv = 3; // [hidden_size, 3, hidden_size] + repeated float multihead_project_bias_qkv = 4; // [3, hidden_size] + repeated float multihead_project_kernel_output = 5; // [hidden_size, hidden_size] + repeated float multihead_project_bias_output = 6; // [hidden_size] + + // layer norm before "Feed-Forward Networks" + repeated float ffn_norm_scale = 7; // [hidden_size] + repeated float ffn_norm_bias = 8; // [hidden_size] + + // "Feed-Forward Networks" + repeated float ffn_first_kernel = 9; // [hidden_size, inner_size] + repeated float ffn_first_bias = 10; // [inner_size] + repeated float ffn_second_kernel = 11; // [inner_size, hidden_size] + repeated float ffn_second_bias = 12; // [hidden_size] +} + +message VitEmbeddingLayer { + // weight and bias of convolution in patch embedding + repeated float conv_weight = 1; // [hidden_size, channel_input, patch_size, patch_size] + repeated float conv_bias = 2; // [hidden_size] + // learnable position embedding + repeated float position_embedding = 3; // [max_seq_len, hidden_size] + repeated float cls_embedding = 4; // [hidden_size] + // the last layer_norm of encoder, + // only for pre layer norm, + repeated float norm_scale = 5; // [hidden_size] + repeated float norm_bias = 6; // [hidden_size] +} + +message VitModelConf { + int32 head_num = 1; + bool use_gelu = 2; // use gelu for activation otherwise relu + int32 image_size = 3; // width of input image + int32 patch_size = 4; //width of patch and convolution kernel + bool is_post_ln = 5; // Pre-LN or Post-LN +} + +message Vit { + VitEmbeddingLayer src_embedding = 1; + repeated VitEncoderLayer encoder_stack = 2; + VitModelConf model_conf = 3; +} diff --git a/lightseq/inference/proto/vit_weight.cc b/lightseq/inference/proto/vit_weight.cc new file mode 100644 index 00000000..d958f2a6 --- /dev/null +++ b/lightseq/inference/proto/vit_weight.cc @@ -0,0 +1,510 @@ +#include "vit_weight.h" + +#include + +/** +@file +Load the model weights which stored in custom proto file into GPU memory. +Currently, fp16 and fp32 versions are provided. +Weights in proto file will always be in fp32. For fp16, the weights + will be casted from fp32 into fp16 +*/ + +namespace lightseq { +namespace cuda { + +/** +Cast weights into required datatype. +The datatype of weights in custom proto file will always be in fp32. +*/ +template <> +float VitWeight::float2required(float value) { + return value; +} + +/** +fp16 version, cast fp32 into fp16 +*/ +template <> +__half VitWeight::float2required(float value) { + return __float2half_rn(value); +} + +/** +Read model config stored in custom proto file. +*/ +template +void VitWeight::proto_get_model_config(const Vit &vit) { + _hidden_size = vit.src_embedding().cls_embedding_size(); + _inner_size = vit.encoder_stack()[0].ffn_first_kernel_size() / _hidden_size; + _max_step = vit.src_embedding().position_embedding_size() / _hidden_size; + _n_enc_layer = vit.encoder_stack_size(); + _head_num = vit.model_conf().head_num(); + _dim_per_head = _hidden_size / _head_num; + _weight_per_enc_layer = 12; + _use_gelu = vit.model_conf().use_gelu(); + _image_size = vit.model_conf().image_size(); + _patch_size = vit.model_conf().patch_size(); + _channel_input = vit.src_embedding().conv_weight_size() / + (_hidden_size * _patch_size * _patch_size); + _is_post_ln = vit.model_conf().is_post_ln(); +} + +/** +Load the weights of embedding layer into GPU memory. +*/ +template +std::string VitWeight::proto_parse_emb_wei( + const VitEmbeddingLayer &layer) { + std::vector offset; + std::vector value; + int idx = 0; + + offset.push_back(idx); + if (layer.conv_weight_size() != + _hidden_size * _channel_input * _patch_size * _patch_size) + return "wrong conv_weight_size !"; + for (float ele : layer.conv_weight()) value.push_back(ele); + idx += _channel_input * _hidden_size * _patch_size * _patch_size; + + offset.push_back(idx); + if (layer.conv_bias_size() != _hidden_size) return "wrong conv_bias_size !"; + for (float ele : layer.conv_bias()) value.push_back(ele); + idx += _hidden_size; + + offset.push_back(idx); + if (layer.position_embedding_size() != _max_step * _hidden_size) + return "wrong position_embedding_size !"; + for (float ele : layer.position_embedding()) value.push_back(ele); + idx += _max_step * _hidden_size; + + offset.push_back(idx); + if (layer.cls_embedding_size() != _hidden_size) + return "wrong cls_embedding_size !"; + for (float ele : layer.cls_embedding()) value.push_back(ele); + idx += _hidden_size; + + offset.push_back(idx); + if (layer.norm_scale_size() != _hidden_size) return "wrong norm_scale_size !"; + for (float ele : layer.norm_scale()) value.push_back(ele); + idx += _hidden_size; + + offset.push_back(idx); + if (layer.norm_bias_size() != _hidden_size) return "wrong norm_bias_size !"; + for (float ele : layer.norm_bias()) value.push_back(ele); + idx += _hidden_size; + + std::vector<_DataType> raw_value; + for (float e : value) raw_value.push_back(float2required(e)); + _d_src_emb_wei = raw_value; + for (int e : offset) + _p_d_src_emb_wei.push_back(thrust::raw_pointer_cast(_d_src_emb_wei.data()) + + e); + + std::cout << "finish initializing emb_wei from host to device" << std::endl; + return ""; +} + +/** +Load the weights of encoder into GPU memory. +*/ +template +std::string VitWeight::proto_parse_enc_wei(const Vit &vit) { + std::vector offset; + std::vector value; + int idx = 0; + + for (auto enc_layer : vit.encoder_stack()) { + offset.push_back(idx); + if (enc_layer.multihead_norm_scale_size() != _hidden_size) + return "wrong multihead_norm_scale_size !"; + for (float ele : enc_layer.multihead_norm_scale()) value.push_back(ele); + idx += _hidden_size; + + offset.push_back(idx); + if (enc_layer.multihead_norm_bias_size() != _hidden_size) + return "wrong multihead_norm_bias_size !"; + for (float ele : enc_layer.multihead_norm_bias()) value.push_back(ele); + idx += _hidden_size; + + offset.push_back(idx); + if (enc_layer.multihead_project_kernel_qkv_size() != + _hidden_size * _hidden_size * 3) + return "wrong multihead_project_kernel_qkv_size !"; + for (float ele : enc_layer.multihead_project_kernel_qkv()) + value.push_back(ele); + idx += _hidden_size * _hidden_size * 3; + + offset.push_back(idx); + if (enc_layer.multihead_project_bias_qkv_size() != _hidden_size * 3) + return "wrong multihead_project_bias_qkv_size !"; + for (float ele : enc_layer.multihead_project_bias_qkv()) + value.push_back(ele); + idx += _hidden_size * 3; + + offset.push_back(idx); + if (enc_layer.multihead_project_kernel_output_size() != + _hidden_size * _hidden_size) + return "wrong multihead_project_kernel_output_size !"; + for (float ele : enc_layer.multihead_project_kernel_output()) + value.push_back(ele); + idx += _hidden_size * _hidden_size; + + offset.push_back(idx); + if (enc_layer.multihead_project_bias_output_size() != _hidden_size) + return "wrong multihead_project_bias_output_size !"; + for (float ele : enc_layer.multihead_project_bias_output()) + value.push_back(ele); + idx += _hidden_size; + + offset.push_back(idx); + if (enc_layer.ffn_norm_scale_size() != _hidden_size) + return "wrong ffn_norm_scale_size !"; + for (float ele : enc_layer.ffn_norm_scale()) value.push_back(ele); + idx += _hidden_size; + + offset.push_back(idx); + if (enc_layer.ffn_norm_bias_size() != _hidden_size) + return "wrong ffn_norm_bias_size !"; + for (float ele : enc_layer.ffn_norm_bias()) value.push_back(ele); + idx += _hidden_size; + + offset.push_back(idx); + if (enc_layer.ffn_first_kernel_size() != _hidden_size * _inner_size) + return "wrong ffn_first_kernel_size !"; + for (float ele : enc_layer.ffn_first_kernel()) value.push_back(ele); + idx += _hidden_size * _inner_size; + + offset.push_back(idx); + if (enc_layer.ffn_first_bias_size() != _inner_size) + return "wrong ffn_first_bias_size !"; + for (float ele : enc_layer.ffn_first_bias()) value.push_back(ele); + idx += _inner_size; + + offset.push_back(idx); + if (enc_layer.ffn_second_kernel_size() != _hidden_size * _inner_size) + return "wrong ffn_second_kernel_size !"; + for (float ele : enc_layer.ffn_second_kernel()) value.push_back(ele); + idx += _hidden_size * _inner_size; + + offset.push_back(idx); + if (enc_layer.ffn_second_bias_size() != _hidden_size) + return "wrong ffn_second_bias_size !"; + for (float ele : enc_layer.ffn_second_bias()) value.push_back(ele); + idx += _hidden_size; + + } // for + + std::vector<_DataType> raw_value; + for (float e : value) raw_value.push_back(float2required(e)); + _d_enc_wei = raw_value; + + for (int e : offset) + _p_d_enc_wei.push_back(thrust::raw_pointer_cast(_d_enc_wei.data()) + e); + std::cout << "finish initializing enc_wei from host to device" << std::endl; + return ""; +} + +/** +Read model config stored in custom hdf5 file. +*/ +template +void VitWeight::hdf5_get_model_config(hid_t hdf5_file) { + _hidden_size = + get_hdf5_dataset_size(hdf5_file, "src_embedding/cls_embedding"); + + _inner_size = + get_hdf5_dataset_size(hdf5_file, "encoder_stack/0/ffn_first_kernel") / + _hidden_size; + + _max_step = + get_hdf5_dataset_size(hdf5_file, "src_embedding/position_embedding") / + _hidden_size; + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/n_encoder_stack", + H5T_NATIVE_INT, &_n_enc_layer); + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/head_num", H5T_NATIVE_INT, + &_head_num); + + _dim_per_head = _hidden_size / _head_num; + _weight_per_enc_layer = 12; + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/is_post_ln", H5T_NATIVE_HBOOL, + &_is_post_ln); + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/use_gelu", H5T_NATIVE_HBOOL, + &_use_gelu); + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/image_size", H5T_NATIVE_INT, + &_image_size); + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/patch_size", H5T_NATIVE_INT, + &_patch_size); + + _channel_input = + get_hdf5_dataset_size(hdf5_file, "src_embedding/conv_weight") / + (_hidden_size * _patch_size * _patch_size); +} + +/** +Load the weights of embedding layer into GPU memory. +*/ +template +void VitWeight::hdf5_parse_emb_wei(hid_t hdf5_file) { + std::string dataset_prefix = "src_embedding"; + + size_t value_size = + _channel_input * _hidden_size * _patch_size * _patch_size + + _max_step * _hidden_size + 4 * _hidden_size; + + std::vector offset; + std::vector value(value_size); // preallocate vector for performance + std::cout << "loading " << value_size * sizeof(OpType_) / (1024 * 1024) + << " MB of embedding weight." << std::endl; + int idx = 0; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/conv_weight", H5T_NATIVE_FLOAT, + value.data() + idx, + [=](int size) { + return size != + _hidden_size * _channel_input * _patch_size * _patch_size; + }, + "Wrong conv_weight_size !"); + idx += _channel_input * _hidden_size * _patch_size * _patch_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/conv_bias", H5T_NATIVE_FLOAT, + value.data() + idx, [=](int size) { return size != _hidden_size; }, + "Wrong conv_bias_size !"); + idx += _hidden_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/position_embedding", H5T_NATIVE_FLOAT, + value.data() + idx, + [=](int size) { return size != _max_step * _hidden_size; }, + "Wrong position_embedding_size !"); + idx += _max_step * _hidden_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/cls_embedding", H5T_NATIVE_FLOAT, + value.data() + idx, [=](int size) { return size != _hidden_size; }, + "Wrong cls_embedding_size !"); + idx += _hidden_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/norm_scale", H5T_NATIVE_FLOAT, + value.data() + idx, [=](int size) { return size != _hidden_size; }, + "Wrong norm_scale_size !"); + idx += _hidden_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/norm_bias", H5T_NATIVE_FLOAT, + value.data() + idx, [=](int size) { return size != _hidden_size; }, + "Wrong norm_bias_size !"); + idx += _hidden_size; + + std::vector<_DataType> raw_value; + raw_value.reserve(value.size()); + for (float e : value) raw_value.push_back(float2required(e)); + _d_src_emb_wei = raw_value; + for (int e : offset) + _p_d_src_emb_wei.push_back(thrust::raw_pointer_cast(_d_src_emb_wei.data()) + + e); + + std::cout << "Finish loading src_emb_wei from host to device" << std::endl; +} + +/** +Load the weights of encoder into GPU memory. +*/ +template +void VitWeight::hdf5_parse_enc_wei(hid_t hdf5_file) { + size_t value_size = + (_hidden_size * 2 + _hidden_size * _hidden_size * 3 + _hidden_size * 3 + + _hidden_size * _hidden_size + _hidden_size * 3 + + _hidden_size * _inner_size + _inner_size + _hidden_size * _inner_size + + _hidden_size) * + _n_enc_layer; + std::vector offset; + std::vector value(value_size); + std::cout << "loading " << value_size * sizeof(OpType_) / (1024 * 1024) + << " MB of encoder weight." << std::endl; + + int idx = 0; + for (int layer_id = 0; layer_id < _n_enc_layer; ++layer_id) { + std::string dataset_prefix = "encoder_stack/" + std::to_string(layer_id); + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/multihead_norm_scale", H5T_NATIVE_FLOAT, + value.data() + idx, [=](int size) { return size != _hidden_size; }, + "Wrong multihead_norm_scale_size !"); + idx += _hidden_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/multihead_norm_bias", H5T_NATIVE_FLOAT, + value.data() + idx, [=](int size) { return size != _hidden_size; }, + "Wrong multihead_norm_bias_size !"); + idx += _hidden_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/multihead_project_kernel_qkv", + H5T_NATIVE_FLOAT, value.data() + idx, + [=](int size) { return size != _hidden_size * _hidden_size * 3; }, + "Wrong multihead_project_kernel_qkv_size !"); + idx += _hidden_size * _hidden_size * 3; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/multihead_project_bias_qkv", + H5T_NATIVE_FLOAT, value.data() + idx, + [=](int size) { return size != _hidden_size * 3; }, + "Wrong multihead_project_bias_qkv_size !"); + idx += _hidden_size * 3; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/multihead_project_kernel_output", + H5T_NATIVE_FLOAT, value.data() + idx, + [=](int size) { return size != _hidden_size * _hidden_size; }, + "Wrong multihead_project_kernel_output_size !"); + idx += _hidden_size * _hidden_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/multihead_project_bias_output", + H5T_NATIVE_FLOAT, value.data() + idx, + [=](int size) { return size != _hidden_size; }, + "Wrong multihead_project_bias_output_size !"); + idx += _hidden_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/ffn_norm_scale", H5T_NATIVE_FLOAT, + value.data() + idx, [=](int size) { return size != _hidden_size; }, + "Wrong ffn_norm_scale_size !"); + idx += _hidden_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/ffn_norm_bias", H5T_NATIVE_FLOAT, + value.data() + idx, [=](int size) { return size != _hidden_size; }, + "Wrong ffn_norm_bias_size !"); + idx += _hidden_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/ffn_first_kernel", H5T_NATIVE_FLOAT, + value.data() + idx, + [=](int size) { return size != _hidden_size * _inner_size; }, + "Wrong ffn_first_kernel_size !"); + idx += _hidden_size * _inner_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/ffn_first_bias", H5T_NATIVE_FLOAT, + value.data() + idx, [=](int size) { return size != _inner_size; }, + "Wrong ffn_first_bias_size !"); + idx += _inner_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/ffn_second_kernel", H5T_NATIVE_FLOAT, + value.data() + idx, + [=](int size) { return size != _hidden_size * _inner_size; }, + "Wrong ffn_second_kernel_size !"); + idx += _hidden_size * _inner_size; + + offset.push_back(idx); + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/ffn_second_bias", H5T_NATIVE_FLOAT, + value.data() + idx, [=](int size) { return size != _hidden_size; }, + "Wrong ffn_second_bias_size !"); + idx += _hidden_size; + + } // for + + std::vector<_DataType> raw_value; + raw_value.reserve(value.size()); + for (float e : value) raw_value.push_back(float2required(e)); + _d_enc_wei = raw_value; + + for (int e : offset) + _p_d_enc_wei.push_back(thrust::raw_pointer_cast(_d_enc_wei.data()) + e); + std::cout << "Finish loading enc_wei from host to device" << std::endl; +} + +/** +Load the proto file into CPU memory and parse it. +*/ +template +std::string VitWeight::initializing(std::string weight_path) { + if (endswith(weight_path, ".pb")) { + std::cout << "Parsing protobuf: " << weight_path << std::endl; + Vit vit; + // Verify that the version of the library that we linked against is + // compatible with the version of the headers we compiled against. + GOOGLE_PROTOBUF_VERIFY_VERSION; + + std::fstream raw_input(weight_path, std::ios::in | std::ios::binary); + if (!vit.ParseFromIstream(&raw_input)) { + return "Parse weights from [" + weight_path + "] failed."; + } + + proto_get_model_config(vit); + if (_hidden_size % 4 != 0) { + return "hidden_size should be a multiple of 4 to avoid misaligned " + "address in CUDA"; + } + + std::string res = proto_parse_emb_wei(vit.src_embedding()); + if (!res.empty()) return res; + + res = proto_parse_enc_wei(vit); + if (!res.empty()) return res; + + std::cout << "finish initializing all weight from host to device" + << std::endl; + // Optional: Delete all global objects allocated by libprotobuf. + // google::protobuf::ShutdownProtobufLibrary(); + return ""; + } else if (endswith(weight_path, ".hdf5")) { + std::cout << "Parsing hdf5: " << weight_path << std::endl; + + hid_t hdf5_file = H5Fopen(weight_path.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); + if (hdf5_file < 0) { + return "Unable to read HDF5 file from " + weight_path; + } + hdf5_get_model_config(hdf5_file); + if (_hidden_size % 4 != 0) { + return "hidden_size should be a multiple of 4 to avoid misaligned " + "address in CUDA"; + } + // hdf5_parse_* would throw std::runtime_error on error + hdf5_parse_emb_wei(hdf5_file); + hdf5_parse_enc_wei(hdf5_file); + H5Fclose(hdf5_file); + + std::cout << "Finish loading all weight from host to device" << std::endl; + return ""; + } else { + return "Unsupported weight extention for [" + weight_path + + "]; Supported extensions: .pb, .hdf5\n"; + } +} + +template class VitWeight; +template class VitWeight; + +} // namespace cuda +} // namespace lightseq diff --git a/lightseq/inference/proto/vit_weight.h b/lightseq/inference/proto/vit_weight.h new file mode 100644 index 00000000..77904d41 --- /dev/null +++ b/lightseq/inference/proto/vit_weight.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "vit.pb.h" +#include "../tools/util.h" + +namespace lightseq { +namespace cuda { + +/* +Load the model weights which stored in custom proto file into GPU memory. +*/ +template +class VitWeight { + private: + typedef OperationTypeTraits _optraits; + typedef typename _optraits::DataType _DataType; + _DataType float2required(float value); + void proto_get_model_config(const Vit &vit); + std::string proto_parse_emb_wei(const VitEmbeddingLayer &layer); + std::string proto_parse_enc_wei(const Vit &vit); + + void hdf5_get_model_config(hid_t hdf5_file); + void hdf5_parse_emb_wei(hid_t hdf5_file); + void hdf5_parse_enc_wei(hid_t hdf5_file); + // store the weights pointer + std::vector _p_d_src_emb_wei; // size: 6 + std::vector _p_d_enc_wei; // size: 12 * enc_layer_num + + // store the weights on gpu memory + thrust::device_vector<_DataType> _d_src_emb_wei; + thrust::device_vector<_DataType> _d_enc_wei; + + public: + std::string initializing(std::string proto_path); + + const std::vector &get_src_emb_wei() const { + // {conv_weight, conv_bias, pos_emb, cls_embedding} + return _p_d_src_emb_wei; + } + + const std::vector &get_enc_wei() const { + // {multihead_norm_scale, multihead_norm_bias, multihead_qkv_kernel, + // multihead_qkv_bias multihead_output_kernel, multihead_output_bias + // ffn_norm_scale, ffn_norm_bias} + // ffn_first_kernel, ffn_first_bias, ffn_second_kernel, ffn_second_bias} * + // encoder_layer_num + return _p_d_enc_wei; + } + + int _hidden_size; + int _inner_size; + int _max_step; + int _n_enc_layer; // number of encoder layer + int _dim_per_head; + int _weight_per_enc_layer; // 12 + int _image_size; + int _patch_size; + int _channel_input; + + int _head_num; + bool _is_post_ln; + bool _use_gelu; + + void print_model_config() { + std::cout << "***model config***" << std::endl; + std::cout << "encoder layers: " << _n_enc_layer << std::endl; + std::cout << "hidden size: " << _hidden_size << std::endl; + std::cout << "inner size: " << _inner_size << std::endl; + std::cout << "head number: " << _head_num << std::endl; + std::cout << "dim per head: " << _dim_per_head << std::endl; + std::cout << "use_gelu: " << _use_gelu << std::endl; + std::cout << "is_post_ln: " << _is_post_ln << std::endl; + std::cout << "image_size: " << _image_size << std::endl; + std::cout << "patch_size: " << _patch_size << std::endl; + std::cout << "channel_input: " << _channel_input << std::endl; + std::cout << std::endl; + } +}; + +} // namespace cuda +} // namespace lightseq diff --git a/lightseq/inference/pywrapper/CMakeLists.txt b/lightseq/inference/pywrapper/CMakeLists.txt index a1d6cd6d..9b1a9219 100644 --- a/lightseq/inference/pywrapper/CMakeLists.txt +++ b/lightseq/inference/pywrapper/CMakeLists.txt @@ -8,21 +8,24 @@ pybind11_add_module( gpt.cc bert.cc quant_transformer.cc - moe.cc) + moe.cc + vit.cc) target_link_libraries(lightseq PUBLIC gpt_model) target_link_libraries(lightseq PUBLIC bert_model) target_link_libraries(lightseq PUBLIC transformer_model) target_link_libraries(lightseq PUBLIC quant_transformer_model) target_link_libraries(lightseq PUBLIC moe_model) +target_link_libraries(lightseq PUBLIC vit_model) set_target_properties(lightseq PROPERTIES OUTPUT_NAME inference) add_library(liblightseq SHARED transformer.cc gpt.cc bert.cc - quant_transformer.cc moe.cc) + quant_transformer.cc moe.cc vit.cc) target_link_libraries(liblightseq PUBLIC transformer_model) target_link_libraries(liblightseq PUBLIC quant_transformer_model) target_link_libraries(liblightseq PUBLIC gpt_model) target_link_libraries(liblightseq PUBLIC bert_model) target_link_libraries(liblightseq PUBLIC moe_model) +target_link_libraries(liblightseq PUBLIC vit_model) target_link_options(liblightseq PUBLIC $) target_include_directories(liblightseq PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/lightseq/inference/pywrapper/vit.cc b/lightseq/inference/pywrapper/vit.cc new file mode 100644 index 00000000..f641acb1 --- /dev/null +++ b/lightseq/inference/pywrapper/vit.cc @@ -0,0 +1,161 @@ +#include "vit.h" + +namespace lightseq { +namespace cuda { + +Vit::Vit(const std::string weight_path, const int max_batch_size) + : LSModel({"pixel_values"}, {"encoder_output"}), + _max_batch_size(max_batch_size) { + /* ---step1. init environment--- */ + CHECK_GPU_ERROR(cudaSetDevice(0)); + CHECK_GPU_ERROR(cudaStreamCreate(&stream_)); + CHECK_GPU_ERROR(cublasCreate(&hd_)); + CHECK_GPU_ERROR(cublasSetStream(hd_, stream_)); + + /* ---step2. load model weights into GPU memory--- */ + + // saved in custom proto file + std::string model_weights_path = weight_path; + std::string res = tw_.initializing(model_weights_path); + if (!res.empty()) { + throw std::runtime_error(res); + } + + tw_.print_model_config(); + + /* + step3. instantiate encoder and decoder, init the gpu memory buffer. + using thrust vector to avoid manage gpu memory by hand + */ + + // register device memory for inputs and outputs + CHECK_GPU_ERROR(cudaMalloc(&d_input_, _max_batch_size * tw_._channel_input * + tw_._image_size * tw_._image_size * + sizeof(float))); + CHECK_GPU_ERROR(cudaMalloc(&d_padding_mask_, + _max_batch_size * tw_._max_step * sizeof(int))); + + CHECK_GPU_ERROR(cudaMalloc( + &d_encoder_output_, _max_batch_size * tw_._max_step * tw_._hidden_size * + sizeof(optraits::DataType))); + + encoder_ = std::make_shared>( + max_batch_size, d_input_, d_padding_mask_, d_encoder_output_, tw_, + stream_, hd_); + res = encoder_->check(); + if (!res.empty()) { + throw std::runtime_error(res); + } + + long buf_bytesize = encoder_->compute_buffer_bytesize(); + std::cout << "Vit buf_bytesize: " << buf_bytesize << std::endl; + + // encoder and decoder use the same buffer to save gpu memory useage + CHECK_GPU_ERROR(cudaMalloc(&d_buf_, (size_t)buf_bytesize)); + encoder_->init_buffer(d_buf_); + CHECK_GPU_ERROR(cudaStreamSynchronize(stream_)); +} + +Vit::~Vit() { + CHECK_GPU_ERROR(cudaFree(d_input_)); + CHECK_GPU_ERROR(cudaFree(d_padding_mask_)); + CHECK_GPU_ERROR(cudaFree(d_encoder_output_)); + CHECK_GPU_ERROR(cudaFree(d_buf_)); + CHECK_GPU_ERROR(cublasDestroy(hd_)); + CHECK_GPU_ERROR(cudaStreamDestroy(stream_)); +} + +void Vit::Infer() { + int batch_size = input_shapes_[0][0]; + encoder_->run_one_infer(batch_size); + CHECK_GPU_ERROR(cudaStreamSynchronize(stream_)); + set_output_shape(0, {batch_size, tw_._max_step, tw_._hidden_size}); +} + +void Vit::set_input_ptr(int index, void *input_ptr) { + switch (index) { + case 0: + encoder_->_p_d_pixel_input = static_cast(input_ptr); + break; + + default: + throw std::runtime_error("invalid input index"); + break; + } +} + +void Vit::set_output_ptr(int index, void *output_ptr) { + switch (index) { + case 0: + encoder_->_p_d_output = static_cast(output_ptr); + break; + + default: + throw std::runtime_error("invalid output index"); + break; + } +} + +const void *Vit::get_output_ptr(int index) { + switch (index) { + case 0: + return static_cast(encoder_->_p_d_output); + + default: + throw std::runtime_error("invalid output index"); + break; + } +} + +std::vector Vit::get_input_max_shape(int index) { + switch (index) { + case 0: + return {_max_batch_size, tw_._channel_input, tw_._image_size, + tw_._image_size}; + + default: + throw std::runtime_error("invalid input index"); + break; + } +} +std::vector Vit::get_output_max_shape(int index) { + switch (index) { + case 0: + return {_max_batch_size, tw_._max_step, tw_._hidden_size}; + + default: + throw std::runtime_error("invalid output index"); + break; + } +} + +DataType Vit::get_input_dtype(int index) { + switch (index) { + case 0: + return DataType::kFloat32; + break; + + default: + throw std::runtime_error("invalid input index"); + break; + } +} + +DataType Vit::get_output_dtype(int index) { + switch (index) { + case 0: + if (vit_optype == OperationType::FP32) { + return DataType::kFloat32; + } else { + return DataType::kFloat16; + } + break; + + default: + throw std::runtime_error("invalid output index"); + break; + } +} + +} // namespace cuda +} // namespace lightseq diff --git a/lightseq/inference/pywrapper/vit.h b/lightseq/inference/pywrapper/vit.h new file mode 100644 index 00000000..9bfd89c0 --- /dev/null +++ b/lightseq/inference/pywrapper/vit.h @@ -0,0 +1,49 @@ + +#include "model_base.h" +#include "../model/vit_encoder.h" +#include "../proto/vit_weight.h" +#include "../tools/util.h" + +#ifdef FP16_MODE +const lightseq::cuda::OperationType vit_optype = + lightseq::cuda::OperationType::FP16; +#else +const lightseq::cuda::OperationType vit_optype = + lightseq::cuda::OperationType::FP32; +#endif + +namespace lightseq { +namespace cuda { +class Vit : public LSModel { + private: + typedef OperationTypeTraits optraits; + std::shared_ptr> encoder_; + + optraits::DataType *d_encoder_output_; + float *d_input_; + int *d_padding_mask_; + int _max_batch_size; + cudaStream_t stream_; + cublasHandle_t hd_; + void *d_buf_; + VitWeight tw_; + + public: + Vit(const std::string weight_path, const int max_batch_size); + + ~Vit(); + + void Infer() override; + void set_input_ptr(int index, void *input_ptr) override; + void set_output_ptr(int index, void *output_ptr) override; + const void *get_output_ptr(int index) override; + std::vector get_input_max_shape(int index) override; + std::vector get_output_max_shape(int index) override; + DataType get_input_dtype(int index) override; + DataType get_output_dtype(int index) override; +}; + +LSMODEL_REGISTER(Vit); + +} // namespace cuda +} // namespace lightseq diff --git a/lightseq/inference/pywrapper/wrapper.cc b/lightseq/inference/pywrapper/wrapper.cc index 7b0dd438..1d5de797 100644 --- a/lightseq/inference/pywrapper/wrapper.cc +++ b/lightseq/inference/pywrapper/wrapper.cc @@ -415,6 +415,90 @@ class PyMoe { } }; +class PyVit { + private: + lightseq::cuda::LSModel *model_; + float *d_input_; + std::vector d_outputs_; + + public: + PyVit(std::string weight_path, int max_batch_size) { + model_ = lightseq::cuda::LSModelFactory::GetInstance().CreateModel( + "Vit", weight_path, max_batch_size); + std::vector max_input_shape = model_->get_input_max_shape(0); + int max_size = + std::accumulate(max_input_shape.begin(), max_input_shape.end(), 1, + std::multiplies()); + lightseq::cuda::CHECK_GPU_ERROR( + cudaMalloc(&d_input_, sizeof(float) * max_size)); + + for (int i = 0; i < model_->get_output_size(); i++) { + void *d_output; + std::vector shape = model_->get_output_max_shape(i); + int output_size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + lightseq::cuda::CHECK_GPU_ERROR( + cudaMalloc(&d_output, output_size * sizeof(int))); + model_->set_output_ptr(i, d_output); + d_outputs_.push_back(d_output); + } + } + ~PyVit() { + delete model_; + lightseq::cuda::CHECK_GPU_ERROR(cudaFree(d_input_)); + for (auto d_output : d_outputs_) { + lightseq::cuda::CHECK_GPU_ERROR(cudaFree(d_output)); + } + } + + py::array_t infer( + py::array_t input_seq) { + auto input_seq_out = input_seq.mutable_unchecked<4>(); + const float *input_seq_data = input_seq_out.data(0, 0, 0, 0); + int batch_size = input_seq_out.shape(0); + int channel_input = input_seq_out.shape(1); + int image_size = input_seq_out.shape(2); + + lightseq::cuda::CHECK_GPU_ERROR(cudaMemcpy( + d_input_, input_seq_data, sizeof(float) * input_seq_out.size(), + cudaMemcpyHostToDevice)); + + model_->set_input_ptr(0, d_input_); + model_->set_input_shape( + 0, {batch_size, channel_input, image_size, image_size}); + + model_->Infer(); + + std::vector output_shape = model_->get_output_shape(0); + auto output = py::array_t(output_shape); + float *output_data = output.mutable_data(0, 0); + lightseq::cuda::DataType output_type = model_->get_output_dtype(0); + if (output_type == lightseq::cuda::kFloat32) { + const float *d_output = + static_cast(model_->get_output_ptr(0)); + + lightseq::cuda::CHECK_GPU_ERROR(cudaMemcpy(output_data, d_output, + sizeof(float) * output.size(), + cudaMemcpyDeviceToHost)); + } else if (output_type == lightseq::cuda::kFloat16) { + const half *d_output = + static_cast(model_->get_output_ptr(0)); + std::vector h_vit_out(output.size()); + lightseq::cuda::CHECK_GPU_ERROR(cudaMemcpy(h_vit_out.data(), d_output, + sizeof(half) * output.size(), + cudaMemcpyDeviceToHost)); + for (auto i = 0; i < h_vit_out.size(); i++) { + float f_data = __half2float(h_vit_out[i]); + output_data[i] = f_data; + } + } else { + throw std::runtime_error("Not supported output type"); + } + + return output; + } +}; + PYBIND11_MODULE(inference, m) { m.attr("__name__") = "lightseq.inference"; py::class_(m, "TransformerDecoder") @@ -453,4 +537,10 @@ PYBIND11_MODULE(inference, m) { py::arg("max_batch_size")) .def("infer", &PyMoe::infer, py::return_value_policy::reference_internal, py::arg("input_seq")); + + py::class_(m, "Vit") + .def(py::init(), py::arg("weight_path"), + py::arg("max_batch_size")) + .def("infer", &PyVit::infer, py::return_value_policy::reference_internal, + py::arg("pixel_values")); } diff --git a/lightseq/training/ops/pytorch/export.py b/lightseq/training/ops/pytorch/export.py index 0a692c3c..cf4ec914 100644 --- a/lightseq/training/ops/pytorch/export.py +++ b/lightseq/training/ops/pytorch/export.py @@ -73,8 +73,8 @@ def check_rule(tensor_name, rule): except: target_tensor = tt["save"] print( - "%s -> %s, convert finished!" - % (target_tn if target_tn else "created", proto_name) + "%s -> %s, shape: %s, convert finished!" + % (target_tn if target_tn else "created", proto_name, target_tensor.shape) ) return target_tensor