Question about offset_factor in decl_buffer

Hello,

when defining access semantics for a tensor I would expect a dependence on stride and offset. So we would get the correct element by using this formula:

idx = offset + stride_i * i + stride_j * j (assuming an ixj tensor)

If we look at decl_buffer then offset corresponds to elem_offset and strides to strides. The definition of decl_buffer also exposes an offset_factor that can be defined. How does this fit with into the definition? The examples mention that it is relevant to tensorization, which is also the use case that interests me, but not the exact meaning of the value.

The define of offset_factor is in tvm.tir — tvm 0.15.dev0 documentation.

It has two meanings:

  1. create a Var for elem_offset;
  2. require that elem_offset be a multiple of offset_factor.

Here gives an example of wmma_load_a tensorization.


@T.prim_func
def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
    C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

    with T.block("root"):
        T.reads(A[0:16, 0:16])
        T.writes(C[0:16, 0:16])
        for i, j in T.grid(16, 16):
            with T.block("load"):
                vii, vjj = T.axis.remap("SS", [i, j])
                C[vii, vjj] = A[vii, vjj]


@T.prim_func
def wmma_load_a_impl(a: T.handle, c: T.handle) -> None:
    s1 = T.int32()
    s0 = T.int32()
    A = T.match_buffer(
        a,
        (16, 16),
        "float16",
        align=128,
        offset_factor=16,
        scope="shared",
        strides=[s1, s0],
    )
    C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

    with T.block("root"):
        T.reads(A[0:16, 0:16])
        T.writes(C[0:16, 0:16])
        T.evaluate(
            T.tvm_load_matrix_sync(
                C.data,
                16,
                16,
                16,
                C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
                A.access_ptr("r"),
                s1,
                "row_major",
                dtype="handle",
            )
        )

We can see the Vars A.access_ptr(“r”) and C.elem_offset are used in wmma_load_a_impl.

If we don’t set offset_factor, then the elem_offset will be set to 0 by default. It causes

TVMError: Check failed: (analyzer_.CanProve(lhs == rhs)) is false: The buffer match constraint for A.elem_offset unmet: 0==threadIdx_y // 4 * 2048 + ax0_0 * 1024 + k_0_1 * 32 + ax1_0 * 16.

If we set offset_factor = 3 and the elem_offset can’t be divisible by 3. It causes

TVMError: Check failed: (analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) is false: The source elem_offset threadIdx_y // 4 * 2048 + ax0_0 * 1024 + k_0_1 * 32 + ax1_0 * 16 does not satisfy the offset_factor 3.
1 Like