Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (arith) add SignlessIntegerBinaryOperation canonicalization #3583

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/filecheck/dialects/arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,13 @@ func.func @test_const_var_const() {
%9 = arith.cmpi uge, %int, %int : i32

"test.op"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %int) : (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i32) -> ()

// Subtraction is not commutative so should not have the constant swapped to the right
// CHECK: arith.subi %c2, %a : i32
%10 = arith.subi %c2, %a : i32
"test.op"(%10) : (i32) -> ()

// CHECK: %{{.*}} = arith.constant false
%11 = arith.constant true
%12 = arith.addi %11, %11 : i1
"test.op"(%12) : (i1) -> ()
10 changes: 5 additions & 5 deletions tests/interactive/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,11 @@ async def test_rewrites():
await pilot.click("#condense_button")

addi_pass = AvailablePass(
display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:AddiIdentityRight",
display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:SignlessIntegerBinaryOperationZeroOrUnitRight",
module_pass=individual_rewrite.ApplyIndividualRewritePass,
pass_spec=list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}'
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}'
)
)[0],
)
Expand All @@ -359,7 +359,7 @@ async def test_rewrites():
individual_rewrite.ApplyIndividualRewritePass,
list(
parse_pipeline(
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}'
'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}'
)
)[0],
),
Expand Down Expand Up @@ -568,7 +568,7 @@ async def test_apply_individual_rewrite():
n.data is not None
and n.data[1] is not None
and str(n.data[1])
== 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiConstantProp"}'
== 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationConstantProp"}'
):
node = n

Expand Down Expand Up @@ -598,7 +598,7 @@ async def test_apply_individual_rewrite():
n.data is not None
and n.data[1] is not None
and str(n.data[1])
== 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}'
== 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}'
):
node = n

Expand Down
213 changes: 180 additions & 33 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from xdsl.pattern_rewriter import RewritePattern
from xdsl.printer import Printer
from xdsl.traits import (
Commutative,
ConditionallySpeculatable,
ConstantLike,
HasCanonicalizationPatternsTrait,
Expand Down Expand Up @@ -195,6 +196,29 @@ class SignlessIntegerBinaryOperation(IRDLOperation, abc.ABC):

assembly_format = "$lhs `,` $rhs attr-dict `:` type($result)"

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return None

@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
"""
Returns True only when 'attr' is a right zero for the operation
https://en.wikipedia.org/wiki/Absorbing_element

Note that this depends on the operation and does *not* imply that
attr.value.data == 0
"""
return False

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
"""
Return True only when 'attr' is a right unit/identity for the operation
https://en.wikipedia.org/wiki/Identity_element
Comment on lines +204 to +218
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooh, if there are wikipedia pages with those titles then maybe using those for function names is fair game

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the wikipedia page for Absorbing element says that Absorbing element/annihilating element/zero are also used in various contexts

"""
return False

def __init__(
self,
operand1: Operation | SSAValue,
Expand All @@ -209,6 +233,22 @@ def __hash__(self) -> int:
return id(self)


class SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(
HasCanonicalizationPatternsTrait
):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.arith import (
SignlessIntegerBinaryOperationConstantProp,
SignlessIntegerBinaryOperationZeroOrUnitRight,
)

return (
SignlessIntegerBinaryOperationConstantProp(),
SignlessIntegerBinaryOperationZeroOrUnitRight(),
)


class SignlessIntegerBinaryOperationWithOverflow(
SignlessIntegerBinaryOperation, abc.ABC
):
Expand Down Expand Up @@ -318,22 +358,23 @@ def print(self, printer: Printer):
printer.print_attribute(self.result.type)


class AddiOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.arith import (
AddiConstantProp,
AddiIdentityRight,
)

return (AddiIdentityRight(), AddiConstantProp())


@irdl_op_definition
class AddiOp(SignlessIntegerBinaryOperationWithOverflow):
name = "arith.addi"

traits = traits_def(Pure(), AddiOpHasCanonicalizationPatternsTrait())
traits = traits_def(
Pure(),
Commutative(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs + rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
Expand Down Expand Up @@ -400,19 +441,27 @@ def infer_overflow_type(input_type: Attribute) -> Attribute:
)


class MuliHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns import arith

return (arith.MuliIdentityRight(), arith.MuliConstantProp())


@irdl_op_definition
class MuliOp(SignlessIntegerBinaryOperationWithOverflow):
name = "arith.muli"

traits = traits_def(Pure(), MuliHasCanonicalizationPatterns())
traits = traits_def(
Pure(),
Commutative(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs * rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)
Comment on lines +458 to +460
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this one taking the type into account, but the one in Addi just checks the value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one here is just to work around that the unit for i1 multiplication is -1 and not 1. This doesn't matter in the 0 case which Addi has

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it I'm not 100% sure this will work for si1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@superlopuh is there a better way to do this check?


@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


class MulExtendedBase(IRDLOperation):
Expand Down Expand Up @@ -460,7 +509,17 @@ class MulSIExtendedOp(MulExtendedBase):
class SubiOp(SignlessIntegerBinaryOperationWithOverflow):
name = "arith.subi"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs - rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


class DivUISpeculatable(ConditionallySpeculatable):
Expand All @@ -483,7 +542,15 @@ class DivUIOp(SignlessIntegerBinaryOperation):

name = "arith.divui"

traits = traits_def(NoMemoryEffect(), DivUISpeculatable())
traits = traits_def(
NoMemoryEffect(),
DivUISpeculatable(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


@irdl_op_definition
Expand All @@ -495,7 +562,14 @@ class DivSIOp(SignlessIntegerBinaryOperation):

name = "arith.divsi"

traits = traits_def(NoMemoryEffect())
traits = traits_def(
NoMemoryEffect(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


@irdl_op_definition
Expand All @@ -506,21 +580,40 @@ class FloorDivSIOp(SignlessIntegerBinaryOperation):

name = "arith.floordivsi"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


@irdl_op_definition
class CeilDivSIOp(SignlessIntegerBinaryOperation):
name = "arith.ceildivsi"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


@irdl_op_definition
class CeilDivUIOp(SignlessIntegerBinaryOperation):
name = "arith.ceildivui"

traits = traits_def(NoMemoryEffect())
traits = traits_def(
NoMemoryEffect(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)


@irdl_op_definition
Expand Down Expand Up @@ -567,21 +660,57 @@ class MaxSIOp(SignlessIntegerBinaryOperation):
class AndIOp(SignlessIntegerBinaryOperation):
name = "arith.andi"

traits = traits_def(Pure())
traits = traits_def(
Pure(),
Commutative(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs & rhs

@staticmethod
def is_right_zero(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
class OrIOp(SignlessIntegerBinaryOperation):
name = "arith.ori"

traits = traits_def(Pure())
traits = traits_def(
Pure(),
Commutative(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs | rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
class XOrIOp(SignlessIntegerBinaryOperation):
name = "arith.xori"

traits = traits_def(Pure())
traits = traits_def(
Pure(),
Commutative(),
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(),
)

@staticmethod
def py_operation(lhs: int, rhs: int) -> int | None:
return lhs ^ rhs

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
Expand All @@ -593,7 +722,13 @@ class ShLIOp(SignlessIntegerBinaryOperationWithOverflow):

name = "arith.shli"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
Expand All @@ -606,7 +741,13 @@ class ShRUIOp(SignlessIntegerBinaryOperation):

name = "arith.shrui"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


@irdl_op_definition
Expand All @@ -620,7 +761,13 @@ class ShRSIOp(SignlessIntegerBinaryOperation):

name = "arith.shrsi"

traits = traits_def(Pure())
traits = traits_def(
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait()
)

@staticmethod
def is_right_unit(attr: AnyIntegerAttr) -> bool:
return attr.value.data == 0


class ComparisonOperation(IRDLOperation):
Expand Down
Loading
Loading