The define of offset_factor is in tvm.tir — tvm 0.15.dev0 documentation.
It has two meanings:
- create a Var for elem_offset;
- require that elem_offset be a multiple of offset_factor.
Here gives an example of wmma_load_a tensorization.
@T.prim_func
def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")
with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
for i, j in T.grid(16, 16):
with T.block("load"):
vii, vjj = T.axis.remap("SS", [i, j])
C[vii, vjj] = A[vii, vjj]
@T.prim_func
def wmma_load_a_impl(a: T.handle, c: T.handle) -> None:
s1 = T.int32()
s0 = T.int32()
A = T.match_buffer(
a,
(16, 16),
"float16",
align=128,
offset_factor=16,
scope="shared",
strides=[s1, s0],
)
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")
with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_load_matrix_sync(
C.data,
16,
16,
16,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.access_ptr("r"),
s1,
"row_major",
dtype="handle",
)
)
We can see the Vars A.access_ptr(“r”) and C.elem_offset are used in wmma_load_a_impl.
If we don’t set offset_factor, then the elem_offset will be set to 0 by default. It causes
TVMError: Check failed: (analyzer_.CanProve(lhs == rhs)) is false: The buffer match constraint for A.elem_offset unmet: 0==threadIdx_y // 4 * 2048 + ax0_0 * 1024 + k_0_1 * 32 + ax1_0 * 16.
If we set offset_factor = 3 and the elem_offset can’t be divisible by 3. It causes
TVMError: Check failed: (analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) is false: The source elem_offset threadIdx_y // 4 * 2048 + ax0_0 * 1024 + k_0_1 * 32 + ax1_0 * 16 does not satisfy the offset_factor 3.