diff --git a/.gitignore b/.gitignore index 0d0fcb729..483af7deb 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,5 @@ dmypy.json # Cython debug symbols cython_debug/ +# Images +images/ diff --git a/README.md b/README.md index 2a1e2123e..948fc9ff9 100644 --- a/README.md +++ b/README.md @@ -42,10 +42,10 @@ **Supported PyTorch / HF Models** -| [**IDEFICS-9B-Instruct**๐ŸŽž๏ธ๐Ÿš…](https://huggingface.co/HuggingFaceM4/idefics-9b-instruct), [**IDEFICS-80B-Instruct**๐ŸŽž๏ธ๐Ÿš…](https://huggingface.co/HuggingFaceM4/idefics-80b-instruct) | [**InstructBLIP-[7B/13B]**](https://github.com/salesforce/LAVIS/blob/main/projects/instructblip/README.md) | [**LLaVA-[v1-7B/v1.5-7B/v1.5-13B]**](https://github.com/haotian-liu/LLaVA) | [**MiniGPT-4-[v1-7B/v1-13B/v2-7B]**](https://github.com/Vision-CAIR/MiniGPT-4) | [**mPLUG-Owl2**๐ŸŽž๏ธ](https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2) | +| [**IDEFICS-9B-Instruct**](https://huggingface.co/HuggingFaceM4/idefics-9b-instruct)๐ŸŽž๏ธ๐Ÿš…, [**IDEFICS-80B-Instruct**](https://huggingface.co/HuggingFaceM4/idefics-80b-instruct)๐ŸŽž๏ธ๐Ÿš… | [**InstructBLIP-[7B/13B]**](https://github.com/salesforce/LAVIS/blob/main/projects/instructblip/README.md) | [**LLaVA-[v1-7B/v1.5-7B/v1.5-13B]**](https://github.com/haotian-liu/LLaVA) | [**MiniGPT-4-[v1-7B/v1-13B/v2-7B]**](https://github.com/Vision-CAIR/MiniGPT-4) | [**mPLUG-Owl2**](https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2)๐ŸŽž๏ธ | | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | -| [**OpenFlamingo-v2**](https://github.com/mlfoundations/open_flamingo)๐ŸŽž๏ธ | [**PandaGPT-13B**](https://github.com/yxuansu/PandaGPT) | [**Qwen-VL**๐ŸŽž๏ธ๐Ÿš…](https://huggingface.co/Qwen/Qwen-VL), [**Qwen-VL-Chat**๐ŸŽž๏ธ๐Ÿš…](https://huggingface.co/Qwen/Qwen-VL-Chat) | [**VisualGLM-6B**๐Ÿš…](https://huggingface.co/THUDM/visualglm-6b) | [**InternLM-XComposer-7B**๐ŸŽž๏ธ๐Ÿš…](https://huggingface.co/internlm/internlm-xcomposer-7b) | -| [**ShareGPT4V-7B**๐Ÿš…](https://sharegpt4v.github.io) | [**TransCore-M**](https://github.com/PCIResearch/TransCore-M) | | | | +| [**OpenFlamingo-v2**](https://github.com/mlfoundations/open_flamingo)๐ŸŽž๏ธ | [**PandaGPT-13B**](https://github.com/yxuansu/PandaGPT) | [**Qwen-VL**](https://huggingface.co/Qwen/Qwen-VL)๐ŸŽž๏ธ๐Ÿš…, [**Qwen-VL-Chat**](https://huggingface.co/Qwen/Qwen-VL-Chat)๐ŸŽž๏ธ๐Ÿš… | [**VisualGLM-6B**](https://huggingface.co/THUDM/visualglm-6b)๐Ÿš… | [**InternLM-XComposer-7B**](https://huggingface.co/internlm/internlm-xcomposer-7b)๐ŸŽž๏ธ๐Ÿš… | +| [**ShareGPT4V-7B**](https://sharegpt4v.github.io)๐Ÿš… | [**TransCore-M**](https://github.com/PCIResearch/TransCore-M) | [**LLaVA (XTuner)**](https://huggingface.co/xtuner/llava-internlm-7b)๐Ÿš… | | | ๐ŸŽž๏ธ: Support multiple images as inputs, via the `multi_generate` interface. @@ -83,7 +83,7 @@ pip install -e . Following VLMs require the configuration step: -**Code Preparation & Installation**: InstructBLIP ([LAVIS](https://github.com/salesforce/LAVIS)), LLaVA ([LLaVA](https://github.com/haotian-liu/LLaVA)), MiniGPT-4 ([MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4)), mPLUG-Owl2 ([mPLUG-Owl2](https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2)), OpenFlamingo-v2 ([OpenFlamingo](https://github.com/mlfoundations/open_flamingo)), PandaGPT-13B ([PandaGPT](https://github.com/yxuansu/PandaGPT)), TransCore-M ([TransCore-M](https://github.com/PCIResearch/TransCore-M)). +**Code Preparation & Installation**: InstructBLIP ([LAVIS](https://github.com/salesforce/LAVIS)), LLaVA ([LLaVA](https://github.com/haotian-liu/LLaVA)), MiniGPT-4 ([MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4)), mPLUG-Owl2 ([mPLUG-Owl2](https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2)), OpenFlamingo-v2 ([OpenFlamingo](https://github.com/mlfoundations/open_flamingo)), PandaGPT-13B ([PandaGPT](https://github.com/yxuansu/PandaGPT)), TransCore-M ([TransCore-M](https://github.com/PCIResearch/TransCore-M)), LLaVA-XTuner ([XTuner](https://github.com/InternLM/xtuner)). **Manual Weight Preparation & Configuration**: InstructBLIP, LLaVA-v1-7B, MiniGPT-4, OpenFlamingo-v2, PandaGPT-13B diff --git a/vlmeval/config.py b/vlmeval/config.py index 6dfd9a3b6..7d47bc0c1 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -45,5 +45,8 @@ 'GPT4V': partial(GPT4V, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low'), 'GPT4V_INT': partial(GPT4V_Internal, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10), 'GeminiProVision': partial(GeminiProVision, temperature=0, retry=10), - 'QwenVLPlus': partial(QwenVLPlus, temperature=0, retry=10) + 'QwenVLPlus': partial(QwenVLPlus, temperature=0, retry=10), + 'llava-internlm-7b': partial(LLaVA_XTuner, llm_path='internlm/internlm-chat-7b', llava_path='xtuner/llava-internlm-7b', visual_encoder_path='openai/clip-vit-large-patch14-336', visual_select_layer=-2, prompt_template='internlm_chat'), + 'llava-v1.5-7b-xtuner': partial(LLaVA_XTuner, llm_path='lmsys/vicuna-7b-v1.5', llava_path='xtuner/llava-v1.5-7b-xtuner', visual_encoder_path='openai/clip-vit-large-patch14-336', visual_select_layer=-2, prompt_template='vicuna'), + 'llava-v1.5-13b-xtuner': partial(LLaVA_XTuner, llm_path='lmsys/vicuna-13b-v1.5', llava_path='xtuner/llava-v1.5-13b-xtuner', visual_encoder_path='openai/clip-vit-large-patch14-336', visual_select_layer=-2, prompt_template='vicuna'), } diff --git a/vlmeval/vlm/__init__.py b/vlmeval/vlm/__init__.py index d6fb17149..761fae8a5 100644 --- a/vlmeval/vlm/__init__.py +++ b/vlmeval/vlm/__init__.py @@ -12,3 +12,4 @@ from .minigpt4 import MiniGPT4 from .xcomposer import XComposer from .mplug_owl2 import mPLUG_Owl2 +from .llava_xtuner import LLaVA_XTuner diff --git a/vlmeval/vlm/llava.py b/vlmeval/vlm/llava.py index 4d7540428..2b599a457 100644 --- a/vlmeval/vlm/llava.py +++ b/vlmeval/vlm/llava.py @@ -64,7 +64,7 @@ def build_prompt(self, line, dataset=None): question = line['question'] hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None if hint is not None: - question + hint + '\n' + question + question = hint + '\n' + question options = { cand: line[cand] @@ -107,4 +107,4 @@ def generate(self, image_path, prompt, dataset=None): with torch.inference_mode(): output_ids = self.model.generate(input_ids, images=image_tensor, stopping_criteria=[stopping_criteria], **self.kwargs) output = self.tokenizer.decode(output_ids[0, input_ids.shape[1]: ]).strip().split("")[0] - return output \ No newline at end of file + return output diff --git a/vlmeval/vlm/llava_xtuner.py b/vlmeval/vlm/llava_xtuner.py new file mode 100644 index 000000000..34668df50 --- /dev/null +++ b/vlmeval/vlm/llava_xtuner.py @@ -0,0 +1,213 @@ +import os +import os.path as osp +import string +import warnings + +import pandas as pd +import torch +from huggingface_hub import snapshot_download +from PIL import Image +from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, + CLIPImageProcessor, CLIPVisionModel, + GenerationConfig) + +from ..smp import cn_string, get_cache_path +from ..utils import DATASET_TYPE, CustomPrompt + + +class LLaVA_XTuner(CustomPrompt): + + INSTALL_REQ = True + + def __init__(self, + llava_path, + llm_path=None, + visual_encoder_path=None, + visual_select_layer=-2, + prompt_template=None, + torch_dtype=torch.float16): + try: + from peft import PeftModel + from xtuner.tools.utils import get_chat_utils + from xtuner.utils import PROMPT_TEMPLATE + except Exception: + warnings.warn( + 'Please install xtuner with `pip install -U xtuner` before ' + 'using LLaVA_XTuner') + exit(-1) + + if not osp.isdir(llava_path): + cache_path = get_cache_path(llava_path) + if cache_path is not None: + llava_path = cache_path + else: + llava_path = snapshot_download(repo_id=llava_path) + assert osp.exists(llava_path) and osp.isdir(llava_path) + + # build visual_encoder + if 'llm' in os.listdir(llava_path): + assert llm_path is None, ( + "Please don't specify the `llm_path` since passed " + '`llava_path` contains a LLM!') + llm_path = osp.join(llava_path, 'llm') + else: + assert llm_path is not None, 'Please specify the `llm_path`!' + + llm = AutoModelForCausalLM.from_pretrained(llm_path, + trust_remote_code=True, + torch_dtype=torch_dtype, + device_map='cpu') + tokenizer = AutoTokenizer.from_pretrained(llm_path, + trust_remote_code=True, + encode_special_tokens=True) + print(f'Load LLM from {llm_path}') + + # build visual_encoder + if 'visual_encoder' in os.listdir(llava_path): + assert visual_encoder_path is None, ( + "Please don't specify the `visual_encoder_path` since passed " + '`llava_path` contains a visual encoder!') + visual_encoder_path = osp.join(llava_path, 'visual_encoder') + else: + assert visual_encoder_path is not None, ( + 'Please specify the `visual_encoder_path`!') + visual_encoder = CLIPVisionModel.from_pretrained( + visual_encoder_path, torch_dtype=torch_dtype, device_map='cpu') + image_processor = CLIPImageProcessor.from_pretrained( + visual_encoder_path) + print(f'Load visual_encoder from {visual_encoder_path}') + + # load adapter + if 'llm_adapter' in os.listdir(llava_path): + adapter_path = osp.join(llava_path, 'llm_adapter') + llm = PeftModel.from_pretrained(llm, + adapter_path, + device_map='cpu') + print(f'Load LLM adapter from {llava_path}') + if 'visual_encoder_adapter' in os.listdir(llava_path): + adapter_path = osp.join(llava_path, 'visual_encoder_adapter') + visual_encoder = PeftModel.from_pretrained(visual_encoder, + adapter_path, + device_map='cpu') + print(f'Load visual_encoder adapter from {llava_path}') + + # build projector + projector_path = osp.join(llava_path, 'projector') + projector = AutoModel.from_pretrained(projector_path, + torch_dtype=torch_dtype, + device_map='cpu') + print(f'Load projector from {llava_path}') + + llm.eval() + visual_encoder.eval() + projector.eval() + + self.llm = llm.cuda() + self.tokenizer = tokenizer + self.visual_encoder = visual_encoder.cuda() + self.image_processor = image_processor + self.projector = projector.cuda() + self.visual_select_layer = visual_select_layer + if prompt_template is not None: + self.prompt_template = PROMPT_TEMPLATE[prompt_template] + else: + self.prompt_template = None + + _, self.stop_criteria = get_chat_utils(self.llm) + + def build_gen_config(self, dataset): + gen_kwargs = dict(max_new_tokens=1024, + do_sample=True, + temperature=1, + num_beams=5, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id + if self.tokenizer.pad_token_id is not None else + self.tokenizer.eos_token_id) + # For single word generation + if (dataset is not None + and DATASET_TYPE(dataset) in ['multi-choice', 'Y/N']): + gen_kwargs.update( + dict(max_new_tokens=5, do_sample=False, num_beams=1)) + return GenerationConfig(**gen_kwargs) + + def use_custom_prompt(self, dataset): + assert dataset is not None + if DATASET_TYPE(dataset) == 'multi-choice': + return True + return False + + def build_prompt(self, line, dataset=None): + assert self.use_custom_prompt(dataset) + assert dataset is None or isinstance(dataset, str) + tgt_path = self.dump_image(line, dataset) + + question = line['question'] + hint = line['hint'] if ('hint' in line + and not pd.isna(line['hint'])) else None + if hint is not None: + question = hint + '\n' + question + + options = { + cand: line[cand] + for cand in string.ascii_uppercase + if cand in line and not pd.isna(line[cand]) + } + for key, item in options.items(): + question += f'\n{key}. {item}' + + if not cn_string(question): + prompt = question + '\n' + ("Answer with the option's letter " + 'from the given choices directly.') + else: + prompt = question + '\n' + '่ฏท็›ดๆŽฅๅ›ž็ญ”้€‰้กนๅญ—ๆฏใ€‚' + + return {'image': tgt_path, 'text': prompt} + + def generate(self, image_path, prompt, dataset=None): + from xtuner.dataset.utils import expand2square + from xtuner.model.utils import prepare_inputs_labels_for_multimodal + from xtuner.utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX + image = Image.open(image_path).convert('RGB') + image = expand2square( + image, + tuple(int(x * 255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + image = image.cuda().unsqueeze(0) + visual_outputs = self.visual_encoder(image, output_hidden_states=True) + pixel_values = self.projector( + visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) + + inputs = DEFAULT_IMAGE_TOKEN + '\n' + prompt + + if self.prompt_template: + inputs = self.prompt_template['INSTRUCTION'].format(input=inputs) + + chunk_encode = [] + for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): + if idx == 0: + cur_encode = self.tokenizer(chunk) + else: + cur_encode = self.tokenizer(chunk, add_special_tokens=False) + chunk_encode.append(cur_encode) + assert len(chunk_encode) == 2 + ids = [] + for idx, cur_chunk_encode in enumerate(chunk_encode): + ids.extend(cur_chunk_encode['input_ids']) + if idx != len(chunk_encode) - 1: + ids.append(IMAGE_TOKEN_INDEX) + ids = torch.tensor(ids).cuda().unsqueeze(0) + mm_inputs = prepare_inputs_labels_for_multimodal( + llm=self.llm, input_ids=ids, pixel_values=pixel_values) + + gen_config = self.build_gen_config(dataset) + generate_output = self.llm.generate( + **mm_inputs, + generation_config=gen_config, + streamer=None, + bos_token_id=self.tokenizer.bos_token_id, + stopping_criteria=self.stop_criteria) + predict = self.tokenizer.decode(generate_output[0], + skip_special_tokens=True).strip() + return predict