Skip to content

Commit

Permalink
[Bugfix] Fix Qwen2-VL LoRA weight loading (#11430)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Dec 24, 2024
1 parent 9edca6b commit b1b1038
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 14 deletions.
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


@pytest.fixture(scope="session")
def qwen2vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")


@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
Expand Down
30 changes: 30 additions & 0 deletions tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from vllm.lora.models import LoRAModel
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.utils import WeightsMapper

lora_lst = [
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
Expand Down Expand Up @@ -71,3 +72,32 @@ def test_load_checkpoints(
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)


def test_lora_weights_mapping(baichuan_lora_files, ):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)

hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"model.": "language_model.model.",
}, )

lora_model = LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules,
weights_mapper=hf_to_vllm_mapper,
)
for name in lora_model.loras:
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
78 changes: 78 additions & 0 deletions tests/lora/test_qwen2vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import List

import pytest

import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform

MODEL_PATH = "Qwen/Qwen2-VL-7B-Instruct"

PROMPT_TEMPLATE = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
"\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
"What is in the image?<|im_end|>\n"
"<|im_start|>assistant\n")

IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]

# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
]


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=5,
)

inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]

outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


@pytest.mark.xfail(current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm"
)
def test_qwen2vl_lora(qwen2vl_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_num_seqs=2,
enable_lora=True,
max_loras=2,
max_lora_rank=16,
trust_remote_code=True,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
},
max_model_len=4096,
)
output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])
9 changes: 6 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.utils import is_pin_memory_available

logger = init_logger(__name__)
Expand Down Expand Up @@ -113,13 +113,14 @@ def from_lora_tensors(
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
tensor_name)
tensor_name, weights_mapper)
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
Expand Down Expand Up @@ -187,6 +188,7 @@ def from_local_checkpoint(
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
Expand Down Expand Up @@ -289,7 +291,8 @@ def from_local_checkpoint(
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules)
embedding_padding_modules=embedding_padding_modules,
weights_mapper=weights_mapper)


class LoRAModelManager(AdapterModelManager):
Expand Down
37 changes: 33 additions & 4 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
import re
from typing import List, Optional, Set, Tuple, Type, Union
Expand Down Expand Up @@ -30,6 +31,8 @@
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import WeightsMapper
from vllm.utils import print_warning_once

logger = init_logger(__name__)

Expand Down Expand Up @@ -91,28 +94,54 @@ def replace_submodule(model: nn.Module, module_name: str,
return new_module


def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]:
def parse_fine_tuned_lora_name(
name: str,
weights_mapper: Optional[WeightsMapper] = None
) -> Tuple[str, bool, bool]:
"""Parse the name of lora weights.
args:
name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight
weights_mapper: maps the name of weight, e.g.
`model.` -> `language_model.model.`,
return:
Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias.
"""

w_mapper = None
if weights_mapper:
w_mapper = copy.deepcopy(weights_mapper)
# TODO: Currently only supports mapping for prefix, mapping for
# substr and subfix will be supported in the future.
for attr, mapping in [
("orig_to_new_substr", w_mapper.orig_to_new_substr),
("orig_to_new_suffix", w_mapper.orig_to_new_suffix),
]:
if mapping:
print_warning_once(
f"vLLM currently does not support mapping of LoRA weights "
f"for {mapping}.")
setattr(w_mapper, attr, {})

mapper = (lambda name: w_mapper._map_name(name)
if w_mapper is not None else name)
parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A"
or parts[-2] == "lora_B"):
return ".".join(parts[2:-2]), parts[-2] == "lora_A", False
new_name = ".".join(parts[2:-2])
return mapper(new_name), parts[-2] == "lora_A", False

if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False
new_name = ".".join(parts[2:-1])
return mapper(new_name), parts[-1] == "lora_embedding_A", False

if parts[-1] == "bias":
return ".".join(parts[2:-2]), False, True
new_name = ".".join(parts[2:-2])
return mapper(new_name), False, True

raise ValueError(f"{name} is unsupported LoRA weight")

Expand Down
11 changes: 10 additions & 1 deletion vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
else:
expected_lora_modules.append(module)
lora_path = get_adapter_absolute_path(lora_request.lora_path)

# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
hf_to_vllm_mapper = None
if (hasattr(model, "hf_to_vllm_mapper")
and model.hf_to_vllm_mapper is not None):
hf_to_vllm_mapper = model.hf_to_vllm_mapper

lora = self._lora_model_cls.from_local_checkpoint(
lora_path,
expected_lora_modules,
Expand All @@ -103,7 +111,8 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
self.lora_config.lora_extra_vocab_size,
embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules,
)
weights_mapper=hf_to_vllm_mapper)

except Exception as e:
raise RuntimeError(f"Loading lora {lora_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank:
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
]
embedding_modules = {}
embedding_padding_modules = []
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -1190,11 +1195,6 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
})

loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

0 comments on commit b1b1038

Please sign in to comment.