Skip to content

Latest commit

 

History

History
68 lines (46 loc) · 4.26 KB

README.md

File metadata and controls

68 lines (46 loc) · 4.26 KB

(Factually Augmented) RL from Human Feedback

This RLHF codebase is mainly adapted from the SALMON codebase, which is adapted from AlpacaFarm and QLoRA.

0. Setup

Please refer to llava_setup for instructions on how to set up the customized llava package.

Additionally, you should run the following command to make sure the versions of some essential packages are correct:

pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
pip install deepspeed==0.9.3
pip install peft==0.4.0
pip install transformers==4.31.0
pip install bitsandbytes==0.41.0
pip install datasets

Note: please install Pytorch 2.0.1 following the guidelines here. We found that the flash-attention implementation in the newest Pytorch Stable (2.1.0) could lead to buggy results. The codebase is tested with torch==2.0.1+cu118.

1. Training the Instruction-Following Reward Model

We first train an instruction-following reward model based on the following judging creteria:

1. Accurate: The AI should provide factual and accurate information from the image, and refrain from making statements that are not supported by the image or inconsistent with the image.
2. Helpful: The AI’s response should precisely serve the user's needs and interests, while grounding the response in the image.
3. Language Natural: The AI should employ language that flows smoothly and is free from repetitive or awkward constructs.
4. Concise: The AI should efficiently address the task or answer the question, communicating the necessary information with brevity and clarity.

After downloading the SFT model checkpoint from LLaVA-RLHF-13b-v1.5-336, the human preference data from LLaVA-Human-Preference-10K, and the image captions from LLaVA-RLHF-Data/image_to_caption.json, you can run the training script for the reward model:

bash scripts/13b-v1.5-336/train_reward_model.sh

Note: For both 7b and 13b policy models, we use the same 13b reward model. We also provide the pretrained reward model checkpoint at LLaVA-RLHF-13b-v1.5-336/rm_lora_adapter_model. To use the pretrained LoRA checkpoint, the base_model_name_or_path in adapter_config.json need to be modified to the actual path of the SFT model.

2. Initialize the RL Model

We initialize the LoRA weights of the policy model by fine-tuning the SFT model for one epoch on the combination of:

  1. Our preference modeling split of the LLaVA data (10k)
  2. A-OKVQA in the CoT format (5k)

We provide the processed data in LLaVA-RLHF-Data/llava_reward10k-aokvqa5k.json. After downloading the data (and potentially the 7b SFT model checkpoint from LLaVA-RLHF-7b-v1.5-224), you can run the following script to initialize the policy model:

bash scripts/7b-v1.5-224/initialize_policy_model.sh
bash scripts/13b-v1.5-336/initialize_policy_model.sh

3. Training the RL Model with PPO

The PPO training of the policy model is based on the prompt combination of:

  1. Our RL split of the LLaVA data (50k)
  2. A-OKVQA in the CoT format (12k)
  3. Yes/No Questions from VQA-v2 (10k)

We provide the processed data in LLaVA-RLHF-Data/llava_ppo50k-aokvqa12k-vqa10k.json. After downloading the data, you can run the following script to train the RL model:

bash scripts/7b-v1.5-224/train_rl_model.sh
bash scripts/13b-v1.5-336/train_rl_model.sh