Skip to content

Commit

Permalink
add option DefaultTensorType to specify the default tensor type to qu…
Browse files Browse the repository at this point in the history
…antize (#19455)

### Description
The current quantization tool relies on shape inference to provide the
type of every intermediate tensor, then the tool knows which type it
must dequantize into (float32, float16). However, this information is
not available if shape inference fails. That happens every time the
model include an operator from a custom domain such as com.microsoft.

This PR introduces an extra option `DefaultTensorType` as a fall back
when the quantizer cannot find the type it needs.

### Motivation and Context
This fixes issue #19409.
  • Loading branch information
xadupre authored Feb 20, 2024
1 parent e832562 commit 7efb0db
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 5 deletions.
25 changes: 21 additions & 4 deletions onnxruntime/python/tools/quantization/onnx_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def add_new_nodes(self, nodes):
def quantize_model(self):
if self.has_QDQ_nodes():
logging.warning(
"Please check if the model is already quantized."
"Please check if the model is already quantized. "
"Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly."
)

Expand Down Expand Up @@ -442,6 +442,23 @@ def is_valid_quantize_weight(self, weight_name):
return False
return self.parent.is_valid_quantize_weight(weight_name)

def _get_default_tensor_type(self, tensor_name):
if "DefaultTensorType" in self.extra_options:
logging.info(
"get_tensor_type returns DefaultTensorType for tensor name %r, use %d",
tensor_name,
self.extra_options["DefaultTensorType"],
)
return self.extra_options["DefaultTensorType"]
raise RuntimeError(
f"Unable to find data type for weight_name={tensor_name!r}. "
f"shape_inference failed to return a type probably this node is "
f"from a different domain or using an input produced by such an operator. "
f"This may happen if you quantize a model already quantized. "
f"You may use extra_options `DefaultTensorType` to indicate "
f"the default weight type, usually `onnx.TensorProto.FLOAT`."
)

def get_tensor_type(self, tensor_name, mandatory=False):
weight = find_by_name(tensor_name, self.model.initializer())
if weight is not None:
Expand All @@ -450,11 +467,11 @@ def get_tensor_type(self, tensor_name, mandatory=False):
vi = self.value_infos[tensor_name]
if vi.type.HasField("tensor_type"):
if mandatory and vi.type.tensor_type.elem_type == 0:
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
return self._get_default_tensor_type(tensor_name)
return vi.type.tensor_type.elem_type
if (not self.enable_subgraph_quantization) or (self.parent is None):
if mandatory:
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
return self._get_default_tensor_type(tensor_name)
return None
otype = self.parent.is_valid_quantize_weight(tensor_name)
if otype is not None:
Expand All @@ -464,7 +481,7 @@ def get_tensor_type(self, tensor_name, mandatory=False):
if res is not None:
return res
if mandatory:
raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
return self._get_default_tensor_type(tensor_name)
return None

def is_float_tensor(self, tensor_name):
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/python/tools/transformers/quantize_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import os

import onnx # noqa: F401
import onnx
import torch
from transformers.modeling_utils import Conv1D

Expand Down Expand Up @@ -69,6 +69,7 @@ def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data
onnx_model_path,
quantized_model_path,
use_external_data_format=use_external_data_format,
extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT},
)
logger.info(f"quantized model saved to:{quantized_model_path}")
# TODO: inlcude external data in total model size.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import unittest

import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh

from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType


class TestQuantizerShapeInference(unittest.TestCase):
def test_com_microsoft(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("MatMul", ["X", "W1"], ["T1"]),
oh.make_node("FusedMatMul", ["T1", "W2"], ["T2"], domain="com.microsoft"),
oh.make_node("MatMul", ["T2", "W3"], ["T3"]),
oh.make_node("MatMul", ["T3", "W4"], ["Y"]),
],
"name",
[oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 4])],
[oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4])],
[
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W1"),
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W2"),
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W3"),
onh.from_array(np.random.randn(4, 4).astype(np.float32), "W4"),
],
),
opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)],
)
model_shaped = onnx.shape_inference.infer_shapes(model)
shaped_results = set(t.name for t in model_shaped.graph.value_info)
# every result after T1 depends on T2 coming from a node com.microsoft,
# shape_inference cannot go beyond this point
self.assertEqual(shaped_results, {"T1"})

# first try: checks it raises an exception
quantizer = ONNXQuantizer(
model,
False, # per_channel
False, # reduce_range
QuantizationMode.IntegerOps, # mode
False, # static
QuantType.QInt8, # weight_type,
QuantType.QUInt8, # dynamic activation only supports uint8
None,
[], # nodes_to_quantize,
[], # nodes_to_exclude
["MatMul"], # op_types_to_quantize,
{"MatMulConstBOnly": True}, # extra_options,
# {'DefaultTensorType': 1, }
)

with self.assertRaises(RuntimeError) as e:
quantizer.quantize_model()
self.assertIn("Unable to find data type for weight_name=", str(e))

# second try: checks it works
quantizer = ONNXQuantizer(
model,
False, # per_channel
False, # reduce_range
QuantizationMode.IntegerOps, # mode
False, # static
QuantType.QInt8, # weight_type,
QuantType.QUInt8, # dynamic activation only supports uint8
None,
[], # nodes_to_quantize,
[], # nodes_to_exclude
["MatMul"], # op_types_to_quantize,
{
"MatMulConstBOnly": True,
"DefaultTensorType": 1,
},
)

model = quantizer.quantize_model()
ops = {n.op_type for n in model.graph.node}
self.assertEqual(ops, {"Cast", "FusedMatMul", "MatMulInteger", "DynamicQuantizeLinear", "Mul"})


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 7efb0db

Please sign in to comment.