This is the official repository of
FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization. Pavan Kumar Anasosalu Vasu, James Gabriel, Jeff Zhu, Oncel Tuzel, Anurag Ranjan. ICCV 2023
All models are trained on ImageNet-1K and benchmarked on iPhone 12 Pro using ModelBench app.
conda create -n fastvit python=3.9
conda activate fastvit
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt
To use our model, follow the code snippet below,
import torch
import models
from timm.models import create_model
from models.modules.mobileone import reparameterize_model
# To Train from scratch/fine-tuning
model = create_model("fastvit_t8")
# ... train ...
# Load unfused pre-trained checkpoint for fine-tuning
# or for downstream task training like detection/segmentation
checkpoint = torch.load('/path/to/unfused_checkpoint.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
# ... train ...
# For inference
model.eval()
model_inf = reparameterize_model(model)
# Use model_inf at test-time
Models trained on ImageNet-1K
Model | Top-1 Acc. | Latency | Pytorch Checkpoint (url) | CoreML Model |
---|---|---|---|---|
FastViT-T8 | 76.2 | 0.8 | T8(unfused) | fastvit_t8.mlpackage.zip |
FastViT-T12 | 79.3 | 1.2 | T12(unfused) | fastvit_t12.mlpackage.zip |
FastViT-S12 | 79.9 | 1.4 | S12(unfused) | fastvit_s12.mlpackage.zip |
FastViT-SA12 | 80.9 | 1.6 | SA12(unfused) | fastvit_sa12.mlpackage.zip |
FastViT-SA24 | 82.7 | 2.6 | SA24(unfused) | fastvit_sa24.mlpackage.zip |
FastViT-SA36 | 83.6 | 3.5 | SA36(unfused) | fastvit_sa36.mlpackage.zip |
FastViT-MA36 | 83.9 | 4.6 | MA36(unfused) | fastvit_ma36.mlpackage.zip |
Models trained on ImageNet-1K with knowledge distillation.
Model | Top-1 Acc. | Latency | Pytorch Checkpoint (url) | CoreML Model |
---|---|---|---|---|
FastViT-T8 | 77.2 | 0.8 | T8(unfused) | fastvit_t8.mlpackage.zip |
FastViT-T12 | 80.3 | 1.2 | T12(unfused) | fastvit_t12.mlpackage.zip |
FastViT-S12 | 81.1 | 1.4 | S12(unfused) | fastvit_s12.mlpackage.zip |
FastViT-SA12 | 81.9 | 1.6 | SA12(unfused) | fastvit_sa12.mlpackage.zip |
FastViT-SA24 | 83.4 | 2.6 | SA24(unfused) | fastvit_sa24.mlpackage.zip |
FastViT-SA36 | 84.2 | 3.5 | SA36(unfused) | fastvit_sa36.mlpackage.zip |
FastViT-MA36 | 84.6 | 4.6 | MA36(unfused) | fastvit_ma36.mlpackage.zip |
Latency of all models measured on iPhone 12 Pro using ModelBench app. For further details please contact James Gabriel and Jeff Zhu. All reported numbers are rounded to the nearest decimal.
Download the ImageNet-1K dataset and structure the data as follows:
/path/to/imagenet-1k/
train/
class1/
img1.jpeg
class2/
img2.jpeg
validation/
class1/
img3.jpeg
class2/
img4.jpeg
To train a variant of FastViT model, follow the respective command below:
FastViT-T8
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256
# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256
--distillation-type "hard"
FastViT-T12
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256
# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256
--distillation-type "hard"
FastViT-S12
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_s12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256
# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_s12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256
--distillation-type "hard"
FastViT-SA12
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1
# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa12 -b 128 --lr 1e-3 \
--native-amp --output /path/to/save/results \
--input-size 3 256 256
--distillation-type "hard"
FastViT-SA24
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa24 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1
# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa24 -b 128 --lr 1e-3 \
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.05 \
--distillation-type "hard"
FastViT-SA36
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa36 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.2
# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa36 -b 128 --lr 1e-3 \
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1 \
--distillation-type "hard"
FastViT-MA36
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.35
# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.2 \
--distillation-type "hard"
To run evaluation on ImageNet, follow the example command below:
FastViT-T8
# Evaluate unfused checkpoint
python validate.py /path/to/ImageNet/dataset --model fastvit_t8 \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8.pth.tar
# Evaluate fused checkpoint
python validate.py /path/to/ImageNet/dataset --model fastvit_t8 \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8_reparam.pth.tar \
--use-inference-mode
To export a coreml package file from a pytorch checkpoint, follow the example command below:
FastViT-T8
python export_model.py --variant fastvit_t8 --output-dir /path/to/save/exported_model \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8_reparam.pth.tar
@inproceedings{vasufastvit2023,
author = {Pavan Kumar Anasosalu Vasu and James Gabriel and Jeff Zhu and Oncel Tuzel and Anurag Ranjan},
title = {FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
year = {2023}
}
Our codebase is built using multiple opensource contributions, please see ACKNOWLEDGEMENTS for more details.