Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix Qwen2-VL LoRA weight loading #11430

Merged
merged 9 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


@pytest.fixture(scope="session")
def qwen2vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")


@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
Expand Down
30 changes: 30 additions & 0 deletions tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from vllm.lora.models import LoRAModel
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.utils import WeightsMapper

lora_lst = [
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
Expand Down Expand Up @@ -71,3 +72,32 @@ def test_load_checkpoints(
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)


def test_lora_weights_mapping(baichuan_lora_files, ):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)

hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"model.": "language_model.model.",
}, )

lora_model = LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules,
weights_mapper=hf_to_vllm_mapper,
)
for name in lora_model.loras:
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
81 changes: 81 additions & 0 deletions tests/lora/test_qwen2vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import List

import pytest

import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform

MODEL_PATH = "Qwen/Qwen2-VL-7B-Instruct"

PROMPT_TEMPLATE = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
"\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
"What is in the image?<|im_end|>\n"
"<|im_start|>assistant\n")

IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]

# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
]


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=5,
)

inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]

outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


@pytest.mark.xfail(current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm"
)
def test_qwen2vl_lora(qwen2vl_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_num_seqs=2,
enable_lora=True,
max_loras=2,
max_lora_rank=16,
trust_remote_code=True,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
},
max_model_len=4096,
)
output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])
output2 = do_sample(llm, qwen2vl_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output2[i])
9 changes: 6 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.utils import is_pin_memory_available

logger = init_logger(__name__)
Expand Down Expand Up @@ -113,13 +113,14 @@ def from_lora_tensors(
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
tensor_name)
tensor_name, weights_mapper)
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
Expand Down Expand Up @@ -187,6 +188,7 @@ def from_local_checkpoint(
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.

Expand Down Expand Up @@ -289,7 +291,8 @@ def from_local_checkpoint(
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules)
embedding_padding_modules=embedding_padding_modules,
weights_mapper=weights_mapper)


class LoRAModelManager(AdapterModelManager):
Expand Down
25 changes: 21 additions & 4 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import WeightsMapper

logger = init_logger(__name__)

Expand Down Expand Up @@ -91,28 +92,44 @@ def replace_submodule(model: nn.Module, module_name: str,
return new_module


def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]:
def parse_fine_tuned_lora_name(
name: str,
weights_mapper: Optional[WeightsMapper] = None
) -> Tuple[str, bool, bool]:
"""Parse the name of lora weights.

args:
name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight
weights_mapper: maps the name of weight, e.g.
`model.` -> `language_model.model.`,
return:
Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias.
"""
# TODO: Currently only supports mapping for prefix, mapping for substr and
# subfix will be supported in the future.
if weights_mapper is not None:
weights_mapper.orig_to_new_substr = {}
weights_mapper.orig_to_new_suffix = {}
jeejeelee marked this conversation as resolved.
Show resolved Hide resolved

mapper = (lambda name: weights_mapper._map_name(name)
if weights_mapper is not None else name)
parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A"
or parts[-2] == "lora_B"):
return ".".join(parts[2:-2]), parts[-2] == "lora_A", False
new_name = ".".join(parts[2:-2])
return mapper(new_name), parts[-2] == "lora_A", False

if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False
new_name = ".".join(parts[2:-1])
return mapper(new_name), parts[-1] == "lora_embedding_A", False

if parts[-1] == "bias":
return ".".join(parts[2:-2]), False, True
new_name = ".".join(parts[2:-2])
return mapper(new_name), False, True

raise ValueError(f"{name} is unsupported LoRA weight")

Expand Down
11 changes: 10 additions & 1 deletion vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
else:
expected_lora_modules.append(module)
lora_path = get_adapter_absolute_path(lora_request.lora_path)

# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
hf_to_vllm_mapper = None
if (hasattr(model, "hf_to_vllm_mapper")
and model.hf_to_vllm_mapper is not None):
hf_to_vllm_mapper = model.hf_to_vllm_mapper

lora = self._lora_model_cls.from_local_checkpoint(
lora_path,
expected_lora_modules,
Expand All @@ -103,7 +111,8 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
self.lora_config.lora_extra_vocab_size,
embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules,
)
weights_mapper=hf_to_vllm_mapper)

except Exception as e:
raise RuntimeError(f"Loading lora {lora_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank:
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
]
embedding_modules = {}
embedding_padding_modules = []
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -1190,11 +1195,6 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
})

loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
Loading