Skip to content

Commit

Permalink
Add providers param to ONNX Session in tests (#2553)
Browse files Browse the repository at this point in the history
* add providers arg to ONNX Session in tests

* Add providers arg to all ort.InferenceSession calls
  • Loading branch information
nik-mosaic authored Sep 21, 2023
1 parent 99c833b commit 7917482
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions examples/exporting_for_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@
"import numpy as np\n",
"\n",
"# run inference\n",
"ort_session = ort.InferenceSession(model_save_path)\n",
"ort_session = ort.InferenceSession(model_save_path, providers=['CPUExecutionProvider'])\n",
"outputs = ort_session.run(\n",
" None,\n",
" {'input': input[0].numpy()})\n",
Expand Down Expand Up @@ -513,7 +513,7 @@
"metadata": {},
"outputs": [],
"source": [
"ort_session = ort.InferenceSession(model_save_path)\n",
"ort_session = ort.InferenceSession(model_save_path, providers=['CPUExecutionProvider'])\n",
"new_outputs = ort_session.run(\n",
" None,\n",
" {'input': input[0].numpy()},\n",
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/test_torch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_surgery_onnx(
onnx.checker.check_model(onnx_model) # type: ignore (third-party)

# run inference
ort_session = ort.InferenceSession(onnx_path)
ort_session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
outputs = ort_session.run(
None,
{'input': input[0].numpy()},
Expand Down
3 changes: 2 additions & 1 deletion tests/test_full_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def inference_test_helper(finetuning_output_path, rud, finetuning_model, algorit
ort = pytest.importorskip('onnxruntime')
loaded_inference_model = onnx.load(str(tmp_path / 'inference_checkpoints' / 'exported_model.onnx'))
onnx.checker.check_model(loaded_inference_model)
ort_session = ort.InferenceSession(str(tmp_path / 'inference_checkpoints' / 'exported_model.onnx'))
ort_session = ort.InferenceSession(str(tmp_path / 'inference_checkpoints' / 'exported_model.onnx'),
providers=['CPUExecutionProvider'])

for key, value in copied_batch.items():
copied_batch[key] = value.numpy()
Expand Down
6 changes: 3 additions & 3 deletions tests/utils/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_huggingface_export_for_inference_onnx(onnx_opset_version, tiny_bert_con

onnx.checker.check_model(loaded_model)

ort_session = ort.InferenceSession(save_path)
ort_session = ort.InferenceSession(save_path, providers=['CPUExecutionProvider'])

for key, value in sample_input.items():
sample_input[key] = cpu_device.tensor_to_device(value).numpy()
Expand Down Expand Up @@ -217,7 +217,7 @@ def test_export_for_inference_onnx(model_cls, sample_input, onnx_opset_version,
loaded_model = onnx.load(save_path)
onnx.checker.check_model(loaded_model)

ort_session = ort.InferenceSession(save_path)
ort_session = ort.InferenceSession(save_path, providers=['CPUExecutionProvider'])
loaded_model_out = ort_session.run(
None,
{'input': cpu_device.tensor_to_device(sample_input[0]).numpy()},
Expand Down Expand Up @@ -355,7 +355,7 @@ def test_export_for_inference_onnx_ddp(model_cls, sample_input, onnx_opset_versi

loaded_model = onnx.load(save_path)
onnx.checker.check_model(loaded_model)
ort_session = ort.InferenceSession(save_path)
ort_session = ort.InferenceSession(save_path, providers=['CPUExecutionProvider'])
loaded_model_out = ort_session.run(
None,
{'input': sample_input[0].numpy()},
Expand Down

0 comments on commit 7917482

Please sign in to comment.