From fdd89a6ee9561f64aa8ff1dea932037501af081d Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Sat, 30 Sep 2023 10:29:20 +0200 Subject: [PATCH 01/17] [Streamline] Prefer AbsorbSignBiasIntoMultiThreshold transform Flips the order of AbsorbSignBiasIntoMultiThreshold and MoveScalarLinearPastInvariants streamlining transforms to prefer absorbing adds into multi-thresholds instead of propagating them downwards. This should prevent accumulation of scalar adds in front of two-input matmuls in scaled dot-product attention operators (they cannot be moved past the matmul operation in that case). --- src/finn/transformation/streamline/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py index 2e68de698b..39ef87f81c 100644 --- a/src/finn/transformation/streamline/__init__.py +++ b/src/finn/transformation/streamline/__init__.py @@ -76,8 +76,8 @@ def apply(self, model): BatchNormToAffine(), ConvertSignToThres(), MoveMulPastMaxPool(), - MoveScalarLinearPastInvariants(), AbsorbSignBiasIntoMultiThreshold(), + MoveScalarLinearPastInvariants(), MoveAddPastMul(), MoveScalarAddPastMatMul(), MoveAddPastConv(), From be33bbc4a188c2a76738df2f484f1fc72bc1c07e Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Sat, 30 Sep 2023 11:19:12 +0200 Subject: [PATCH 02/17] [Streamline] Refactor MoveScalarMulPastMatMul to handle join-node matmul The MoveScalarMulPastMatMul transformation can now handle matmul operations with both inputs preceded by a scalar multiplication. This change is required for streamlining scaled dot-product attention operations, which are essentially two-input matmuls. --- src/finn/transformation/streamline/reorder.py | 147 ++++++++++++------ 1 file changed, 103 insertions(+), 44 deletions(-) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 8ac2d7dad6..24ddfb78cd 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -100,58 +100,117 @@ def apply(self, model): return (model, graph_modified) +# Tests whether a tensor is a scalar, i.e., whether all dimensions are 1 +def is_scalar(tensor): + return tensor is not None and all(x == 1 for x in tensor.shape) + + +# Tests whether a node is a scalar multiplication with a constant scale factor +def is_const_scalar_mul(node, model): + # Only handle existing Mul type nodes + if node is not None and node.op_type == "Mul": + # The constant must be an initializer + # Note: Assumes the constant parameter to always be the second input + scale = model.get_initializer(node.input[1]) + # Test for existence of a constant scale factor + return scale is not None and is_scalar(scale) + # Did not match the operator type + return False + + +# Refactored version of the MoveScalarMulPastMatMul transform capable of +# transforming two-input MatMul, like those being part of the attention operator class MoveScalarMulPastMatMul(Transformation): """Move scalar mul operations past matmul operations. We want to have muls next to each other such that they can be collapsed into a single mul.""" + # Applies the transform to a whole model graph def apply(self, model): + # Get the model graph out of the model wrapper object graph = model.graph - node_ind = 0 + # Keep track of whether the graph has been modified graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Mul" and not model.is_fork_node(n) and not model.is_join_node(n): - consumer = model.find_consumer(n.output[0]) - if ( - consumer is not None - and consumer.op_type == "MatMul" - and not model.is_join_node(consumer) - ): - mul_weight_name = n.input[1] - matmul_weight_name = consumer.input[1] - A = model.get_initializer(mul_weight_name) - W = model.get_initializer(matmul_weight_name) - if (A is None) or (W is None): - warnings.warn("MatMul or Mul params are not constant, skipping") - continue - start_name = n.input[0] - middle_name = n.output[0] - end_name = consumer.output[0] - mm_out_shape = model.get_tensor_shape(end_name) - if all(x == 1 for x in A.shape): - # if the mul is scalar, we can simply swap the order of ops - # make and insert new nodes - new_matmul = oh.make_node( - "MatMul", - [start_name, matmul_weight_name], - [middle_name], - name=consumer.name, - ) - new_mul = oh.make_node( - "Mul", - [middle_name, mul_weight_name], - [end_name], - name=n.name, - ) - graph.node.insert(node_ind, new_matmul) - graph.node.insert(node_ind + 1, new_mul) - model.set_tensor_shape(middle_name, mm_out_shape) - # remove old nodes - graph.node.remove(n) - graph.node.remove(consumer) - graph_modified = True + + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # First pattern matching condition: For the transform to be + # applicable, the node has to be a MatMul operator + # Note: Cannot handle fork nodes for now, as it is unclear how to + # distribute the mul into the branches (without knowing the + # operators for all branches) + if node.op_type == "MatMul": + # Get the left hand side and right hand side inputs + # Note: Assumes the ordering of left to right inputs to match + # indices 0 to 1. However, it does not "hurt" if it is + # reversed as both sides are treated equivalently. + lhs = model.find_producer(node.input[0]) + rhs = model.find_producer(node.input[1]) + + # Give precedence to the left hand side input testing for the + # presence of a scalar multiplication + if (is_const_scalar_mul(lhs, model) + and not model.is_fork_node(lhs)): + # Unpack the connection pattern of a scalar mul feeding the + # lhs input of the matmul + # Names of the three input tensors to the mul-matmul complex + a, b, c = lhs.input[0], lhs.input[1], node.input[1] + # Names of the intermediate and the global output + m, o = lhs.output[0], node.output[0] # noqa: Duplicate code + # Rewire the operator connections locally, swapping mul and + # matmul operator order + matmul = oh.make_node("MatMul", [a, c], [m], node.name) + mul = oh.make_node("Mul", [m, b], [o], lhs.name) + # Insert the rewired nodes into the graph + graph.node.insert(index, matmul) + graph.node.insert(index + 1, mul) + # Adapt the shape of the intermediate tensor as it changed + # according to the output shape of the matmul + model.set_tensor_shape(m, model.get_tensor_shape(o)) + # Remove the old nodes from the graph + graph.node.remove(lhs) + graph.node.remove(node) + # The graph has been modified, this needs to be reported + # back to the caller + graph_modified = True + # Cannot further modify the node (i.e., the rhs) as the + # index and state of the nodes changed and need to be + # queried again from the graph.node at the start of the next + # iteration. + continue + + # Next try whether the right hand side matches the pattern of a + # scalar multiplication + if (is_const_scalar_mul(rhs, model) + and not model.is_fork_node(rhs)): + # Unpack the connection pattern of a scalar mul feeding the + # rhs input of the matmul + # Names of the three input tensors to the mul-matmul complex + a, b, c = node.input[0], rhs.input[0], rhs.input[1] + # Names of the intermediate and the global output + m, o = rhs.output[0], node.output[0] # noqa: Duplicate code + # Rewire the operator connections locally, swapping mul and + # matmul operator order + matmul = oh.make_node("MatMul", [a, b], [m], node.name) + mul = oh.make_node("Mul", [m, c], [o], rhs.name) + # Insert the rewired nodes into the graph + graph.node.insert(index, matmul) + graph.node.insert(index + 1, mul) + # Adapt the shape of the intermediate tensor as it changed + # according to the output shape of the matmul + model.set_tensor_shape(m, model.get_tensor_shape(o)) + # Remove the old nodes from the graph + graph.node.remove(rhs) + graph.node.remove(node) + # The graph has been modified, this needs to be reported + # back to the caller + graph_modified = True + + # Finalize the transformation by inferring shapes again (as these might + # have changed) model = model.transform(InferShapes()) - return (model, graph_modified) + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified class MoveScalarAddPastMatMul(Transformation): From 09c1993ee9787c53b3411f186a03e7b80fbe6443 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Sat, 30 Sep 2023 11:36:38 +0200 Subject: [PATCH 03/17] Remove misplaced/outdated comment --- src/finn/transformation/streamline/reorder.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 24ddfb78cd..861d9c2ac6 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -135,9 +135,6 @@ def apply(self, model): for index, node in enumerate(graph.node): # First pattern matching condition: For the transform to be # applicable, the node has to be a MatMul operator - # Note: Cannot handle fork nodes for now, as it is unclear how to - # distribute the mul into the branches (without knowing the - # operators for all branches) if node.op_type == "MatMul": # Get the left hand side and right hand side inputs # Note: Assumes the ordering of left to right inputs to match From 9dade0c9df21b19118c87f658a28c7334d624965 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Sat, 30 Sep 2023 12:15:18 +0200 Subject: [PATCH 04/17] [Streamline] Soften initializer tests in Absorb1BitMulIntoMatMul/Conv Assertions are to restrictive, causing the program to terminate in cases the streamlining simply encounters nodes to which the transforms are not applicable: Just skip those nodes. Only the two transforms currently affecting the streamlining of scaled dot-product attention have been changed. --- src/finn/transformation/streamline/absorb.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index e3e2468bba..72f49d3ac0 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -224,16 +224,22 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 + # TODO: Maybe test for join-node here and reject? if n.op_type == "MatMul": matmul_weight_name = n.input[1] W = model.get_initializer(matmul_weight_name) Wdt = model.get_tensor_datatype(matmul_weight_name) - assert W is not None, "Initializer for matmul weights is not set." + # Just skip matmuls with non-existing weight initializers + if W is None: + continue consumer = model.find_consumer(n.output[0]) + # TODO: Maybe test for join-node here and reject? if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) - assert A is not None, "Initializer for mul weights is not set." + # Just skip muls with non-existing scale initializers + if A is None: + continue is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 if is_1bit: Wnew = A * W @@ -260,16 +266,22 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 + # TODO: Maybe test for join-node here and reject? if n.op_type == "Conv": conv_weight_name = n.input[1] W = model.get_initializer(conv_weight_name) Wdt = model.get_tensor_datatype(conv_weight_name) - assert W is not None, "Initializer for conv weights is not set." + # Just skip convs with non-existing weight initializers + if W is None: + continue consumer = model.find_consumer(n.output[0]) + # TODO: Maybe test for join-node here and reject? if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) - assert A is not None, "Initializer for mul weights is not set." + # Just skip muls with non-existing scale initializers + if A is None: + continue is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 is_scalar = np.prod(A.shape) == 1 actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) From 8bae5d7a6d3572e38c73ea542bd3654864f00d7d Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Thu, 19 Oct 2023 15:40:41 +0200 Subject: [PATCH 05/17] Address some linting issues --- src/finn/transformation/streamline/reorder.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 861d9c2ac6..2a520b3d28 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -136,6 +136,11 @@ def apply(self, model): # First pattern matching condition: For the transform to be # applicable, the node has to be a MatMul operator if node.op_type == "MatMul": + # Note: When touching the following code, remember to treat both + # branches equivalently! + # TODO: Can this be enforeced or at least be made easier by + # extracting common code patterns to a function? + # Get the left hand side and right hand side inputs # Note: Assumes the ordering of left to right inputs to match # indices 0 to 1. However, it does not "hurt" if it is @@ -145,8 +150,15 @@ def apply(self, model): # Give precedence to the left hand side input testing for the # presence of a scalar multiplication - if (is_const_scalar_mul(lhs, model) - and not model.is_fork_node(lhs)): + if is_const_scalar_mul(lhs, model): + # Cannto handle fork nodes: We would have to distribute the + # Mul into all branches + # TODO: Maybe reconsider this at some point, there is + # probabably nothing preventing this in general, it is just + # more difficult and apparanetly not necessary right now. + if model.is_fork_node(lhs): + # Softly skip this node + continue # Unpack the connection pattern of a scalar mul feeding the # lhs input of the matmul # Names of the three input tensors to the mul-matmul complex @@ -177,8 +189,15 @@ def apply(self, model): # Next try whether the right hand side matches the pattern of a # scalar multiplication - if (is_const_scalar_mul(rhs, model) - and not model.is_fork_node(rhs)): + if is_const_scalar_mul(rhs, model): + # Cannto handle fork nodes: We would have to distribute the + # Mul into all branches + # TODO: Maybe reconsider this at some point, there is + # probabably nothing preventing this in general, it is just + # more difficult and apparanetly not necessary right now. + if model.is_fork_node(rhs): + # Softly skip this node + continue # Unpack the connection pattern of a scalar mul feeding the # rhs input of the matmul # Names of the three input tensors to the mul-matmul complex From b22ebe3ec25bc257be5ae9e421bf3dfa9ca0253d Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Thu, 19 Oct 2023 17:54:47 +0200 Subject: [PATCH 06/17] [Tests] Add test for MoveScalarMulPastMatMul handling join nodes This is pretty much copy and paste of the existing test case, just replacing the MatMul initializer by a second top-input followed by a scalar Mul. --- .../test_move_scalar_past_matmul.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/transformation/streamline/test_move_scalar_past_matmul.py b/tests/transformation/streamline/test_move_scalar_past_matmul.py index e4f4357fff..515e9b9462 100644 --- a/tests/transformation/streamline/test_move_scalar_past_matmul.py +++ b/tests/transformation/streamline/test_move_scalar_past_matmul.py @@ -72,6 +72,43 @@ def test_move_scalar_mul_past_matmul(): assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] +@pytest.mark.streamline +def test_move_scalar_mul_past_join_matmul(): + top_in1 = oh.make_tensor_value_info("top_in1", TensorProto.FLOAT, [1, 2]) + top_in2 = oh.make_tensor_value_info("top_in2", TensorProto.FLOAT, [2, 1]) + mul1_param = oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1, 1]) + mul2_param = oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1, 1]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [1, 1]) + modelproto = qonnx_make_model( + oh.make_graph( + name="test", + inputs=[top_in1, top_in2], + outputs=[top_out], + value_info=[mul1_param, mul2_param], + nodes=[ + oh.make_node("Mul", ["top_in1", "mul1_param"], ["middle1"]), + oh.make_node("Mul", ["top_in2", "mul2_param"], ["middle2"]), + oh.make_node("MatMul", ["middle1", "middle2"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + model.set_initializer("mul1_param", np.asarray([[3]], dtype=np.float32)) + model.set_initializer("mul2_param", np.asarray([[3]], dtype=np.float32)) + new_model = model.transform(MoveScalarMulPastMatMul()) + inp_dict = { + "top_in1": np.asarray([[-1.0, 1.0]], dtype=np.float32), + "top_in2": np.asarray([[1.0], [-1.0]], dtype=np.float32), + } + assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "MatMul" + assert new_model.graph.node[1].op_type == "Mul" + assert new_model.graph.node[2].op_type == "Mul" + assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] + assert new_model.graph.node[1].output[0] == new_model.graph.node[2].input[0] + + @pytest.mark.streamline def test_move_scalar_add_past_matmul(): top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [1, 2]) From c10fa1d1cadb212c1932090eed95e5c9967d867d Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Fri, 27 Oct 2023 14:07:48 +0200 Subject: [PATCH 07/17] [Deps] Update qonnx version to include FoldTransposeIntoQuantInit fix --- fetch-repos.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fetch-repos.sh b/fetch-repos.sh index 073c052d67..0674a3668d 100755 --- a/fetch-repos.sh +++ b/fetch-repos.sh @@ -27,7 +27,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -QONNX_COMMIT="fd61cfeebbdaba351abf7e9d54cd785d7776fa4f" +QONNX_COMMIT="cadd6b236e093f910c9e7cea623c81846cab3506" FINN_EXP_COMMIT="de99347e936d51715f5356a1b6c64e37b91c23c2" BREVITAS_COMMIT="84f42259ec869eb151af4cb8a8b23ad925f493db" PYVERILATOR_COMMIT="ce0a08c20cb8c1d1e84181d6f392390f846adbd1" From 475a27bea9b3b49c3e480c71accaa2e948fa40c9 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Mon, 13 Nov 2023 15:01:37 +0100 Subject: [PATCH 08/17] [Streamline] Fix FoldQuantWeights input order and shape annotations Folding quantized initializers into add-like nodes did not repsect the order of inputs to the add node correctly. This is fixed by testing for one of the two possible orders and selecting the following indices accordingly. Shape inference following the transformation is fixed by deleting the annotations instead of propagating them incorrectly. Deleting the shape annotations should not hurt, as these are redone by running shape inference after each transformation anyways. --- .../qonnx/fold_quant_weights.py | 35 +++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/src/finn/transformation/qonnx/fold_quant_weights.py b/src/finn/transformation/qonnx/fold_quant_weights.py index 0f6cbacb82..59ebe4eea3 100644 --- a/src/finn/transformation/qonnx/fold_quant_weights.py +++ b/src/finn/transformation/qonnx/fold_quant_weights.py @@ -149,7 +149,8 @@ def apply(self, model): mul_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - mul_shape, + mul_shape, # Note: This shape is known exactly as + # it is an initializer with known shape ) graph.value_info.append(mul_tensor) model.set_initializer(mul_tensor.name, scale) @@ -168,7 +169,9 @@ def apply(self, model): act_mul_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - output_shape, + None, # Note: Explicitly delete the shape + # annotation to be redone by the next shape + # inference ) graph.value_info.append(act_mul_tensor) successor.output[0] = act_mul_tensor.name @@ -186,19 +189,37 @@ def apply(self, model): div_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - mul_shape, + None, # Note: Explicitly delete the shape + # annotation to be redone by the next shape + # inference ) graph.value_info.append(div_tensor) model.set_initializer(div_tensor.name, scale) - succ_input_name = successor.input[0] + # Detect which input of the add-like successor is + # fed by the quantizer node to select the other + # branch to insert the scale factor + if successor.input[0] == node_out: + succ_input_name = successor.input[1] + else: + succ_input_name = successor.input[0] + act_mul_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, - output_shape, + None, # Note: Explicitly delete the shape + # annotation to be redone by the next shape + # inference ) graph.value_info.append(act_mul_tensor) - successor.input[0] = act_mul_tensor.name + + # Detect which input of the add-like successor is + # fed by the quantizer node to select the other + # branch to insert the scale factor + if successor.input[0] == node_out: + successor.input[1] = act_mul_tensor.name + else: + successor.input[0] = act_mul_tensor.name div_node = helper.make_node( "Div", @@ -210,6 +231,8 @@ def apply(self, model): # remove old node graph.node.remove(n) graph_modified = True + # Note: Running shape inference is necessary as shape + # annotations have been deleted above model = model.transform(InferShapes()) return (model, graph_modified) return (model, graph_modified) From bd6a8f82c5f8ec4934ad2c3f47cf0b7b0ac0c599 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Mon, 13 Nov 2023 15:58:57 +0100 Subject: [PATCH 09/17] [Streamline] Fix AbsorbAddIntoMultiThreshold assumed input order Add is commutative and thus the export does not always generate the initializer as the second input. However, this was always assumed by this transformation, failing via assertion if the inputs were simply ordered differently. The transformation now handles both of the two possible input orderings. --- src/finn/transformation/streamline/absorb.py | 45 ++++++++++++++++---- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 72f49d3ac0..4ff1d607d6 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -30,6 +30,10 @@ import qonnx.core.data_layout as DataLayout import warnings from onnx import helper as oh +# Protobuf onnx graph node type +from onnx import NodeProto # noqa +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.datatype import DataType from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation @@ -100,6 +104,23 @@ def apply(self, model): return (model, graph_modified) +# Groups inputs by categories, i.e., groups dynamic inputs first, followed by +# initializers. Keeps order of inputs in each category. +def group_inputs_by_category(node: NodeProto, model: ModelWrapper): # noqa + # First select all dynamic inputs, which are those without initializer + # tensor + dynamics = [ + i for i in node.input if model.get_initializer(i) is None + ] + # Select all input which are initializers, which, by exclusion, are all + # those not among the dynamic inputs + initializers = [ + i for i in node.input if i not in dynamics + ] + # Return lists of dynamic anc initializer inputs + return dynamics, initializers + + class AbsorbAddIntoMultiThreshold(Transformation): """Absorb preceding Add ops into MultiThreshold by updating the threshold values. Only scalar/1D add vectors can be absorbed.""" @@ -113,13 +134,19 @@ def apply(self, model): if n.op_type == "Add" and not model.is_fork_node(n) and not model.is_join_node(n): consumer = model.find_consumer(n.output[0]) if consumer is not None and consumer.op_type == "MultiThreshold": - add_weight_name = n.input[1] - threshold_name = consumer.input[1] - A = model.get_initializer(add_weight_name) - T = model.get_initializer(threshold_name) - assert A is not None, "Initializer for add weights is not set." + # As Add is not a join node, there must be one initializer + # and one dynamic input. We do not know their order, but + # can group them accordingly to extract the tensor names + (start,), (add_weight, ) = group_inputs_by_category( + n, model + ) + threshold = consumer.input[1] + A = model.get_initializer(add_weight) + T = model.get_initializer(threshold) + # Test for the thresholds actually being initializers + # Note: No need to validate the add_weights anymore, this + # is already handled by the grouping and is_join_node test. assert T is not None, "Initializer for thresholds is not set." - start_name = n.input[0] # we can only absorb 0d or 1d adds is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape) actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) @@ -128,13 +155,13 @@ def apply(self, model): Tnew = T - A.reshape(-1, 1) # Tnew = T - A.reshape(-1, T.shape[1]) # compute new thresholds and set initializer - model.set_initializer(threshold_name, Tnew) + model.set_initializer(threshold, Tnew) # wire add input directly to MultiThreshold - consumer.input[0] = start_name + consumer.input[0] = start # remove the add node graph.node.remove(n) graph_modified = True - return (model, graph_modified) + return model, graph_modified class AbsorbMulIntoMultiThreshold(Transformation): From 1f7dd4c9ab625b2c190502597296f9222153493c Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Wed, 15 Nov 2023 11:18:48 +0100 Subject: [PATCH 10/17] [Streamline] Add support for Slice to MoveScalarLinearPastInvariants This is required for streamlining packed input projections of multi-head scaled dot-product attention. Adds support for Squeeze and Unsqueeze as well. Skip moving of fork-node producers as this is not handled correctly. However, the same effect can be attained by applying the MoveLinearPastFork transformation first. --- src/finn/transformation/streamline/reorder.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 2a520b3d28..9314774f02 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -680,6 +680,17 @@ class MoveScalarLinearPastInvariants(Transformation): GlobalAveragePool """ + # Op-types of currently supported invariants + SUPPORTED_INVARIANTS = { + "GlobalAveragePool", + "Reshape", + "Transpose", + "Flatten", + "Slice", + "Squeeze", + "Unsqueeze", + } + def apply(self, model): graph = model.graph node_ind = 0 @@ -692,13 +703,7 @@ def apply(self, model): # Extract mode and scales and input shape mode = get_by_name(n.attribute, "mode").s.decode("ascii") is_nearest_neighbor_resample = mode == "nearest" - if ( - n.op_type == "GlobalAveragePool" - or n.op_type == "Reshape" - or n.op_type == "Transpose" - or n.op_type == "Flatten" - or is_nearest_neighbor_resample - ): + if n.op_type in self.SUPPORTED_INVARIANTS or is_nearest_neighbor_resample: in0 = n.input[0] if in0 is None: continue @@ -708,6 +713,16 @@ def apply(self, model): continue if prod0.op_type in ["Mul", "Add", "Div"]: + # Cannot handle fork-nodes, try MoveLinearPastFork first + if model.is_fork_node(prod0): + warnings.warn( + f"{self.__class__.__name__}:" + f" Skipping near match: {prod0.name} is a fork-node," + f" try MoveLinearPastFork first" + ) + # Skip transforming this node as moving this would lead + # to messed up or detached graph + continue # check if second input of producer is an initializer init0 = model.get_initializer(prod0.input[1]) # if either initializer is None, skip From b3e50d70a021b8f1bd8612fed06c231f9f537aae Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Fri, 17 Nov 2023 13:18:33 +0100 Subject: [PATCH 11/17] [Streamline] Absorb1BitMulIntoMatMul/Conv does not handle fork-nodes Explicitly rejects absorbing into fork-nodes. Previously, this probably would have failed, silently resulting in a wrong model. Not sure whether this happened in any practically relevant models? --- src/finn/transformation/streamline/absorb.py | 28 ++++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 4ff1d607d6..9d5239eb5f 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -242,7 +242,7 @@ def apply(self, model): class Absorb1BitMulIntoMatMul(Transformation): - """Absorb bipolar or binary multiplications into the preciding matrix + """Absorb bipolar or binary multiplications into the preceding matrix multiply.""" def apply(self, model): @@ -251,8 +251,13 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 - # TODO: Maybe test for join-node here and reject? - if n.op_type == "MatMul": + # Note: Join-node test is implicitly covered by testing for the + # initializer below + # Note: This cannot handle fork-nodes, as only the first consumer is + # considered below. + # TODO: Fork-nodes could be handled if the muls are the same in all + # branches, but this is not checked nor rewired at all right now. + if n.op_type == "MatMul" and not model.is_fork_node(n): matmul_weight_name = n.input[1] W = model.get_initializer(matmul_weight_name) Wdt = model.get_tensor_datatype(matmul_weight_name) @@ -260,7 +265,8 @@ def apply(self, model): if W is None: continue consumer = model.find_consumer(n.output[0]) - # TODO: Maybe test for join-node here and reject? + # Note: Join-node test is implicitly covered by testing for the + # initializer below if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) @@ -285,7 +291,7 @@ def apply(self, model): class Absorb1BitMulIntoConv(Transformation): - """Absorb bipolar or binary multiplications into the preciding convolution.""" + """Absorb bipolar or binary multiplications into the preceding convolution.""" def apply(self, model): graph = model.graph @@ -293,8 +299,13 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 - # TODO: Maybe test for join-node here and reject? - if n.op_type == "Conv": + # Note: Join-node test is implicitly covered by testing for the + # initializer below + # Note: This cannot handle fork-nodes, as only the first consumer is + # considered below. + # TODO: Fork-nodes could be handled if the muls are the same in all + # branches, but this is not checked nor rewired at all right now. + if n.op_type == "Conv" and not model.is_fork_node(n): conv_weight_name = n.input[1] W = model.get_initializer(conv_weight_name) Wdt = model.get_tensor_datatype(conv_weight_name) @@ -302,7 +313,8 @@ def apply(self, model): if W is None: continue consumer = model.find_consumer(n.output[0]) - # TODO: Maybe test for join-node here and reject? + # Note: Join-node test is implicitly covered by testing for the + # initializer below if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) From 0413368817c60c07ec84957a815970c2cdcfda78 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Fri, 17 Nov 2023 13:34:50 +0100 Subject: [PATCH 12/17] [Deps] Temporarily switch qonnx to my fork including necessary fixes --- fetch-repos.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fetch-repos.sh b/fetch-repos.sh index 0674a3668d..acade345a9 100755 --- a/fetch-repos.sh +++ b/fetch-repos.sh @@ -27,7 +27,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -QONNX_COMMIT="cadd6b236e093f910c9e7cea623c81846cab3506" +QONNX_COMMIT="984a097645dd889782d79093d07dc757d5a6b77b" FINN_EXP_COMMIT="de99347e936d51715f5356a1b6c64e37b91c23c2" BREVITAS_COMMIT="84f42259ec869eb151af4cb8a8b23ad925f493db" PYVERILATOR_COMMIT="ce0a08c20cb8c1d1e84181d6f392390f846adbd1" @@ -40,7 +40,7 @@ RFSOC4x2_BDF_COMMIT="13fb6f6c02c7dfd7e4b336b18b959ad5115db696" KV260_BDF_COMMIT="98e0d3efc901f0b974006bc4370c2a7ad8856c79" EXP_BOARD_FILES_MD5="226ca927a16ea4ce579f1332675e9e9a" -QONNX_URL="https://github.com/fastmachinelearning/qonnx.git" +QONNX_URL="https://github.com/iksnagreb/qonnx.git" FINN_EXP_URL="https://github.com/Xilinx/finn-experimental.git" BREVITAS_URL="https://github.com/Xilinx/brevitas.git" PYVERILATOR_URL="https://github.com/maltanar/pyverilator.git" From 2bf794997c40a73dc895b8e6acd5fbf5fa99aabd Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Mon, 20 Nov 2023 15:10:03 +0100 Subject: [PATCH 13/17] Make quantized activation handlers data layout aware This probably is still rather sketchy, but at least it tries to check the data layout annotation. For now seems to be enough for getting the thresholds of multi-head attention right, IF qonnx properly annotates the 3D layouts. --- .../qonnx/qonnx_activation_handlers.py | 59 +++++++++++++++---- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index 323e391df4..451ba52c29 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -25,8 +25,8 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - import numpy as np +import warnings from abc import ABC, abstractmethod from onnx import TensorProto, helper from qonnx.core.modelwrapper import ModelWrapper @@ -70,7 +70,7 @@ def _check_compatibility(self): @abstractmethod def _calculate_act_bias(self): """Calculate the activation bias, - which is introduced as an Add node behind the MultiTrheshold node. + which is introduced as an Add node behind the MultiThreshold node. """ raise NotImplementedError() @@ -82,7 +82,7 @@ def _calculate_thresholds(self): @abstractmethod def _calculate_act_scale(self): """Calculate the activation scale, - which is indroduced as a Mul node behind the Add node + which is introduced as a Mul node behind the Add node for the activation bias. """ raise NotImplementedError() @@ -157,7 +157,7 @@ def replace_quant_node(self): # Set scale and bias # If these values are scalar then they can be set as attributes # of the MultiThreshold node, if not they get inserted as adder and mul nodes - # behind the MultiTrheshold nodes. + # behind the MultiThreshold nodes. bias_scalar = adder_bias.shape == (1,) or len(adder_bias.shape) == 0 scale_scalar = mul_scale.shape == (1,) or len(mul_scale.shape) == 0 if scale_scalar and bias_scalar and self._q_node.op_type == "BipolarQuant": @@ -355,7 +355,7 @@ def _calculate_thresholds(self): act_node = self._model.find_direct_predecessors(self._q_node) act_node = act_node[0] if act_node.op_type == "Relu": - # Calculate thersholds, see: https://github.com/Xilinx/brevitas/blob/ + # Calculate thresholds, see: https://github.com/Xilinx/brevitas/blob/ # a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/export/ # onnx/finn/handler/act.py#L21 num_distinct_values = 2**bit_width @@ -395,8 +395,27 @@ def _calculate_thresholds(self): else: thresholds[c][t] = step / selu_scale + # First try to consider the tensor layout of the output for determining + # the number of output channels + layout = self._model.get_tensor_layout(self._q_node.output[0]) + # If there is a layout annotation, use this to determine the index of + # the channel dimension + if layout is not None and "C" in layout: + # Lookup the index in list + cdim = layout.index("C") + # If no layout has been annotated or there is no channel dimension, fall + # back to the previous default assumption + else: + # Assume the channels to be in axis 1 + cdim = 1 + # Issue a warning to the user, so they are aware of this + warnings.warn( + f"No layout annotations for {self._q_node.output[0]}:" + f" Assuming channel dimension at index {cdim}" + ) + # ToDo: The index 1 needs to be changed to -1 for the channels last format - num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1] + num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim] final_shape = (num_output_channels, num_thresholds) if thresholds.shape != final_shape: thresholds = np.broadcast_to(thresholds, final_shape) @@ -417,12 +436,12 @@ def _remove_activation_node(self, multi_threshold_node): act_node = self._model.find_direct_predecessors(self._q_node) if act_node is None: raise RuntimeError( - "For handling of Relu activations a predecesor to " "the Quant node must exist." + "For handling of Relu activations a predecessor to " "the Quant node must exist." ) act_node = act_node[0] if act_node.op_type not in self.valid_predecessor_op_types(): raise RuntimeError( - "The predecesor of the Quant node must be Relu or Selu for handling " + "The predecessor of the Quant node must be Relu or Selu for handling " "of activations." ) @@ -509,7 +528,7 @@ def _calculate_thresholds(self): else: raise RuntimeError("Got an unexpected quantizer node type") - # Calculate thersholds, see: https://github.com/Xilinx/brevitas/ + # Calculate thresholds, see: https://github.com/Xilinx/brevitas/ # blob/a5bfd6dc5e030f0047ac1ee47932b60e8e873e17/src/brevitas/ # export/onnx/finn/handler/act.py#L76 if bit_width == 1.0: @@ -537,8 +556,28 @@ def _calculate_thresholds(self): for t in range(num_thresholds): thresholds[c][t] = min_threshold[c] + step[c] * t + # First try to consider the tensor layout of the output for + # determining the number of output channels + layout = self._model.get_tensor_layout(self._q_node.output[0]) + # If there is a layout annotation, use this to determine the index + # of the channel dimension + if layout is not None and "C" in layout: + # Lookup the index in list + cdim = layout.index("C") + # If no layout has been annotated or there is no channel dimension, + # fall back to the previous default assumption + else: + # Assume the channels to be in axis 1 + cdim = 1 + # Issue a warning to the user, so they are aware of this + warnings.warn( + f"No layout annotations for {self._q_node.output[0]}:" + f" Assuming channel dimension at index {cdim}" + ) + # ToDo: The index 1 needs to be changed to -1 for the channels last format - num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[1] + num_output_channels = self._model.get_tensor_shape(self._q_node.output[0])[cdim] + final_shape = (num_output_channels, num_thresholds) if thresholds.shape != final_shape: thresholds = np.broadcast_to(thresholds, final_shape) From 8783fd4e06939cce9e69b4f4d00fdf5f5d076b30 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Mon, 20 Nov 2023 16:36:41 +0100 Subject: [PATCH 14/17] [Deps] Update qonnx --- fetch-repos.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fetch-repos.sh b/fetch-repos.sh index acade345a9..e69808bcb2 100755 --- a/fetch-repos.sh +++ b/fetch-repos.sh @@ -27,7 +27,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -QONNX_COMMIT="984a097645dd889782d79093d07dc757d5a6b77b" +QONNX_COMMIT="a32707424c97447f6d7d0cacd977100f58db88d5" FINN_EXP_COMMIT="de99347e936d51715f5356a1b6c64e37b91c23c2" BREVITAS_COMMIT="84f42259ec869eb151af4cb8a8b23ad925f493db" PYVERILATOR_COMMIT="ce0a08c20cb8c1d1e84181d6f392390f846adbd1" From 2bf37f18262f9576f4c617efaea381badc58edad Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Wed, 13 Dec 2023 17:30:11 +0100 Subject: [PATCH 15/17] [Deps] Update qonnx --- fetch-repos.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fetch-repos.sh b/fetch-repos.sh index e69808bcb2..5a6896f05c 100755 --- a/fetch-repos.sh +++ b/fetch-repos.sh @@ -27,7 +27,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -QONNX_COMMIT="a32707424c97447f6d7d0cacd977100f58db88d5" +QONNX_COMMIT="786feb4fbdde8709fd9d5bca6cc3dab931627c6a" FINN_EXP_COMMIT="de99347e936d51715f5356a1b6c64e37b91c23c2" BREVITAS_COMMIT="84f42259ec869eb151af4cb8a8b23ad925f493db" PYVERILATOR_COMMIT="ce0a08c20cb8c1d1e84181d6f392390f846adbd1" From a4fc498ccb6b2c908d2ae340327e51b6762edaa0 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Wed, 13 Mar 2024 11:32:50 +0100 Subject: [PATCH 16/17] [Deps] Update qonnx --- fetch-repos.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fetch-repos.sh b/fetch-repos.sh index 5a6896f05c..9ad51fefb0 100755 --- a/fetch-repos.sh +++ b/fetch-repos.sh @@ -27,7 +27,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -QONNX_COMMIT="786feb4fbdde8709fd9d5bca6cc3dab931627c6a" +QONNX_COMMIT="1a4957ebf2aaf139217fd56109386d4518dd6127" FINN_EXP_COMMIT="de99347e936d51715f5356a1b6c64e37b91c23c2" BREVITAS_COMMIT="84f42259ec869eb151af4cb8a8b23ad925f493db" PYVERILATOR_COMMIT="ce0a08c20cb8c1d1e84181d6f392390f846adbd1" From 6c56382003e7366361b8a4e27759877ff2909523 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Thu, 4 Apr 2024 18:08:51 +0200 Subject: [PATCH 17/17] Fix some typos --- src/finn/transformation/streamline/reorder.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 9314774f02..74cef0558a 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -138,7 +138,7 @@ def apply(self, model): if node.op_type == "MatMul": # Note: When touching the following code, remember to treat both # branches equivalently! - # TODO: Can this be enforeced or at least be made easier by + # TODO: Can this be enforced or at least be made easier by # extracting common code patterns to a function? # Get the left hand side and right hand side inputs @@ -151,11 +151,11 @@ def apply(self, model): # Give precedence to the left hand side input testing for the # presence of a scalar multiplication if is_const_scalar_mul(lhs, model): - # Cannto handle fork nodes: We would have to distribute the + # Cannot handle fork nodes: We would have to distribute the # Mul into all branches # TODO: Maybe reconsider this at some point, there is - # probabably nothing preventing this in general, it is just - # more difficult and apparanetly not necessary right now. + # probably nothing preventing this in general, it is just + # more difficult and apparently not necessary right now. if model.is_fork_node(lhs): # Softly skip this node continue @@ -190,11 +190,11 @@ def apply(self, model): # Next try whether the right hand side matches the pattern of a # scalar multiplication if is_const_scalar_mul(rhs, model): - # Cannto handle fork nodes: We would have to distribute the + # Cannot handle fork nodes: We would have to distribute the # Mul into all branches # TODO: Maybe reconsider this at some point, there is - # probabably nothing preventing this in general, it is just - # more difficult and apparanetly not necessary right now. + # probably nothing preventing this in general, it is just + # more difficult and apparently not necessary right now. if model.is_fork_node(rhs): # Softly skip this node continue