Skip to content

Commit

Permalink
[Misc] Move weights mapper (#11443)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Dec 24, 2024
1 parent 5c79632 commit 196c34b
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


class MyGemma2Embedding(nn.Module):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -62,8 +63,8 @@ def pooler(
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)

weights = self.hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
return self.model.load_weights(weights)
20 changes: 10 additions & 10 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,15 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
This model combines a vision tower, a multi-modal projector, and a language
model to perform tasks that involve both image and text inputs.
"""
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "language_model",
"language_model.lm_head": "lm_head",
},
orig_to_new_suffix={
"router.weight": "router_weight",
},
)

def __init__(
self,
Expand Down Expand Up @@ -662,15 +671,6 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "language_model",
"language_model.lm_head": "lm_head",
},
orig_to_new_suffix={
"router.weight": "router_weight",
},
)

loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=hf_to_vllm_mapper)
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
4 changes: 2 additions & 2 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ class BertEmbeddingModel(nn.Module):
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -441,8 +442,7 @@ def pooler(
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = self.hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
Expand Down
58 changes: 30 additions & 28 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
"image_projector.w1.": "image_projector.gate_proj.",
"image_projector.w3.": "image_projector.up_proj.",
"image_projector.w2.": "image_projector.down_proj.",
# language backbone mapping
"att_proj": "self_attn.qkv_proj",
"attn_out": "self_attn.o_proj",
"q_norm": "self_attn.q_norm",
"k_norm": "self_attn.k_norm",
"ff_proj": "mlp.gate_up_proj",
"ff_out": "mlp.down_proj",
"attn_norm": "input_layernorm",
"ff_norm": "post_attention_layernorm",
},
orig_to_new_prefix={
# vision backbone mapping
"model.vision_backbone.": "vision_backbone.",
# language backbone mapping
"model.transformer.blocks.": "model.layers.",
"model.transformer.ln_f.": "model.norm.",
# lm_head is renamed to model.transformer.mlp.down_proj firstly,
# we need to run a second renaming for it
"model.transformer.mlp.down_proj.": "lm_head.",
},
)

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -1298,36 +1326,10 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
"image_projector.w1.": "image_projector.gate_proj.",
"image_projector.w3.": "image_projector.up_proj.",
"image_projector.w2.": "image_projector.down_proj.",
# language backbone mapping
"att_proj": "self_attn.qkv_proj",
"attn_out": "self_attn.o_proj",
"q_norm": "self_attn.q_norm",
"k_norm": "self_attn.k_norm",
"ff_proj": "mlp.gate_up_proj",
"ff_out": "mlp.down_proj",
"attn_norm": "input_layernorm",
"ff_norm": "post_attention_layernorm",
},
orig_to_new_prefix={
# vision backbone mapping
"model.vision_backbone.": "vision_backbone.",
# language backbone mapping
"model.transformer.blocks.": "model.layers.",
"model.transformer.ln_f.": "model.norm.",
# lm_head is renamed to model.transformer.mlp.down_proj firstly,
# we need to run a second renaming for it
"model.transformer.mlp.down_proj.": "lm_head.",
},
)

loader = AutoWeightsLoader(self)
weights = _get_weights_with_merged_embedding(weights)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)


def _get_weights_with_merged_embedding(
Expand Down
16 changes: 8 additions & 8 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,13 @@ def _get_dummy_mm_inputs(
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens",
"model.vision_embed_tokens.": "vision_embed_tokens.",
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -616,17 +623,10 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens",
"model.vision_embed_tokens.": "vision_embed_tokens.",
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
})

loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights,
mapper=hf_to_vllm_mapper)
mapper=self.hf_to_vllm_mapper)

# The HF config doesn't specify whether these are tied,
# so we detect it this way
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {}
embedding_padding_modules = []

hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -577,8 +579,7 @@ def pooler(
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = self.hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
27 changes: 14 additions & 13 deletions vllm/model_executor/models/telechat2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@

class TeleChat2Model(LlamaModel):

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"transformer.": "model.",
},
orig_to_new_substr={
".h.": ".layers.",
".self_attention.": ".self_attn.",
".word_embeddings.": ".embed_tokens.",
".dense.": ".o_proj.",
".ln_f.": ".norm.",
},
)

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# 1. Initialize the LlamaModel with bias
vllm_config.model_config.hf_config.bias = True
Expand Down Expand Up @@ -111,21 +124,9 @@ def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"transformer.": "model.",
},
orig_to_new_substr={
".h.": ".layers.",
".self_attention.": ".self_attn.",
".word_embeddings.": ".embed_tokens.",
".dense.": ".o_proj.",
".ln_f.": ".norm.",
},
)
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
7 changes: 4 additions & 3 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ def forward(
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -494,9 +497,7 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})

loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["audio_tower."])
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

0 comments on commit 196c34b

Please sign in to comment.