Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose input_names and output_names when exporting to ONNX #2601

Merged
merged 4 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion composer/callbacks/export_for_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class ExportForInferenceCallback(Callback):
sample_input (Any, optional): Example model inputs used for tracing. This is needed for "onnx" export
transforms (Sequence[Transform], optional): transformations (usually optimizations) that should
be applied to the model. Each Transform should be a callable that takes a model and returns a modified model.
input_names (Sequence[str], optional): names to assign to the input nodes of the graph, in order. If set
to ``None``, the keys from the `sample_input` will be used. Fallbacks to ``["input"]``.
output_names (Sequence[str], optional): names to assign to the output nodes of the graph, in order. It set
to ``None``, it defaults to ``["output"]``.
"""

def __init__(
Expand All @@ -60,12 +64,16 @@ def __init__(
save_object_store: Optional[ObjectStore] = None,
sample_input: Optional[Any] = None,
transforms: Optional[Sequence[Transform]] = None,
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
):
self.save_format = save_format
self.save_path = save_path
self.save_object_store = save_object_store
self.sample_input = sample_input
self.transforms = transforms
self.input_names = input_names
self.output_names = output_names

def after_dataloader(self, state: State, logger: Logger) -> None:
del logger
Expand All @@ -85,4 +93,6 @@ def export_model(self, state: State, logger: Logger):
logger=logger,
save_object_store=self.save_object_store,
sample_input=(self.sample_input, {}),
transforms=self.transforms)
transforms=self.transforms,
input_names=self.input_names,
output_names=self.output_names)
10 changes: 9 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3099,6 +3099,8 @@ def export_for_inference(
save_object_store: Optional[ObjectStore] = None,
sample_input: Optional[Any] = None,
transforms: Optional[Sequence[Transform]] = None,
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
):
"""Export a model for inference.

Expand All @@ -3117,6 +3119,10 @@ def export_for_inference(
should accept the ``sample_input`` as is. (default: ``None``)
transforms (Sequence[Transform], optional): transformations (usually optimizations) that should
be applied to the model. Each Transform should be a callable that takes a model and returns a modified model.
input_names (Sequence[str], optional): names to assign to the input nodes of the graph, in order. If set
to ``None``, the keys from the `sample_input` will be used. Fallbacks to ``["input"]``.
output_names (Sequence[str], optional): names to assign to the output nodes of the graph, in order. It set
to ``None``, it defaults to ``["output"]``.

Returns:
None
Expand All @@ -3132,4 +3138,6 @@ def export_for_inference(
logger=self.logger,
save_object_store=save_object_store,
sample_input=(sample_input, {}),
transforms=transforms)
transforms=transforms,
input_names=input_names,
output_names=output_names)
46 changes: 33 additions & 13 deletions composer/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def export_for_inference(
load_path: Optional[str] = None,
load_object_store: Optional[ObjectStore] = None,
load_strict: bool = False,
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
) -> None:
"""Export a model for inference.

Expand Down Expand Up @@ -132,6 +134,10 @@ def export_for_inference(
Otherwise, if the checkpoint is a local filepath, set to ``None``. (default: ``None``)
load_strict (bool): Whether the keys (i.e., model parameter names) in the model state dict should
perfectly match the keys in the model instance. (default: ``False``)
input_names (Sequence[str], optional): names to assign to the input nodes of the graph, in order. If set
to ``None``, the keys from the `sample_input` will be used. Fallbacks to ``["input"]``.
output_names (Sequence[str], optional): names to assign to the output nodes of the graph, in order. It set
to ``None``, it defaults to ``["output"]``.

Returns:
None
Expand Down Expand Up @@ -224,25 +230,29 @@ def export_for_inference(
if sample_input is None:
raise ValueError(f'sample_input argument is required for onnx export')

input_names = []
if input_names is None:
input_names = []

# assert statement for pyright error: Cannot access member "keys" for type "Tensor"
assert isinstance(sample_input, tuple)
# Extract input names from sample_input if it contains dicts
for i in range(len(sample_input)):
if isinstance(sample_input[i], dict):
input_names += list(sample_input[i].keys())
# assert statement for pyright error: Cannot access member "keys" for type "Tensor"
assert isinstance(sample_input, tuple)
# Extract input names from sample_input if it contains dicts
for i in range(len(sample_input)):
if isinstance(sample_input[i], dict):
input_names += list(sample_input[i].keys())

# Default input name if no dict present
if input_names == []:
input_names = ['input']
# Default input name if no dict present
if input_names == []:
input_names = ['input']

if output_names is None:
output_names = ['output']

torch.onnx.export(
model,
sample_input,
local_save_path,
input_names=input_names,
output_names=['output'],
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=onnx_opset_version,
)
Expand All @@ -260,6 +270,8 @@ def export_with_logger(
save_object_store: Optional[ObjectStore] = None,
sample_input: Optional[Any] = None,
transforms: Optional[Sequence[Transform]] = None,
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
) -> None:
"""Helper method for exporting a model for inference.

Expand Down Expand Up @@ -289,6 +301,10 @@ def export_with_logger(
transforms (Sequence[Transform], optional): transformations (usually optimizations) that should
be applied to the model. Each Transform should be a callable that takes a model and returns a modified model.
``transforms`` are applied after ``surgery_algs``. (default: ``None``)
input_names (Sequence[str], optional): names to assign to the input nodes of the graph, in order. If set
to ``None``, the keys from the `sample_input` will be used. Fallbacks to ``["input"]``.
output_names (Sequence[str], optional): names to assign to the output nodes of the graph, in order. It set
to ``None``, it defaults to ``["output"]``.

Returns:
None
Expand All @@ -300,12 +316,16 @@ def export_with_logger(
save_format=save_format,
save_path=temp_local_save_path,
sample_input=sample_input,
transforms=transforms)
transforms=transforms,
input_names=input_names,
output_names=output_names)
logger.upload_file(remote_file_name=save_path, file_path=temp_local_save_path)
else:
export_for_inference(model=model,
save_format=save_format,
save_path=save_path,
save_object_store=save_object_store,
sample_input=sample_input,
transforms=transforms)
transforms=transforms,
input_names=input_names,
output_names=output_names)
26 changes: 26 additions & 0 deletions tests/utils/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,32 @@ def test_export_for_inference_torchscript(model_cls, sample_input):
)


def test_export_for_inference_input_and_output_names():
model = composer_resnet('resnet18')
sample_input = torch.rand(4, 3, 224, 224)

model.eval()

input_names = ['image']
output_names = ['prediction', 'score']

save_format = 'torchscript'
with patch('torch.onnx.export'):
with tempfile.TemporaryDirectory() as tempdir:
save_path = os.path.join(tempdir, f'model.pt')
inference.export_for_inference(
model=model,
sample_input=sample_input,
save_format=save_format,
save_path=save_path,
input_names=input_names,
output_names=output_names,
)

torch.onnx.export.call_args.kwargs['input_names'] = input_names
torch.onnx.export.call_args.kwargs['output_names'] = output_names


@device('cpu', 'gpu')
@pytest.mark.parametrize('onnx_opset_version', [13, None])
def test_huggingface_export_for_inference_onnx(onnx_opset_version, tiny_bert_config, device):
Expand Down
Loading