-
Notifications
You must be signed in to change notification settings - Fork 32
/
run_inference.py
79 lines (70 loc) · 3.23 KB
/
run_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import bitsandbytes as bnb
from functools import partial
from peft import AutoPeftModelForCausalLM
import torch
from transformers import AutoTokenizer
import sys
import argparse
def main():
parser = argparse.ArgumentParser(description="Script to run inference on a tuned model.")
parser.add_argument('model_name', type=str,
help='The name of the tuned model that you pushed to Huggingface after finetuning or DPO.')
parser.add_argument('instruction_prompt', type=str,
help='An instruction message added to every prompt given to the chatbot to force it to answer in the target language.')
parser.add_argument('--cpu', action='store_true',
help="Forces usage of CPU. By default GPU is taken if available.")
parser.add_argument('--thread_template', type=str, default="threads/template_default.txt",
help='A file containing the thread template to use. Default is threads/template_fefault.txt')
parser.add_argument('--padding', type=str, default="left",
help='What padding to use, can be either left or right.')
args = parser.parse_args()
model_name = args.model_name
instruction_prompt = args.instruction_prompt
thread_template_file = args.thread_template
force_cpu = args.cpu
device = torch.device("cuda:0" if torch.cuda.is_available() and not (force_cpu) else "cpu")
padding = args.padding
# Get the template
with open(thread_template_file, 'r', encoding="utf8") as f:
chat_template = f.read()
# Load the model and merge with base
model = AutoPeftModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)
model = model.merge_and_unload()
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
if padding == 'left':
tokenizer.pad_token_id = 0
else:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = padding
thread = [
{'role': 'system', 'content': instruction_prompt}
]
while True:
user_input = input("Enter your input, use ':n' for a new thread or ':q' to quit: ")
if user_input.lower() == ':q':
break
elif user_input.lower() == ':n':
thread = [{'role': 'system', 'content': instruction_prompt}]
continue
# Prepare input in LLaMa3 chat format
thread.append({
'role': 'user', 'content': user_input
})
input_chat = tokenizer.apply_chat_template(thread, tokenize=False, chat_template=chat_template)
inputs = tokenizer(input_chat, return_tensors="pt").to(device)
# Generate response and decode
output_sequences = model.generate(
input_ids=inputs['input_ids'],
max_length=200,
repetition_penalty=1.2 # LLaMa3 is sensitive to repetition
)
generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
print(generated_text)
# Get the answer only
answer = generated_text[(len(input_chat)-len(tokenizer.bos_token)+1):]
thread.append({
'role': 'assistant', 'content': answer
})
if __name__ == "__main__":
main()