Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option DefaultTensorType to specify the default tensor type to quantize #19455

Merged
merged 6 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions onnxruntime/python/tools/quantization/onnx_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,19 @@
return False
return self.parent.is_valid_quantize_weight(weight_name)

def _get_default_tensor_type(self, tensor_name):
if "DefaultTensorType" in self.extra_options:
logging.info(f"get_tensor_type returns DefaultTensorType for tensor name %r, use %d", tensor_name, self.extra_options["DefaultTensorType"])
Fixed Show fixed Hide fixed
return self.extra_options["DefaultTensorType"]
raise RuntimeError(
f"Unable to find data type for weight_name={tensor_name!r}. "
f"shape_inference failed to return a type probably this node is "
f"from a different domain or using an input produced by such an operator. "
f"This may happen if you quantize a model already quantized. "
f"You may use extra_options `DefaultTensorType` to indicate "
f"the default weight type, usually `onnx.TensorProto.FLOAT`."
)

def get_tensor_type(self, tensor_name, mandatory=False):
weight = find_by_name(tensor_name, self.model.initializer())
if weight is not None:
Expand All @@ -454,11 +467,11 @@
vi = self.value_infos[tensor_name]
if vi.type.HasField("tensor_type"):
if mandatory and vi.type.tensor_type.elem_type == 0:
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
return self._get_default_tensor_type(tensor_name)
return vi.type.tensor_type.elem_type
if (not self.enable_subgraph_quantization) or (self.parent is None):
if mandatory:
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
return self._get_default_tensor_type(tensor_name)
return None
otype = self.parent.is_valid_quantize_weight(tensor_name)
if otype is not None:
Expand All @@ -468,7 +481,7 @@
if res is not None:
return res
if mandatory:
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
return self._get_default_tensor_type(tensor_name)
return None

def is_float_tensor(self, tensor_name):
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/transformers/quantize_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data
onnx_model_path,
quantized_model_path,
use_external_data_format=use_external_data_format,
extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT},
)
logger.info(f"quantized model saved to:{quantized_model_path}")
# TODO: inlcude external data in total model size.
Expand Down
Loading