Skip to content

Commit

Permalink
transforms: use rewriter and listener in convert-stencil-to-csl-stenc…
Browse files Browse the repository at this point in the history
…il (xdslproject#3538)

Stacked PRs:
 * xdslproject#3540
 * xdslproject#3539
 * __->__#3538
 * xdslproject#3537


--- --- ---

### transforms: use rewriter and listener in
convert-stencil-to-csl-stencil


The pass was not propagating the listener from the PatternRewriter, and
thus some operations were modified without notifying the rewrite
worklist.
  • Loading branch information
math-fehr authored and EdmundGoodman committed Dec 6, 2024
1 parent 66e761e commit d03b6e8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ builtin.module {

// CHECK-NEXT: func.func @coefficients(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({
// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>]}> ({
// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>):
// CHECK-NEXT: %5 = arith.constant dense<1.234500e-01> : tensor<510xf32>
// CHECK-NEXT: %6 = arith.constant dense<2.345678e-01> : tensor<510xf32>
Expand Down Expand Up @@ -191,7 +191,7 @@ builtin.module {
// CHECK-NEXT: %2 = arith.constant 1 : index
// CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) {
// CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32>
// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [2, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [-2, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, 2]>, #csl_stencil.exchange<to [0, -1]>, #csl_stencil.exchange<to [0, -2]>], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 1>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>]}> ({
// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [2, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [-2, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, 2]>, #csl_stencil.exchange<to [0, -1]>, #csl_stencil.exchange<to [0, -2]>], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 1>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>]}> ({
// CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>):
// CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32>
// CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<600xf32>
Expand Down
13 changes: 8 additions & 5 deletions xdsl/transforms/convert_stencil_to_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
# replace stencil.access (operating on stencil.temp at arg_index)
# with csl_stencil.access (operating on memref at last arg index)
nested_rewriter = PatternRewriteWalker(
ConvertAccessOpFromPrefetchPattern(arg_idx)
ConvertAccessOpFromPrefetchPattern(arg_idx), listener=rewriter
)

nested_rewriter.rewrite_region(new_apply_op.region)
Expand Down Expand Up @@ -415,6 +415,7 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /):
PatternRewriteWalker(
SplitVarithOpPattern(op.region.block.args[prefetch_idx]),
apply_recursively=False,
listener=rewriter,
).rewrite_region(op.region)

# determine how ops should be split across the two regions
Expand Down Expand Up @@ -505,12 +506,13 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /):
# add operations from list to receive_chunk, use translation table to rebuild operands
for o in chunk_region_ops:
if isinstance(o, stencil.ReturnOp | csl_stencil.YieldOp):
rewriter.erase_op(o)
break
o.operands = [chunk_region_oprnd_table.get(x, x) for x in o.operands]
receive_chunk.block.add_op(o)
rewriter.insert_op(o, InsertPoint.at_end(receive_chunk.block))

# put `chunk_res` into `accumulator` (using tensor.insert_slice) and yield the result
receive_chunk.block.add_ops(
rewriter.insert_op(
[
insert_slice_op := tensor.InsertSliceOp.get(
source=chunk_res,
Expand All @@ -519,13 +521,14 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /):
static_sizes=(prefetch.type.get_shape()[1] // self.num_chunks,),
),
csl_stencil.YieldOp(insert_slice_op.result),
]
],
InsertPoint.at_end(receive_chunk.block),
)

# add operations from list to done_exchange, use translation table to rebuild operands
for o in done_exchange_ops:
o.operands = [done_exchange_oprnd_table.get(x, x) for x in o.operands]
done_exchange.block.add_op(o)
rewriter.insert_op(o, InsertPoint.at_end(done_exchange.block))
if isinstance(o, stencil.ReturnOp):
rewriter.replace_op(o, csl_stencil.YieldOp(*o.operands))

Expand Down

0 comments on commit d03b6e8

Please sign in to comment.