[SOLVED] Layout transform not bijective affine

Can anyone help me understand why this layout transformation is not bijective affine?

sch = tvm.tir.Schedule(Module)
input_vol = matrixwdm * update_rate
block_mm = sch.get_block("T_matmul")

sch.transform_block_layout(
  block_mm,
  index_map=lambda i, j, k: (i // 8, j // 8, k // 8, i % 8, j % 8, k % 8))

Error:

Error message: The index map lambda i, j, k: (i // T.int64(8), j // T.int64(8), k // T.int64(8), i % T.int64(8), j % T.int64(8), k % T.int64(8),) is not bijective affine.

Context:

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(var_A: T.handle, var_B: T.handle, var_T_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        A = T.match_buffer(var_A, (T.int64(1), T.int64(262144)))
        B = T.match_buffer(var_B, (T.int64(1024), T.int64(262144)))
        T_matmul = T.match_buffer(var_T_matmul, (T.int64(1), T.int64(1024)))
        with T.block("root"):
            T.reads()
            T.writes()
            for ax0 in range(T.int64(1)):
                for ax1 in range(T.int64(1024)):
                    for k in range(T.int64(262144)):
                        with T.block("T_matmul"):
                            v_ax0 = T.axis.spatial(T.int64(1), ax0)
                            v_ax1 = T.axis.spatial(T.int64(1024), ax1)
                            v_k = T.axis.reduce(T.int64(262144), k)
                            T.reads(A[v_ax0, v_k], B[v_ax1, v_k])
                            T.writes(T_matmul[v_ax0, v_ax1])
                            with T.init():
                                T_matmul[v_ax0, v_ax1] = T.float32(0)
                            T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k] * B[v_ax1, v_k]

IIUC, that’s because in your block "T_matmul", ax0 (which represents i in your transform_layout index_map) has a range of just 1. Thus, when you map this to i//8 and i%8, you would not get a bijective mapping.

In order to get a bijective mapping, ax0 and in turn, the range of first dimension of buffer A has to be a multiple of 8, thus ensuring that you map to all indices in the range 0 to 7 for i%8.

You can probably quickly try this out by changing the line that changes shape of A and ax0 to 8 as shown below

@I.ir_module
class Module:
    @T.prim_func
    def main(var_A: T.handle, var_B: T.handle, var_T_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        A = T.match_buffer(var_A, (T.int64(8), T.int64(262144)))
        B = T.match_buffer(var_B, (T.int64(1024), T.int64(262144)))
        T_matmul = T.match_buffer(var_T_matmul, (T.int64(1), T.int64(1024)))
        with T.block("root"):
            T.reads()
            T.writes()
            for ax0 in range(T.int64(8)):
                for ax1 in range(T.int64(1024)):
                    for k in range(T.int64(262144)):
                        with T.block("T_matmul"):
                            v_ax0 = T.axis.spatial(T.int64(8), ax0)
                            v_ax1 = T.axis.spatial(T.int64(1024), ax1)
                            v_k = T.axis.reduce(T.int64(262144), k)
                            T.reads(A[v_ax0, v_k], B[v_ax1, v_k])
                            T.writes(T_matmul[v_ax0, v_ax1])
                            with T.init():
                                T_matmul[v_ax0, v_ax1] = T.float32(0)
                            T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k] * B[v_ax1, v_k]

and you should not get the non-bijective error anymore.

1 Like

Awesome, thank you for the detailed answer :slight_smile: