Skip to content

Commit

Permalink
update test.py code
Browse files Browse the repository at this point in the history
  • Loading branch information
zehuichen123 committed Jan 4, 2024
1 parent fd70415 commit 8e54050
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,24 @@ def infer(dataset, llm, out_dir, tmp_folder_name='tmp', test_num = -1):
os.makedirs(args.out_dir, exist_ok=True)
tmp_folder_name = os.path.splitext(args.out_name)[0]
os.makedirs(os.path.join(args.out_dir, tmp_folder_name), exist_ok=True)
# if args.model_type.startswith('gpt'):
# # if you want to use GPT, please refer to lagent for how to pass your key to GPTAPI class
# llm = GPTAPI(args.model_type)
# # elif args.model_type.startswith('claude'):
# # llm = ClaudeAPI(args.model_type)
# elif args.model_type == 'hf':
# meta_template = meta_template_dict.get(args.meta_template)
# llm = HFTransformerCasualLM(args.hf_path, meta_template=meta_template)
if args.model_type.startswith('gpt'):
# if you want to use GPT, please refer to lagent for how to pass your key to GPTAPI class
llm = GPTAPI(args.model_type)
# elif args.model_type.startswith('claude'):
# llm = ClaudeAPI(args.model_type)
elif args.model_type == 'hf':
meta_template = meta_template_dict.get(args.meta_template)
llm = HFTransformerCasualLM(args.hf_path, meta_template=meta_template)
dataset, tested_num, total_num = load_dataset(args.dataset_path, args.out_dir, args.resume, tmp_folder_name=tmp_folder_name)
if args.test_num == -1:
test_num = max(total_num - tested_num, 0)
else:
test_num = max(min(args.test_num - tested_num, total_num - tested_num), 0)
print(f"Tested {tested_num} samples, left {test_num} samples, total {total_num} samples")
# prediction = infer(dataset, llm, args.out_dir, tmp_folder_name=tmp_folder_name, test_num=test_num)
prediction = infer(dataset, llm, args.out_dir, tmp_folder_name=tmp_folder_name, test_num=test_num)
# dump prediction to out_dir
output_file_path = os.path.join(args.out_dir, args.out_name)
# mmengine.dump(prediction, os.path.join(args.out_dir, args.out_name))
mmengine.dump(prediction, os.path.join(args.out_dir, args.out_name))

if args.eval:
if args.model_display_name == "":
Expand Down

0 comments on commit 8e54050

Please sign in to comment.