Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make quantized activation handlers data layout aware #1183

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions src/finn/transformation/qonnx/qonnx_activation_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import warnings
from abc import ABC, abstractmethod
from onnx import TensorProto, helper
from qonnx.core.modelwrapper import ModelWrapper
Expand Down Expand Up @@ -70,7 +70,7 @@ def _check_compatibility(self):
@abstractmethod
def _calculate_act_bias(self):
"""Calculate the activation bias,
which is introduced as an Add node behind the MultiTrheshold node.
which is introduced as an Add node behind the MultiThreshold node.
"""
raise NotImplementedError()

Expand All @@ -82,7 +82,7 @@ def _calculate_thresholds(self):
@abstractmethod
def _calculate_act_scale(self):
"""Calculate the activation scale,
which is indroduced as a Mul node behind the Add node
which is introduced as a Mul node behind the Add node
for the activation bias.
"""
raise NotImplementedError()
Expand Down Expand Up @@ -139,6 +139,8 @@ def replace_quant_node(self):
graph.value_info.append(thresh_tensor)
model.set_initializer(thresh_tensor.name, thresholds)

data_layout = model.get_tensor_layout(n.input[0])

# Insert MultiThreshold node
outp_trans_node = helper.make_node(
"MultiThreshold",
Expand All @@ -154,10 +156,15 @@ def replace_quant_node(self):
mt_node = graph.node[running_node_index - 1]
mt_inst = getCustomOp(mt_node)

# Inherit the data layout from the input tensor if available
if data_layout is not None:
# Convert list to string representation of the data layout
mt_inst.set_nodeattr("data_layout", "".join(data_layout))

# Set scale and bias
# If these values are scalar then they can be set as attributes
# of the MultiThreshold node, if not they get inserted as adder and mul nodes
# behind the MultiTrheshold nodes.
# behind the MultiThreshold nodes.
bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
if scale_scalar and bias_scalar and self._q_node.op_type == "BipolarQuant":
Expand Down Expand Up @@ -355,7 +362,7 @@ def _calculate_thresholds(self):
act_node = self._model.find_direct_predecessors(self._q_node)
act_node = act_node[0]
if act_node.op_type == "Relu":
# Calculate thersholds, see: https://github.com/Xilinx/brevitas/blob/
# Calculate thresholds, see: https://github.com/Xilinx/brevitas/blob/
# a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/
# onnx/finn/handler/act.py#L21
num_distinct_values = 2**bit_width
Expand Down Expand Up @@ -395,8 +402,27 @@ def _calculate_thresholds(self):
else:
thresholds[c][t] = step / selu_scale

# First try to consider the tensor layout of the input for determining
# the number of output channels
layout = self._model.get_tensor_layout(self._q_node.input[0])
# If there is a layout annotation, use this to determine the index of
# the channel dimension
if layout is not None and "C" in layout:
# Lookup the index in list
cdim = layout.index("C")
# If no layout has been annotated or there is no channel dimension, fall
# back to the previous default assumption
else:
# Assume the channels to be in axis 1
cdim = 1
# Issue a warning to the user, so they are aware of this
warnings.warn(
f"No layout annotations for {self._q_node.input[0]}:"
f" Assuming channel dimension at index {cdim}"
)

# ToDo: The index 1 needs to be changed to -1 for the channels last format
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1]
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim]
final_shape = (num_output_channels, num_thresholds)
if thresholds.shape != final_shape:
thresholds = np.broadcast_to(thresholds, final_shape)
Expand All @@ -417,12 +443,12 @@ def _remove_activation_node(self, multi_threshold_node):
act_node = self._model.find_direct_predecessors(self._q_node)
if act_node is None:
raise RuntimeError(
"For handling of Relu activations a predecesor to " "the Quant node must exist."
"For handling of Relu activations a predecessor to " "the Quant node must exist."
)
act_node = act_node[0]
if act_node.op_type not in self.valid_predecessor_op_types():
raise RuntimeError(
"The predecesor of the Quant node must be Relu or Selu for handling "
"The predecessor of the Quant node must be Relu or Selu for handling "
"of activations."
)

Expand Down Expand Up @@ -509,7 +535,7 @@ def _calculate_thresholds(self):
else:
raise RuntimeError("Got an unexpected quantizer node type")

# Calculate thersholds, see: https://github.com/Xilinx/brevitas/
# Calculate thresholds, see: https://github.com/Xilinx/brevitas/
# blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/
# export/onnx/finn/handler/act.py#L76
if bit_width == 1.0:
Expand Down Expand Up @@ -537,8 +563,28 @@ def _calculate_thresholds(self):
for t in range(num_thresholds):
thresholds[c][t] = min_threshold[c] + step[c] * t

# First try to consider the tensor layout of the input for
# determining the number of output channels
layout = self._model.get_tensor_layout(self._q_node.input[0])
# If there is a layout annotation, use this to determine the index
# of the channel dimension
if layout is not None and "C" in layout:
# Lookup the index in list
cdim = layout.index("C")
# If no layout has been annotated or there is no channel dimension,
# fall back to the previous default assumption
else:
# Assume the channels to be in axis 1
cdim = 1
# Issue a warning to the user, so they are aware of this
warnings.warn(
f"No layout annotations for {self._q_node.input[0]}:"
f" Assuming channel dimension at index {cdim}"
)

# ToDo: The index 1 needs to be changed to -1 for the channels last format
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1]
num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim]

final_shape = (num_output_channels, num_thresholds)
if thresholds.shape != final_shape:
thresholds = np.broadcast_to(thresholds, final_shape)
Expand Down