Non-bijective transformation support in AlterOpImpl pass

Hello,

Currently relax AlterOpImpl pass does not support non-bijective transformations. When I tried running the below example, I was seeing the assertion error.

@I.ir_module
class Before:
    @T.prim_func
    def add(arg0: T.Buffer((1, 17, 32, 64), "float32"), arg1: T.Buffer((1, 17, 32, 64), "float32"), output: T.Buffer((1, 17, 32, 64), "float32")):
        T.func_attr({"operator_name": "relax.add"})
        for ax0, ax1, ax2, ax3 in T.grid(1, 17, 32, 64):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(arg0[v_ax0, v_ax1, v_ax2, v_ax3], arg1[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(output[v_ax0, v_ax1, v_ax2, v_ax3])
                output[v_ax0, v_ax1, v_ax2, v_ax3] = arg0[v_ax0, v_ax1, v_ax2, v_ax3] + arg1[v_ax0, v_ax1, v_ax2, v_ax3]

    @R.function
    def main(x: R.Tensor((1, 17, 32, 64), dtype="float32"), y: R.Tensor((1, 17, 32, 64), dtype="float32")) -> R.Tensor((1, 17, 32, 64), dtype="float32"):
        with R.dataflow():
            lv = R.call_tir(Before.add, (x, y), out_sinfo=R.Tensor((1, 17, 32, 64), dtype="float32"))
            gv: R.Tensor((1, 17, 32, 64), dtype="float32") = lv
            R.output(gv)
        return gv

@T.prim_func
def add_2d(arg0: T.Buffer((1, 5, 32, 64, 4), "float32"), arg1: T.Buffer((1, 5, 32, 64, 4), "float32"), output: T.Buffer((1, 5, 32, 64, 4), "float32")):
    T.func_attr({"operator_name": "relax.add"})
    for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4):
        with T.block("buffer_arg1_assumptions"):
            v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4])
            T.reads(arg1[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4])
            T.writes()
            T.assume(not (v_axis1 == 4 and 1 <= v_axis4) or arg1[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] == T.float32(0))
    for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4):
        with T.block("buffer_arg0_assumptions"):
            v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4])
            T.reads(arg0[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4])
            T.writes()
            T.assume(not (v_axis1 == 4 and 1 <= v_axis4) or arg0[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] == T.float32(0))
    for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4):
        with T.block("T_add"):
            v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4])
            T.reads(arg0[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4], arg1[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4])
            T.writes(output[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4])
            output[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] = T.if_then_else(v_axis1 == 4 and 1 <= v_axis4, T.float32(0), arg0[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4] + arg1[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4])

index_map = lambda n, c, h, w: (n, c // 4, h, w, c % 4)
after = relax.transform.AlterOpImpl(
    {"relax.add": add_2d}, {"relax.add": [index_map, index_map, index_map]}
)(Before)

So, I have tried commenting the checks from AlterOpImpl pass and run the example, which resulted in the below code.

@R.function
def main(x: R.Tensor((1, 17, 32, 64), dtype="float32"), y: R.Tensor((1, 17, 32, 64), dtype="float32")) -> R.Tensor((1, 17, 32, 64), dtype="float32"):
    cls = Module
    with R.dataflow():
        lv: R.Tensor((1, 5, 32, 64, 4), dtype="float32") = R.layout_transform(x, index_map=T.index_map(lambda n, c, h, w: (n, c // 4, h, w, c % 4)), pad_value=None)
        lv1: R.Tensor((1, 5, 32, 64, 4), dtype="float32") = R.layout_transform(y, index_map=T.index_map(lambda n, c, h, w: (n, c // 4, h, w, c % 4)), pad_value=None)
        lv2 = R.call_tir(cls.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((1, 5, 32, 64, 4), dtype="float32"))
        lv_1: R.Tensor((1, 20, 32, 64), dtype="float32") = R.layout_transform(lv2, index_map=T.index_map(lambda axis0, axis1, axis2, axis3, axis4: (axis0, axis1 * 4 + axis4, axis2, axis3)), pad_value=None)
        gv: R.Tensor((1, 20, 32, 64), dtype="float32") = lv_1
        R.output(gv)
    return gv

The above code is invalid as the transformed output buffer shape (1, 20, 32, 64) is not matching with the expected buffer shape (1, 17, 32, 64).

I have tried introducing remove_pad op, which will map non padded buffer elements to the output buffer as shown in the below example.

@R.function
def main(x: R.Tensor((1, 17, 32, 64), dtype="float32"), y: R.Tensor((1, 17, 32, 64), dtype="float32")) -> R.Tensor((1, 17, 32, 64), dtype="float32"):
    cls = Module
    with R.dataflow():
        lv: R.Tensor((1, 5, 32, 64, 4), dtype="float32") = R.layout_transform(x, index_map=T.index_map(lambda n, c, h, w: (n, c // 4, h, w, c % 4)), pad_value=None)
        lv1: R.Tensor((1, 5, 32, 64, 4), dtype="float32") = R.layout_transform(y, index_map=T.index_map(lambda n, c, h, w: (n, c // 4, h, w, c % 4)), pad_value=None)
        lv2 = R.call_tir(cls.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((1, 5, 32, 64, 4), dtype="float32"))
        lv3: R.Tensor((1, 20, 32, 64), dtype="float32") = R.layout_transform(lv2, index_map=T.index_map(lambda axis0, axis1, axis2, axis3, axis4: (axis0, axis1 * 4 + axis4, axis2, axis3)), pad_value=None)
        lv_1: R.Tensor((1, 17, 32, 64), dtype="float32") = R.remove_pad(lv3, orig_shape=[1, 17, 32, 64])
        gv: R.Tensor((1, 17, 32, 64), dtype="float32") = lv_1
        R.output(gv)
    return gv

Upon legalizing the above code, the final tir looks like below.

@I.ir_module
class Module:
    @T.prim_func
    def relax_add_replacement(arg0: T.Buffer((1, 5, 32, 64, 4), "float32"), arg1: T.Buffer((1, 5, 32, 64, 4), "float32"), output: T.Buffer((1, 5, 32, 64, 4), "float32")):
        T.func_attr({"operator_name": "relax.add"})
        for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4):
            with T.block("buffer_arg1_assumptions"):
                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4])
                T.reads(arg1[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4])
                T.writes()
                T.assume(not (v_axis1 == 4 and 1 <= v_axis4) or arg1[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] == T.float32(0))
        for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4):
            with T.block("buffer_arg0_assumptions"):
                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4])
                T.reads(arg0[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4])
                T.writes()
                T.assume(not (v_axis1 == 4 and 1 <= v_axis4) or arg0[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] == T.float32(0))
        for axis0, axis1, axis2, axis3, axis4 in T.grid(1, 5, 32, 64, 4):
            with T.block("T_add"):
                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap("SSSSS", [axis0, axis1, axis2, axis3, axis4])
                T.reads(arg0[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4], arg1[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4])
                T.writes(output[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4])
                output[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4] = T.if_then_else(v_axis1 == 4 and 1 <= v_axis4, T.float32(0), arg0[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4] + arg1[v_axis0, (v_axis1 * 4 + v_axis4) // 4, v_axis2, v_axis3, (v_axis1 * 4 + v_axis4) % 4])

@T.prim_func
    def remove_pad(A: T.Buffer((T.int64(1), T.int64(20), T.int64(32), T.int64(64)), "float32"), T_remove_pad: T.Buffer((T.int64(1), T.int64(17), T.int64(32), T.int64(64)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(17), T.int64(32), T.int64(64)):
            with T.block("T_remove_pad"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(T_remove_pad[v_ax0, v_ax1, v_ax2, v_ax3])
                T_remove_pad[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]

@T.prim_func
    def te_layout_transform(A: T.Buffer((T.int64(1), T.int64(17), T.int64(32), T.int64(64)), "float32"), te_layout_transform_with_pad: T.Buffer((T.int64(1), T.int64(5), T.int64(32), T.int64(64), T.int64(4)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        for i0, i1, i2, i3, i4 in T.grid(T.int64(1), T.int64(5), T.int64(32), T.int64(64), T.int64(4)):
            with T.block("te_layout_transform_with_pad"):
                v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
                T.reads(A[v_i0, v_i1 * T.int64(4) + v_i4, v_i2, v_i3])
                T.writes(te_layout_transform_with_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
                te_layout_transform_with_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(v_i1 == T.int64(4) and T.int64(1) <= v_i4, T.float32(0), A[v_i0, v_i1 * T.int64(4) + v_i4, v_i2, v_i3])

@T.prim_func
    def te_layout_transform1(A: T.Buffer((T.int64(1), T.int64(5), T.int64(32), T.int64(64), T.int64(4)), "float32"), te_layout_transform: T.Buffer((T.int64(1), T.int64(20), T.int64(32), T.int64(64)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(20), T.int64(32), T.int64(64)):
            with T.block("te_layout_transform"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(A[v_i0, v_i1 // T.int64(4), v_i2, v_i3, v_i1 % T.int64(4)])
                T.writes(te_layout_transform[v_i0, v_i1, v_i2, v_i3])
                te_layout_transform[v_i0, v_i1, v_i2, v_i3] = A[v_i0, v_i1 // T.int64(4), v_i2, v_i3, v_i1 % T.int64(4)]

@R.function
    def main(x: R.Tensor((1, 17, 32, 64), dtype="float32"), y: R.Tensor((1, 17, 32, 64), dtype="float32")) -> R.Tensor((1, 17, 32, 64), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.te_layout_transform, (x,), out_sinfo=R.Tensor((1, 5, 32, 64, 4), dtype="float32"))
            lv1 = R.call_tir(cls.te_layout_transform, (y,), out_sinfo=R.Tensor((1, 5, 32, 64, 4), dtype="float32"))
            lv2 = R.call_tir(cls.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((1, 5, 32, 64, 4), dtype="float32"))
            lv3 = R.call_tir(cls.te_layout_transform1, (lv2,), out_sinfo=R.Tensor((1, 20, 32, 64), dtype="float32"))
            lv_1 = R.call_tir(cls.remove_pad, (lv3,), out_sinfo=R.Tensor((1, 17, 32, 64), dtype="float32"))
            gv: R.Tensor((1, 17, 32, 64), dtype="float32") = lv_1
            R.output(gv)
        return gv

Can someone please tell me if this design to introduce new op(remove_pad) and eliminate the padded entries is an appropriate design for handing non-bijective transformations in AlterOpImpl?

cc: @psrivas2 @sanirudh @jverma @abhikran-quic

Hi @junrushao,

I have gone through your PR (https://github.com/apache/tvm/pull/15262). Is there another follow up PR in plan to solve NonSurjective inverse?

Thanks, Rahul