-
Notifications
You must be signed in to change notification settings - Fork 84
/
finetune.py
198 lines (170 loc) · 7.21 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
llama-4b trainer with support of Stanford Alpaca-like JSON datasets (short for SAD)
Intended to use with https://github.com/johnsmith0031/alpaca_lora_4bit
SAD structure:
[
{
"instruction": "Give null hypothesis",
"input": "6 subjects were given a drug (treatment group) and an additional 6 subjects a placebo (control group).",
"output": "Drug is equivalent of placebo"
},
{
"instruction": "What does RNA stand for?",
"input": "",
"output": "RNA stands for ribonucleic acid."
}
]
"""
import os
import sys
# Early load config to replace attn if needed
from alpaca_lora_4bit.arg_parser import get_config
ft_config = get_config()
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
replace_peft_model_with_int4_lora_model()
if ft_config.flash_attention:
from alpaca_lora_4bit.monkeypatch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
elif ft_config.xformers:
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
hijack_llama_attention()
from alpaca_lora_4bit import autograd_4bit
if ft_config.backend.lower() == 'triton':
autograd_4bit.switch_backend_to('triton')
else:
autograd_4bit.switch_backend_to('cuda')
import sys
import os
import peft
import peft.tuners.lora
import wandb
import torch
import transformers
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from peft import LoraConfig, get_peft_model, PeftModel
# ! Config
from alpaca_lora_4bit import train_data
# * Show loaded parameters
if ft_config.local_rank == 0:
print(f"{ft_config}\n")
if ft_config.gradient_checkpointing:
print('Disable Dropout.')
if ft_config.mbatch_size > ft_config.batch_size:
raise Exception('batch_size need to be larger than mbatch_size.')
# Load Basic Model
model, tokenizer = load_llama_model_4bit_low_ram(ft_config.llama_q4_config_dir,
ft_config.llama_q4_model,
device_map=ft_config.device_map,
groupsize=ft_config.groupsize,
is_v1_model=ft_config.v1)
# Config Lora
lora_config = LoraConfig(
r=ft_config.lora_r,
lora_alpha=ft_config.lora_alpha,
target_modules=["q_proj", "v_proj"],
lora_dropout=ft_config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
if ft_config.lora_apply_dir is None:
model = get_peft_model(model, lora_config)
else:
device_map = ft_config.device_map
if ft_config.ddp:
device_map = {'': 0}
else:
if torch.cuda.device_count() > 1:
device_map = "auto"
else:
device_map = {'': 0}
print('Device map for lora:', device_map)
model = PeftModel.from_pretrained(model, ft_config.lora_apply_dir, device_map=device_map, torch_dtype=torch.float32, is_trainable=True)
print(ft_config.lora_apply_dir, 'loaded')
# Scales to half
print('Fitting 4bit scales and zeros to half')
for n, m in model.named_modules():
if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)):
if hasattr(m, "is_v1_model") and m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
# Set tokenizer
tokenizer.pad_token_id = 0
if not ft_config.skip:
# Load Data
data = None
if ft_config.ds_type == "txt" and not ft_config.skip:
#### LLaMa
data = train_data.TrainTxt(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
elif ft_config.ds_type == "alpaca" and not ft_config.skip:
#### Stanford Alpaca-like Data
data = train_data.TrainSAD(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
elif ft_config.ds_type == "gpt4all" and not ft_config.skip:
#### GPT4All Data
data = train_data.TrainGPT4All(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
elif ft_config.ds_type == "bluemoon" and not ft_config.skip:
#### Blue Moon Data
data = train_data.TrainBlueMoon(ft_config.dataset, ft_config.val_set_size, tokenizer, ft_config.cutoff_len)
else:
raise NotImplementedError("ERROR: Unknown dataset format")
data.prepare_data(thd=ft_config.txt_row_thd, use_eos_token=ft_config.use_eos_token)
####
# Use gradient checkpointing
if ft_config.gradient_checkpointing:
print('Applying gradient checkpointing ...')
from alpaca_lora_4bit.gradient_checkpointing import apply_gradient_checkpointing
apply_gradient_checkpointing(model, checkpoint_ratio=ft_config.gradient_checkpointing_ratio)
# Disable Trainer's DataParallel for multigpu
if not ft_config.ddp and torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
# Count eval count for wandb
if ft_config.val_set_size > 0:
eval_count = 10
eval_steps = max(
ft_config.logging_steps, (len(data.train_data) + len(data.val_data)) // (eval_count*ft_config.mbatch_size)
)
print(f"Run eval every {eval_steps} steps")
else:
eval_steps = 0
training_arguments = transformers.TrainingArguments(
per_device_train_batch_size=ft_config.mbatch_size,
gradient_accumulation_steps=ft_config.gradient_accumulation_steps,
warmup_steps=ft_config.warmup_steps,
optim="adamw_torch",
num_train_epochs=ft_config.epochs,
learning_rate=ft_config.lr,
fp16=True,
logging_steps=ft_config.logging_steps,
evaluation_strategy="steps" if eval_steps != 0 else "no",
save_strategy="steps",
eval_steps=eval_steps if eval_steps != 0 else None,
save_steps=ft_config.save_steps,
output_dir=ft_config.lora_out_dir,
save_total_limit=ft_config.save_total_limit,
load_best_model_at_end=False,
ddp_find_unused_parameters=False if ft_config.ddp else None,
)
trainer = transformers.Trainer(
model=model,
train_dataset=data.train_data,
eval_dataset=data.val_data,
args=training_arguments,
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
# Set Verbose
if ft_config.verbose:
transformers.logging.set_verbosity_info()
# Run Trainer
with wandb.init(project="alpaca_lora_4bit") as run:
if ft_config.resume_checkpoint:
print('Resuming from {} ...'.format(ft_config.resume_checkpoint))
trainer.train(resume_from_checkpoint=ft_config.resume_checkpoint)
else:
trainer.train()
print('Train completed.')
# Save Model
model.save_pretrained(ft_config.lora_out_dir)
if ft_config.checkpoint:
print("Warning: Merge model + LoRA and save the whole checkpoint not implemented yet.")
print('Model Saved.')