Skip to content

Commit

Permalink
add load_in_8bit
Browse files Browse the repository at this point in the history
  • Loading branch information
airaria committed Jul 17, 2023
1 parent 67d8d93 commit 41580b1
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion scripts/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
parser.add_argument('--predictions_file', default='./predictions.json', type=str)
parser.add_argument('--gpus', default="0", type=str)
parser.add_argument('--only_cpu',action='store_true',help='only use CPU for inference')
parser.add_argument('--load_in_8bit',action='store_true', help="Load the LLM in the 8bit mode")

args = parser.parse_args()
if args.only_cpu is True:
args.gpus = ""
Expand Down Expand Up @@ -91,7 +93,7 @@ def generate_prompt(instruction, input=None):

base_model = LlamaForCausalLM.from_pretrained(
args.base_model,
load_in_8bit=False,
load_in_8bit=args.load_in_8bit,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
Expand Down

0 comments on commit 41580b1

Please sign in to comment.