-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dialects: (arith) add SignlessIntegerBinaryOperation canonicalization #3583
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ | |
from xdsl.pattern_rewriter import RewritePattern | ||
from xdsl.printer import Printer | ||
from xdsl.traits import ( | ||
Commutative, | ||
ConditionallySpeculatable, | ||
ConstantLike, | ||
HasCanonicalizationPatternsTrait, | ||
|
@@ -195,6 +196,29 @@ class SignlessIntegerBinaryOperation(IRDLOperation, abc.ABC): | |
|
||
assembly_format = "$lhs `,` $rhs attr-dict `:` type($result)" | ||
|
||
@staticmethod | ||
def py_operation(lhs: int, rhs: int) -> int | None: | ||
return None | ||
|
||
@staticmethod | ||
def is_right_zero(attr: AnyIntegerAttr) -> bool: | ||
""" | ||
Returns True only when 'attr' is a right zero for the operation | ||
https://en.wikipedia.org/wiki/Absorbing_element | ||
|
||
Note that this depends on the operation and does *not* imply that | ||
attr.value.data == 0 | ||
""" | ||
return False | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
""" | ||
Return True only when 'attr' is a right unit/identity for the operation | ||
https://en.wikipedia.org/wiki/Identity_element | ||
""" | ||
return False | ||
|
||
def __init__( | ||
self, | ||
operand1: Operation | SSAValue, | ||
|
@@ -209,6 +233,22 @@ def __hash__(self) -> int: | |
return id(self) | ||
|
||
|
||
class SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait( | ||
HasCanonicalizationPatternsTrait | ||
): | ||
@classmethod | ||
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: | ||
from xdsl.transforms.canonicalization_patterns.arith import ( | ||
SignlessIntegerBinaryOperationConstantProp, | ||
SignlessIntegerBinaryOperationZeroOrUnitRight, | ||
) | ||
|
||
return ( | ||
SignlessIntegerBinaryOperationConstantProp(), | ||
SignlessIntegerBinaryOperationZeroOrUnitRight(), | ||
) | ||
|
||
|
||
class SignlessIntegerBinaryOperationWithOverflow( | ||
SignlessIntegerBinaryOperation, abc.ABC | ||
): | ||
|
@@ -318,22 +358,23 @@ def print(self, printer: Printer): | |
printer.print_attribute(self.result.type) | ||
|
||
|
||
class AddiOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait): | ||
@classmethod | ||
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: | ||
from xdsl.transforms.canonicalization_patterns.arith import ( | ||
AddiConstantProp, | ||
AddiIdentityRight, | ||
) | ||
|
||
return (AddiIdentityRight(), AddiConstantProp()) | ||
|
||
|
||
@irdl_op_definition | ||
class AddiOp(SignlessIntegerBinaryOperationWithOverflow): | ||
name = "arith.addi" | ||
|
||
traits = traits_def(Pure(), AddiOpHasCanonicalizationPatternsTrait()) | ||
traits = traits_def( | ||
Pure(), | ||
Commutative(), | ||
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), | ||
) | ||
|
||
@staticmethod | ||
def py_operation(lhs: int, rhs: int) -> int | None: | ||
return lhs + rhs | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr.value.data == 0 | ||
|
||
|
||
@irdl_op_definition | ||
|
@@ -400,19 +441,27 @@ def infer_overflow_type(input_type: Attribute) -> Attribute: | |
) | ||
|
||
|
||
class MuliHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait): | ||
@classmethod | ||
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: | ||
from xdsl.transforms.canonicalization_patterns import arith | ||
|
||
return (arith.MuliIdentityRight(), arith.MuliConstantProp()) | ||
|
||
|
||
@irdl_op_definition | ||
class MuliOp(SignlessIntegerBinaryOperationWithOverflow): | ||
name = "arith.muli" | ||
|
||
traits = traits_def(Pure(), MuliHasCanonicalizationPatterns()) | ||
traits = traits_def( | ||
Pure(), | ||
Commutative(), | ||
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), | ||
) | ||
|
||
@staticmethod | ||
def py_operation(lhs: int, rhs: int) -> int | None: | ||
return lhs * rhs | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr == IntegerAttr(1, attr.type) | ||
Comment on lines
+458
to
+460
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this one taking the type into account, but the one in Addi just checks the value? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The one here is just to work around that the unit for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking about it I'm not 100% sure this will work for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @superlopuh is there a better way to do this check? |
||
|
||
@staticmethod | ||
def is_right_zero(attr: AnyIntegerAttr) -> bool: | ||
return attr.value.data == 0 | ||
|
||
|
||
class MulExtendedBase(IRDLOperation): | ||
|
@@ -460,7 +509,17 @@ class MulSIExtendedOp(MulExtendedBase): | |
class SubiOp(SignlessIntegerBinaryOperationWithOverflow): | ||
name = "arith.subi" | ||
|
||
traits = traits_def(Pure()) | ||
traits = traits_def( | ||
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() | ||
) | ||
|
||
@staticmethod | ||
def py_operation(lhs: int, rhs: int) -> int | None: | ||
return lhs - rhs | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr.value.data == 0 | ||
|
||
|
||
class DivUISpeculatable(ConditionallySpeculatable): | ||
|
@@ -483,7 +542,15 @@ class DivUIOp(SignlessIntegerBinaryOperation): | |
|
||
name = "arith.divui" | ||
|
||
traits = traits_def(NoMemoryEffect(), DivUISpeculatable()) | ||
traits = traits_def( | ||
NoMemoryEffect(), | ||
DivUISpeculatable(), | ||
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), | ||
) | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr == IntegerAttr(1, attr.type) | ||
|
||
|
||
@irdl_op_definition | ||
|
@@ -495,7 +562,14 @@ class DivSIOp(SignlessIntegerBinaryOperation): | |
|
||
name = "arith.divsi" | ||
|
||
traits = traits_def(NoMemoryEffect()) | ||
traits = traits_def( | ||
NoMemoryEffect(), | ||
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), | ||
) | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr == IntegerAttr(1, attr.type) | ||
|
||
|
||
@irdl_op_definition | ||
|
@@ -506,21 +580,40 @@ class FloorDivSIOp(SignlessIntegerBinaryOperation): | |
|
||
name = "arith.floordivsi" | ||
|
||
traits = traits_def(Pure()) | ||
traits = traits_def( | ||
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() | ||
) | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr == IntegerAttr(1, attr.type) | ||
|
||
|
||
@irdl_op_definition | ||
class CeilDivSIOp(SignlessIntegerBinaryOperation): | ||
name = "arith.ceildivsi" | ||
|
||
traits = traits_def(Pure()) | ||
traits = traits_def( | ||
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() | ||
) | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr == IntegerAttr(1, attr.type) | ||
|
||
|
||
@irdl_op_definition | ||
class CeilDivUIOp(SignlessIntegerBinaryOperation): | ||
name = "arith.ceildivui" | ||
|
||
traits = traits_def(NoMemoryEffect()) | ||
traits = traits_def( | ||
NoMemoryEffect(), | ||
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), | ||
) | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr == IntegerAttr(1, attr.type) | ||
|
||
|
||
@irdl_op_definition | ||
|
@@ -567,21 +660,57 @@ class MaxSIOp(SignlessIntegerBinaryOperation): | |
class AndIOp(SignlessIntegerBinaryOperation): | ||
name = "arith.andi" | ||
|
||
traits = traits_def(Pure()) | ||
traits = traits_def( | ||
Pure(), | ||
Commutative(), | ||
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), | ||
) | ||
|
||
@staticmethod | ||
def py_operation(lhs: int, rhs: int) -> int | None: | ||
return lhs & rhs | ||
|
||
@staticmethod | ||
def is_right_zero(attr: AnyIntegerAttr) -> bool: | ||
return attr.value.data == 0 | ||
|
||
|
||
@irdl_op_definition | ||
class OrIOp(SignlessIntegerBinaryOperation): | ||
name = "arith.ori" | ||
|
||
traits = traits_def(Pure()) | ||
traits = traits_def( | ||
Pure(), | ||
Commutative(), | ||
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), | ||
) | ||
|
||
@staticmethod | ||
def py_operation(lhs: int, rhs: int) -> int | None: | ||
return lhs | rhs | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr.value.data == 0 | ||
|
||
|
||
@irdl_op_definition | ||
class XOrIOp(SignlessIntegerBinaryOperation): | ||
name = "arith.xori" | ||
|
||
traits = traits_def(Pure()) | ||
traits = traits_def( | ||
Pure(), | ||
Commutative(), | ||
SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), | ||
) | ||
|
||
@staticmethod | ||
def py_operation(lhs: int, rhs: int) -> int | None: | ||
return lhs ^ rhs | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr.value.data == 0 | ||
|
||
|
||
@irdl_op_definition | ||
|
@@ -593,7 +722,13 @@ class ShLIOp(SignlessIntegerBinaryOperationWithOverflow): | |
|
||
name = "arith.shli" | ||
|
||
traits = traits_def(Pure()) | ||
traits = traits_def( | ||
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() | ||
) | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr.value.data == 0 | ||
|
||
|
||
@irdl_op_definition | ||
|
@@ -606,7 +741,13 @@ class ShRUIOp(SignlessIntegerBinaryOperation): | |
|
||
name = "arith.shrui" | ||
|
||
traits = traits_def(Pure()) | ||
traits = traits_def( | ||
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() | ||
) | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr.value.data == 0 | ||
|
||
|
||
@irdl_op_definition | ||
|
@@ -620,7 +761,13 @@ class ShRSIOp(SignlessIntegerBinaryOperation): | |
|
||
name = "arith.shrsi" | ||
|
||
traits = traits_def(Pure()) | ||
traits = traits_def( | ||
Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() | ||
) | ||
|
||
@staticmethod | ||
def is_right_unit(attr: AnyIntegerAttr) -> bool: | ||
return attr.value.data == 0 | ||
|
||
|
||
class ComparisonOperation(IRDLOperation): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ooh, if there are wikipedia pages with those titles then maybe using those for function names is fair game
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the wikipedia page for Absorbing element says that Absorbing element/annihilating element/zero are also used in various contexts