Hello,
Currently relax AlterOpImpl pass does not support non-bijective transformations. When I tried running the below example, I was seeing the assertion error.
@I.ir_module class Before: @T.prim_func def add(arg0: T.Buffer((1, 17, 32, 64), "float32"), arg1: T.Buffer((1, 17, 32, 64), "float32"), output: T.Buffer((1, 17, 32, 64), "float32")): T.func_attr({"operator_name": "relax.add"}) for ax0, ax1, ax2, ax3 in T.grid(1, 17, 32, 64): with T.block("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(arg0[v_ax0, v_ax1, v_ax2, v_ax3], arg1[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(output[v_ax0, v_ax1, v_ax2, v_ax3]) output[v_ax0, v_ax1, v_ax2, v_ax3] = arg0[v_ax0, v_ax1, v_ax2, v_ax3] + arg1[v_ax0, v_ax1, v_ax2, v_ax3] @R.function def main(x: R.Tensor((1, 17, 32, 64), dtype="float32"), y: R.Tensor((1, 17, 32, 64), dtype="float32")) -> R.Tensor((1, 17, 32, 64), dtype="float32"): with R.dataflow(): lv = R.call_tir(Before.add, (x, y), out_sinfo=R.Tensor((1, 17, 32, 64), dtype="float32")) gv: R.Tensor((1, 17, 32, 64), dtype="float32") = lv R.output(gv) return gv @T.prim_func def add_2d(arg0: T.Buffer((1, 5, 32, 64, 4), "float32"), arg1: T.Buffer((1, 5, 32, 64, 4), "float32"), output: T.Buffer((1, 5, 32, 64, 4), "float32")): T.func_attr({"operator_name": "relax.add"}) for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4): with T.block("buffer_arg1_assumptions"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4]) T.reads(arg1[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4]) T.writes() T.assume(not (v_axis1 == 4 and 1 <= v_axis4) or arg1[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] == T.float32(0)) for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4): with T.block("buffer_arg0_assumptions"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4]) T.reads(arg0[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4]) T.writes() T.assume(not (v_axis1 == 4 and 1 <= v_axis4) or arg0[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] == T.float32(0)) for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4): with T.block("T_add"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4]) T.reads(arg0[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4], arg1[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4]) T.writes(output[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4]) output[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] = T.if_then_else(v_axis1 == 4 and 1 <= v_axis4, T.float32(0), arg0[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4] + arg1[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4]) index_map = lambda n, c, h, w: (n, c // 4, h, w, c % 4) after = relax.transform.AlterOpImpl( {"relax.add": add_2d}, {"relax.add": [index_map, index_map, index_map]} )(Before)
So, I have tried commenting the checks from AlterOpImpl pass and run the example, which resulted in the below code.
@R.function def main(x: R.Tensor((1, 17, 32, 64), dtype="float32"), y: R.Tensor((1, 17, 32, 64), dtype="float32")) -> R.Tensor((1, 17, 32, 64), dtype="float32"): cls = Module with R.dataflow(): lv: R.Tensor((1, 5, 32, 64, 4), dtype="float32") = R.layout_transform(x, index_map=T.index_map(lambda n, c, h, w: (n, c // 4, h, w, c % 4)), pad_value=None) lv1: R.Tensor((1, 5, 32, 64, 4), dtype="float32") = R.layout_transform(y, index_map=T.index_map(lambda n, c, h, w: (n, c // 4, h, w, c % 4)), pad_value=None) lv2 = R.call_tir(cls.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((1, 5, 32, 64, 4), dtype="float32")) lv_1: R.Tensor((1, 20, 32, 64), dtype="float32") = R.layout_transform(lv2, index_map=T.index_map(lambda axis0, axis1, axis2, axis3, axis4: (axis0, axis1 * 4 + axis4, axis2, axis3)), pad_value=None) gv: R.Tensor((1, 20, 32, 64), dtype="float32") = lv_1 R.output(gv) return gv
The above code is invalid as the transformed output buffer shape (1, 20, 32, 64) is not matching with the expected buffer shape (1, 17, 32, 64).
I have tried introducing remove_pad op, which will map non padded buffer elements to the output buffer as shown in the below example.
@R.function def main(x: R.Tensor((1, 17, 32, 64), dtype="float32"), y: R.Tensor((1, 17, 32, 64), dtype="float32")) -> R.Tensor((1, 17, 32, 64), dtype="float32"): cls = Module with R.dataflow(): lv: R.Tensor((1, 5, 32, 64, 4), dtype="float32") = R.layout_transform(x, index_map=T.index_map(lambda n, c, h, w: (n, c // 4, h, w, c % 4)), pad_value=None) lv1: R.Tensor((1, 5, 32, 64, 4), dtype="float32") = R.layout_transform(y, index_map=T.index_map(lambda n, c, h, w: (n, c // 4, h, w, c % 4)), pad_value=None) lv2 = R.call_tir(cls.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((1, 5, 32, 64, 4), dtype="float32")) lv3: R.Tensor((1, 20, 32, 64), dtype="float32") = R.layout_transform(lv2, index_map=T.index_map(lambda axis0, axis1, axis2, axis3, axis4: (axis0, axis1 * 4 + axis4, axis2, axis3)), pad_value=None) lv_1: R.Tensor((1, 17, 32, 64), dtype="float32") = R.remove_pad(lv3, orig_shape=[1, 17, 32, 64]) gv: R.Tensor((1, 17, 32, 64), dtype="float32") = lv_1 R.output(gv) return gv
Upon legalizing the above code, the final tir looks like below.
@I.ir_module class Module: @T.prim_func def relax_add_replacement(arg0: T.Buffer((1, 5, 32, 64, 4), "float32"), arg1: T.Buffer((1, 5, 32, 64, 4), "float32"), output: T.Buffer((1, 5, 32, 64, 4), "float32")): T.func_attr({"operator_name": "relax.add"}) for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4): with T.block("buffer_arg1_assumptions"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4]) T.reads(arg1[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4]) T.writes() T.assume(not (v_axis1 == 4 and 1 <= v_axis4) or arg1[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] == T.float32(0)) for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4): with T.block("buffer_arg0_assumptions"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4]) T.reads(arg0[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4]) T.writes() T.assume(not (v_axis1 == 4 and 1 <= v_axis4) or arg0[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] == T.float32(0)) for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4): with T.block("T_add"): v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4]) T.reads(arg0[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4], arg1[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4]) T.writes(output[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4]) output[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] = T.if_then_else(v_axis1 == 4 and 1 <= v_axis4, T.float32(0), arg0[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4] + arg1[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4]) @T.prim_func def remove_pad(A: T.Buffer((T.int64(1), T.int64(20), T.int64(32), T.int64(64)), "float32"), T_remove_pad: T.Buffer((T.int64(1), T.int64(17), T.int64(32), T.int64(64)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(17), T.int64(32), T.int64(64)): with T.block("T_remove_pad"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(T_remove_pad[v_ax0, v_ax1, v_ax2, v_ax3]) T_remove_pad[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3] @T.prim_func def te_layout_transform(A: T.Buffer((T.int64(1), T.int64(17), T.int64(32), T.int64(64)), "float32"), te_layout_transform_with_pad: T.Buffer((T.int64(1), T.int64(5), T.int64(32), T.int64(64), T.int64(4)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) for i0, i1, i2, i3, i4 in T.grid(T.int64(1), T.int64(5), T.int64(32), T.int64(64), T.int64(4)): with T.block("te_layout_transform_with_pad"): v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(A[v_i0, v_i1 * T.int64(4) + v_i4, v_i2, v_i3]) T.writes(te_layout_transform_with_pad[v_i0, v_i1, v_i2, v_i3, v_i4]) te_layout_transform_with_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(v_i1 == T.int64(4) and T.int64(1) <= v_i4, T.float32(0), A[v_i0, v_i1 * T.int64(4) + v_i4, v_i2, v_i3]) @T.prim_func def te_layout_transform1(A: T.Buffer((T.int64(1), T.int64(5), T.int64(32), T.int64(64), T.int64(4)), "float32"), te_layout_transform: T.Buffer((T.int64(1), T.int64(20), T.int64(32), T.int64(64)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(20), T.int64(32), T.int64(64)): with T.block("te_layout_transform"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(A[v_i0, v_i1 // T.int64(4), v_i2, v_i3, v_i1 % T.int64(4)]) T.writes(te_layout_transform[v_i0, v_i1, v_i2, v_i3]) te_layout_transform[v_i0, v_i1, v_i2, v_i3] = A[v_i0, v_i1 // T.int64(4), v_i2, v_i3, v_i1 % T.int64(4)] @R.function def main(x: R.Tensor((1, 17, 32, 64), dtype="float32"), y: R.Tensor((1, 17, 32, 64), dtype="float32")) -> R.Tensor((1, 17, 32, 64), dtype="float32"): cls = Module with R.dataflow(): lv = R.call_tir(cls.te_layout_transform, (x,), out_sinfo=R.Tensor((1, 5, 32, 64, 4), dtype="float32")) lv1 = R.call_tir(cls.te_layout_transform, (y,), out_sinfo=R.Tensor((1, 5, 32, 64, 4), dtype="float32")) lv2 = R.call_tir(cls.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((1, 5, 32, 64, 4), dtype="float32")) lv3 = R.call_tir(cls.te_layout_transform1, (lv2,), out_sinfo=R.Tensor((1, 20, 32, 64), dtype="float32")) lv_1 = R.call_tir(cls.remove_pad, (lv3,), out_sinfo=R.Tensor((1, 17, 32, 64), dtype="float32")) gv: R.Tensor((1, 17, 32, 64), dtype="float32") = lv_1 R.output(gv) return gv
Can someone please tell me if this design to introduce new op(remove_pad) and eliminate the padded entries is an appropriate design for handing non-bijective transformations in AlterOpImpl?