-
Notifications
You must be signed in to change notification settings - Fork 92
Minimal Example for HeteroCL MLIR
Hongzheng Chen edited this page Nov 28, 2022
·
1 revision
In this example, we will use general matrix multiplication (GEMM) to showcase the basic usage of HeteroCL.
# 0. Initialize the data type
# If `hcl.init` is not called, hcl.Int() will be used
hcl.init(hcl.Float())
# 1. Declare placeholders for the input
A = hcl.placeholder((32, 32), "A")
B = hcl.placeholder((32, 32), "B")
# 2. Declare the algorithm
def gemm(A, B):
k = hcl.reduce_axis(0, 32, "k")
C = hcl.compute((32, 32), lambda i, j:
hcl.sum(A[i, k] * B[k, j], axis=k), "C")
return C
# 3. Create schedule
s = hcl.create_schedule([A, B], gemm)
# 4. Build the module
f = hcl.build(s)
print(s.device_module)
# 5. Execute the module
# HeteroCL needs to use the destination passing style,
# i.e., the output of the function is passed as the last argument
A = np.random.randint(10, size=(32, 32)).astype(np.float32)
B = np.random.randint(10, size=(32, 32)).astype(np.float32)
C = np.zeros((32, 32), dtype=np.float32)
hcl_A = hcl.asarray(A)
hcl_B = hcl.asarray(B)
hcl_C = hcl.asarray(C)
f(hcl_A, hcl_B, hcl_C)
golden = np.matmul(A, B)
assert np.allclose(golden, hcl_C.asnumpy())
After execution, we will see the following MLIR module.
module {
func.func @top(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) -> memref<32x32xf32> attributes {itypes = "__", llvm.emit_c_interface, otypes = "_", top} {
%0 = memref.alloc() {name = "C"} : memref<32x32xf32>
affine.for %arg2 = 0 to 32 {
affine.for %arg3 = 0 to 32 {
%1 = memref.alloc() {name = "op_0"} : memref<1xf32>
%c0_i32 = arith.constant 0 : i32
%2 = arith.sitofp %c0_i32 : i32 to f32
affine.store %2, %1[0] {to = "op_0"} : memref<1xf32>
affine.for %arg4 = 0 to 32 {
%4 = affine.load %arg0[%arg2, %arg4] {from = "A"} : memref<32x32xf32>
%5 = affine.load %arg1[%arg4, %arg3] {from = "B"} : memref<32x32xf32>
%6 = arith.mulf %4, %5 : f32
%7 = affine.load %1[0] {from = "op_0"} : memref<1xf32>
%8 = arith.addf %6, %7 : f32
affine.store %8, %1[0] {to = "op_0"} : memref<1xf32>
} {loop_name = "k", reduction}
%3 = affine.load %1[0] {from = "op_0"} : memref<1xf32>
affine.store %3, %0[%arg2, %arg3] {to = "C"} : memref<32x32xf32>
} {loop_name = "j"}
} {loop_name = "i", op_name = "C"}
return %0 : memref<32x32xf32>
}
}
We can leverage the pritimives provided by HeteroCL to easily conduct optimizations.
# 3. Create schedule
s = hcl.create_schedule([A, B], gemm)
op_C = gemm.C
x_out, x_in = s[op_C].split(op_C.axis[0], factor=8)
f = hcl.build(s)
print(s.device_module)
The corresponding MLIR module is like this. We can see the i
loop is split into two i.outer
and i.inner
loops.
module {
func.func @top(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) -> memref<32x32xf32> attributes {itypes = "__", llvm.emit_c_interface, otypes = "_", top} {
%0 = memref.alloc() {name = "C"} : memref<32x32xf32>
affine.for %arg2 = 0 to 4 {
affine.for %arg3 = 0 to 8 {
affine.for %arg4 = 0 to 32 {
%2 = affine.apply #map(%arg3, %arg2)
%3 = memref.alloc() {name = "op_0"} : memref<1xf32>
%c0_i32 = arith.constant 0 : i32
%4 = arith.sitofp %c0_i32 : i32 to f32
affine.store %4, %3[0] {to = "op_0"} : memref<1xf32>
affine.for %arg5 = 0 to 32 {
%6 = affine.load %arg0[%2, %arg5] {from = "A"} : memref<32x32xf32>
%7 = affine.load %arg1[%arg5, %arg4] {from = "B"} : memref<32x32xf32>
%8 = arith.mulf %6, %7 : f32
%9 = affine.load %3[0] {from = "op_0"} : memref<1xf32>
%10 = arith.addf %8, %9 : f32
affine.store %10, %3[0] {to = "op_0"} : memref<1xf32>
} {loop_name = "k", reduction}
%5 = affine.load %3[0] {from = "op_0"} : memref<1xf32>
affine.store %5, %0[%2, %arg4] {to = "C"} : memref<32x32xf32>
} {loop_name = "j"}
} {loop_name = "i.inner"}
} {loop_name = "i.outer", op_name = "C"}
%1 = hcl.create_op_handle "C"
return %0 : memref<32x32xf32>
}
}
We can further do splitting and reordering.
s = hcl.create_schedule([A, B], gemm)
op_C = gemm.C
x_out, x_in = s[op_C].split(op_C.axis[0], factor=8)
y_out, y_in = s[op_C].split(op_C.axis[1], factor=8)
s[op_C].reorder(x_out, y_out, x_in, y_in)
f = hcl.build(s)
And we obtain
#map = affine_map<(d0, d1) -> (d0 + d1 * 8)>
module {
func.func @top(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) -> memref<32x32xf32> attributes {itypes = "__", llvm.emit_c_interface, otypes = "_", top} {
%0 = memref.alloc() {name = "C"} : memref<32x32xf32>
affine.for %arg2 = 0 to 4 {
affine.for %arg3 = 0 to 4 {
affine.for %arg4 = 0 to 8 {
affine.for %arg5 = 0 to 8 {
%2 = affine.apply #map(%arg5, %arg3)
%3 = affine.apply #map(%arg4, %arg2)
%4 = memref.alloc() {name = "op_0"} : memref<1xf32>
%c0_i32 = arith.constant 0 : i32
%5 = arith.sitofp %c0_i32 : i32 to f32
affine.store %5, %4[0] {to = "op_0"} : memref<1xf32>
affine.for %arg6 = 0 to 32 {
%7 = affine.load %arg0[%3, %arg6] {from = "A"} : memref<32x32xf32>
%8 = affine.load %arg1[%arg6, %2] {from = "B"} : memref<32x32xf32>
%9 = arith.mulf %7, %8 : f32
%10 = affine.load %4[0] {from = "op_0"} : memref<1xf32>
%11 = arith.addf %9, %10 : f32
affine.store %11, %4[0] {to = "op_0"} : memref<1xf32>
} {loop_name = "k", reduction}
%6 = affine.load %4[0] {from = "op_0"} : memref<1xf32>
affine.store %6, %0[%3, %2] {to = "C"} : memref<32x32xf32>
} {loop_name = "j.inner"}
} {loop_name = "i.inner"}
} {loop_name = "j.outer"}
} {loop_name = "i.outer", op_name = "C"}
%1 = hcl.create_op_handle "C"
return %0 : memref<32x32xf32>
}
}