Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more models. #58

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fake_quant/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ In this directory, we provide the torch scripts for the experiments in QuaRot.

## Language Generation and Zero-Shot Evaluations

Currently, we only support **LLaMa-2** models. You can simply run the `main.py` to reproduce the results in the paper. The most important arguments are:
Currently, we only support **LLaMa/Mistral/Qwen2** models. You can simply run the `main.py` to reproduce the results in the paper. The most important arguments are:

- `--model`: the model name (or path to the weights)
- `--bsz`: the batch size for PPL evaluation
Expand Down
Empty file added fake_quant/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion fake_quant/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def evaluator(model, testenc, dev, args):
if 'opt' in args.model:
opt_type = True
llama_type = False
elif 'meta' in args.model:
elif model_utils.is_llama_like_causal_lm(model):
llama_type = True
opt_type = False
else:
Expand Down
321 changes: 320 additions & 1 deletion fake_quant/hadamard_utils.py

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions fake_quant/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import eval_utils
import hadamard_utils


def main():
args = utils.parser_gen()
if args.wandb:
Expand Down Expand Up @@ -57,8 +58,9 @@ def main():
model.load_state_dict(save_dict["model"])

elif not args.w_rtn: # GPTQ Weight Quantization
assert "llama" in args.model, "Only llama is supported for GPTQ!"

# assert "llama" in args.model, "Only llama is supported for GPTQ!"
assert model_utils.is_supported_llama_like_model_name(args.model), "Only llama, qwen and mistral are supported for GPTQ!"

trainloader = data_utils.get_loaders(
args.cal_dataset, nsamples=args.nsamples,
seed=args.seed, model=args.model,
Expand All @@ -79,7 +81,7 @@ def main():
if args.a_bits < 16 or args.v_bits < 16:
qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper])
down_proj_groupsize = -1
if args.a_groupsize > 0 and "llama" in args.model:
if args.a_groupsize > 0 and model_utils.is_supported_llama_like_model_name(args.model):
down_proj_groupsize = utils.llama_down_proj_groupsize(model, args.a_groupsize)

for name in qlayers:
Expand Down
71 changes: 51 additions & 20 deletions fake_quant/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,24 @@
import os
import logging

from transformers import (LlamaForCausalLM, MistralForCausalLM, Qwen2ForCausalLM)

OPT_MODEL = transformers.models.opt.modeling_opt.OPTForCausalLM
OPT_LAYER = transformers.models.opt.modeling_opt.OPTDecoderLayer
LLAMA_MODEL = transformers.models.llama.modeling_llama.LlamaForCausalLM
LLAMA_LAYER = transformers.models.llama.modeling_llama.LlamaDecoderLayer


def is_supported_llama_like_model_name(model_name: str):
return ('llama' in model_name.lower()) or (
'qwen' in model_name.lower()) or (
'mistral' in model_name.lower())

def model_type_extractor(model):
if isinstance(model, LLAMA_MODEL):
return LLAMA_MODEL
if isinstance(model, LlamaForCausalLM):
return LlamaForCausalLM
elif isinstance(model, MistralForCausalLM):
return MistralForCausalLM
elif isinstance(model, Qwen2ForCausalLM):
return Qwen2ForCausalLM
elif isinstance(model, OPT_MODEL):
return OPT_MODEL
else:
Expand All @@ -23,27 +32,34 @@ def skip(*args, **kwargs):
# This is a helper function to save time during the initialization!
pass

def is_llama_like_causal_lm(model):
return isinstance(model, (LlamaForCausalLM, MistralForCausalLM, Qwen2ForCausalLM))


def is_llama_like_model_type(model_type):
return model_type in (LlamaForCausalLM, MistralForCausalLM, Qwen2ForCausalLM)

def get_rope_function_name(model):
if isinstance(model, LLAMA_MODEL):
if is_llama_like_causal_lm(model):
return "apply_rotary_pos_emb"
raise NotImplementedError


def get_layers(model):
if isinstance(model, OPT_MODEL):
return model.model.decoder.layers
if isinstance(model, LLAMA_MODEL):
if is_llama_like_causal_lm(model):
return model.model.layers
raise NotImplementedError


def get_llama(model_name, hf_token):
def get_llama_like(model_name, hf_token):
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = transformers.LlamaForCausalLM.from_pretrained(model_name, torch_dtype='auto',
use_auth_token=hf_token,
low_cpu_mem_usage=True)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype='auto',
use_auth_token=hf_token,
low_cpu_mem_usage=True)
model.seqlen = 2048
logging.info('---> Loading {} Model with seq_len: {}'.format(model_name, model.seqlen))
return model
Expand All @@ -64,15 +80,16 @@ def get_opt(model_name):
def get_model(
model_name, hf_token=None
):
if 'llama' in model_name:
return get_llama(model_name, hf_token)
if is_supported_llama_like_model_name(model_name):
return get_llama_like(model_name, hf_token)
elif 'opt' in model_name:
return get_opt(model_name)
else:
raise ValueError(f'Unknown model {model_name}')


def get_model_type(model):
return model_type_extractor(model)
if isinstance(model, OPT_MODEL):
model_type = OPT_MODEL
elif isinstance(model, LLAMA_MODEL):
Expand All @@ -82,7 +99,7 @@ def get_model_type(model):
return model_type

def get_embeddings(model, model_type) -> list[torch.nn.Module]:
if model_type == LLAMA_MODEL:
if is_llama_like_causal_lm(model):
return [model.model.embed_tokens]
elif model_type == OPT_MODEL:
return [model.model.decoder.embed_tokens, model.model.decoder.embed_positions]
Expand All @@ -91,7 +108,7 @@ def get_embeddings(model, model_type) -> list[torch.nn.Module]:


def get_transformer_layers(model, model_type):
if model_type == LLAMA_MODEL:
if is_llama_like_causal_lm(model):
return [layer for layer in model.model.layers]
elif model_type == OPT_MODEL:
return [layer for layer in model.model.decoder.layers]
Expand All @@ -100,18 +117,32 @@ def get_transformer_layers(model, model_type):


def get_lm_head(model, model_type):
if model_type == LLAMA_MODEL:
if is_llama_like_causal_lm(model):
return model.lm_head
elif model_type == OPT_MODEL:
return model.lm_head
else:
raise ValueError(f'Unknown model type {model_type}')

def get_norm_type(model):
if is_llama_like_causal_lm(model):
if isinstance(model, LlamaForCausalLM):
return transformers.models.llama.modeling_llama.LlamaRMSNorm
elif isinstance(model, MistralForCausalLM):
return transformers.models.mistral.modeling_mistral.MistralRMSNorm
elif isinstance(model, Qwen2ForCausalLM):
return transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
elif get_model_type(model) == OPT_MODEL:
return torch.nn.LayerNorm


def get_pre_head_layernorm(model, model_type):
if model_type == LLAMA_MODEL:
if is_llama_like_causal_lm(model):
pre_head_layernorm = model.model.norm
assert isinstance(pre_head_layernorm,
transformers.models.llama.modeling_llama.LlamaRMSNorm)
(transformers.models.llama.modeling_llama.LlamaRMSNorm,
transformers.models.mistral.modeling_mistral.MistralRMSNorm,
transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm))
elif model_type == OPT_MODEL:
pre_head_layernorm = model.model.decoder.final_layer_norm
assert pre_head_layernorm is not None
Expand All @@ -121,7 +152,7 @@ def get_pre_head_layernorm(model, model_type):

def get_mlp_bottleneck_size(model):
model_type = get_model_type(model)
if model_type == LLAMA_MODEL:
if is_llama_like_causal_lm(model):
return model.config.intermediate_size
elif model_type == OPT_MODEL:
return model.config.ffn_dim
Expand Down Expand Up @@ -167,7 +198,7 @@ class RMSN(torch.nn.Module):
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L75
"""

def __init__(self, mean_dim: int, eps=1e-5):
def __init__(self, mean_dim: int, eps=1e-6):
super().__init__()
self.eps = eps
self.mean_dim = mean_dim
Expand Down Expand Up @@ -196,7 +227,7 @@ def hook(module, input, output):

handles = []

if model_type == LLAMA_MODEL:
if is_llama_like_causal_lm(layer):
captured_inputs = {
'k_proj': [], # q_proj, v_proj has the same input as k_proj
'o_proj': [],
Expand Down
14 changes: 7 additions & 7 deletions fake_quant/rotation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def fuse_layer_norms(model):
for layer in layers:

# fuse the input layernorms into the linear layers
if model_type == model_utils.LLAMA_MODEL:
if model_utils.is_llama_like_causal_lm(model):
fuse_ln_linear(layer.post_attention_layernorm, [layer.mlp.up_proj, layer.mlp.gate_proj])
fuse_ln_linear(layer.input_layernorm, [layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj])
elif model_type == model_utils.OPT_MODEL:
Expand All @@ -79,7 +79,7 @@ def fuse_layer_norms(model):

model_utils.replace_modules(
model,
transformers.models.llama.modeling_llama.LlamaRMSNorm if model_type == model_utils.LLAMA_MODEL else torch.nn.LayerNorm,
model_utils.get_norm_type(model),
lambda _: model_utils.RMSN(model.config.hidden_size),
replace_layers=False,
)
Expand Down Expand Up @@ -132,7 +132,7 @@ def rotate_attention_inputs(layer, Q, model_type) -> None:

def rotate_attention_output(layer, Q, model_type) -> None:
# Rotate output matrix of the self-attention layer.
if model_type == model_utils.LLAMA_MODEL:
if model_utils.is_llama_like_model_type(model_type):
W = layer.self_attn.o_proj
elif model_type == model_utils.OPT_MODEL:
W = layer.self_attn.out_proj
Expand All @@ -148,7 +148,7 @@ def rotate_attention_output(layer, Q, model_type) -> None:

def rotate_mlp_input(layer, Q, model_type):
# Rotate the MLP input weights.
if model_type == model_utils.LLAMA_MODEL:
if model_utils.is_llama_like_model_type(model_type):
mlp_inputs = [layer.mlp.up_proj, layer.mlp.gate_proj]
elif model_type == model_utils.OPT_MODEL:
mlp_inputs = [layer.fc1]
Expand All @@ -161,7 +161,7 @@ def rotate_mlp_input(layer, Q, model_type):

def rotate_mlp_output(layer, Q, model_type):
# Rotate the MLP output weights and bias.
if model_type == model_utils.LLAMA_MODEL:
if model_utils.is_llama_like_model_type(model_type):
W = layer.mlp.down_proj
elif model_type == model_utils.OPT_MODEL:
W = layer.fc2
Expand Down Expand Up @@ -196,7 +196,7 @@ def matmul_hadU_cuda_had(X, hadK, transpose=False):

def rotate_faster_down_proj(layer, model_type, hardK):
from fast_hadamard_transform import hadamard_transform
if model_type == model_utils.LLAMA_MODEL:
if model_utils.is_llama_like_model_type(model_type):
W = layer.mlp.down_proj
else:
raise ValueError(f'Faster MLP is onlu supported for LLaMa models!')
Expand All @@ -215,7 +215,7 @@ def rotate_head(model, Q: torch.Tensor) -> None:

def rotate_ov_proj(layer, model_type, head_num, head_dim):
v_proj = layer.self_attn.v_proj
if model_type == model_utils.LLAMA_MODEL:
if model_utils.is_llama_like_model_type(model_type):
o_proj = layer.self_attn.o_proj
elif model_type == model_utils.OPT_MODEL:
o_proj = layer.self_attn.out_proj
Expand Down
18 changes: 18 additions & 0 deletions fake_quant/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# MODEL=Qwen/Qwen2.5-0.5B # Will fail because the head number is 14.
# MODEL=Qwen/Qwen2.5-1.5B
MODEL=Qwen/Qwen2.5-7B
# MODEL=Qwen/Qwen2.5-32B
# MODEL='NousResearch/Llama-3.2-1B'
# MODEL='NousResearch/Meta-Llama-3.1-8B'
# MODEL='NousResearch/Meta-Llama-3.1-70B'
# MODEL=mistralai/Mistral-7B-v0.3

python main.py \
--rotate \
--w_bits 4 --w_clip \
--a_bits 8 \
--percdamp 0.1 \
--w_asym \
--model $MODEL
# --w_groupsize 128 --a_groupsize 128 \
# --w_rtn \
12 changes: 2 additions & 10 deletions fake_quant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory

supported_models = [
'meta-llama/Llama-2-7b-hf',
'meta-llama/Llama-2-13b-hf',
'meta-llama/Llama-2-70b-hf',
'meta-llama/Meta-Llama-3-8B',
'meta-llama/Meta-Llama-3-70B',
'facebook/opt-125m'
]
supported_datasets = ['wikitext2', 'ptb', 'c4']

# These flags disable using TensorFloat-32 tensor cores (to avoid numerical issues)
Expand Down Expand Up @@ -73,7 +65,7 @@ def parser_gen():

# General Arguments
parser.add_argument('--model', type=str, default='meta-llama/Llama-2-7b-hf',
help='Model to load;', choices=supported_models)
help='Model to load;')
parser.add_argument('--seed', type=int, default=0, help='Random Seed for HuggingFace and PyTorch')
parser.add_argument('--eval_dataset', type=str, default='wikitext2',
help='Dataset for Evaluation (default: wikitext2)', choices=supported_datasets,)
Expand Down Expand Up @@ -251,7 +243,7 @@ def total_reserved_mem() -> int:

def distribute_model(model) -> None:
"""Distribute the model across available GPUs. NB: only implemented for Llama-2."""
no_split_module_classes = ['LlamaDecoderLayer']
no_split_module_classes = ['LlamaDecoderLayer', 'MistralDecoderLayer', 'Qwen2DecoderLayer']
max_memory = get_balanced_memory(
model,
no_split_module_classes=no_split_module_classes,
Expand Down