Sorry, I just started contacting tvm. I noticed that in order to speed up conv2d, the necessary transforms for the input and weight will be required:
def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
KH, KW, IC, OC = get_const_tuple(kernel.shape)
K = KH * KW * IC
N = OC
kernel_flat = te.compute(
(K, N), lambda x, y: kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y], "weight_flatten"
)
pad_K = 0
pad_N = 0
if N % tile_rows != 0:
pad_N = tile_rows - (N % tile_rows)
if K % tile_cols != 0:
pad_K = tile_cols - (K % tile_cols)
N_padded = N + pad_N
K_padded = K + pad_K
if pad_K != 0 or pad_N != 0:
kernel_flat = pad(
kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding"
)
return te.compute(
(N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols),
lambda x, y, z, w: kernel_flat[w + tile_cols * y, z + tile_rows * x],
name="weight_block_reshape",
)
You can see that this is a te.compute, and you can still see this calculation expression at the TIR level,below is the printed TIR:
// attr [A_padded] storage_scope = "global"
allocate A_padded[uint8 * 1605632]
// attr [weight_block_reshape] storage_scope = "global"
allocate weight_block_reshape[uint8 * 1024]
// attr [C] storage_scope = "global"
allocate C[int32 * 1605632]
// attr [iter_var(pipeline, , pipeline)] pipeline_exec_scope = 1
for (i1, 0, 50176) {
for (i2, 0, 32) {
A_padded[((i1*32) + i2)] = tir.if_then_else((i2 < 12), placeholder[((((floordiv(i1, 224)*675) + (floordiv(i2, 6)*675)) + (floormod(i1, 224)*3)) + floormod(i2, 6))], (uint8)0)
}
}
for (x.y.fused, 0, 8) {
for (z, 0, 32) {
weight_block_reshape[((x.y.fused*128) + (z*4))] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[((x.y.fused*12) + z)], (uint8)0)
weight_block_reshape[(((x.y.fused*128) + (z*4)) + 1)] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[(((x.y.fused*12) + z) + 3)], (uint8)0)
weight_block_reshape[(((x.y.fused*128) + (z*4)) + 2)] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[(((x.y.fused*12) + z) + 6)], (uint8)0)
weight_block_reshape[(((x.y.fused*128) + (z*4)) + 3)] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[(((x.y.fused*12) + z) + 9)], (uint8)0)
}
}
This is just a transformation of the data, why not optimize it? In other words, LLVM will be optimized in the future?