Hello,
I’ve recently been trying to gain a better understanding of TensorIR by porting over a TE conv2d schedule I have written. However, I cannot figure out how to express the following use of the variable B
to conditionally represent either the transformed or the unchanged weight matrix:
if transform_B:
B_flat = te.compute(
B_flat_shape,
lambda x, y: B_in[...],
name="B_flat"
)
B = te.compute(
B_shape,
lambda xo, yo, xi, yi: B_flat[...],
name="B_t_interleaved"
)
else:
B = B_in
My current attempt is the following:
@T.prim_func
def compute(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, data_a.shape, dtype=data_a.dtype)
B_in = T.match_buffer(b, data_b.shape, dtype=data_b.dtype)
out = T.match_buffer(c, (OH, OW, OC), dtype=out_dtype)
B = T.decl_buffer((b_xo, b_yo, tile_cols_B, tile_rows_B), dtype=data_b.dtype)
# B matrix transformation
if transform_B:
with T.block("B_transform"):
.......................
B_flat = T.alloc_buffer(...)
for ...
B_flat[...] = B_in[...]
for ...
B[...] = B_flat[...]
.......................
else:
B = T.match_buffer(B_in[:,:,:,:], shape = data_b.shape)
# Matrix-Matrix multiplication
for x, y, k in T.grid(M_padded, N_padded, K_padded):
with T.block("gemm"):
v_x, v_y, v_k = T.axis.remap("SSR", [x, y, k])
with T.init():
C[v_x, v_y] = T.float32(0)
C[v_x, v_y] += A[v_x, v_k] * B[...]
.......................
Although the STIR successfully compiles when transform_B = False
, the B
buffer does not seem to point to the unchanged weight matrix as I expected. This is a dump of the TIR:
@T.prim_func
def tvmgen_default_fused_nn_contrib_conv2d_gemm_without_weight_transform(data: T.Buffer((T.int64(1), T.int64(128), T.int64(128), T.int64(8)), "float32"), B_in: T.Buffer((T.int64(1), T.int64(20), T.int64(4), T.int64(16)), "float32")):
.......................
B = T.allocate([288], "float32x4", "global")
.......................
for b_x_0_fused in T.parallel(1024):
.......................
for k_0, k_1 in T.grid(18, 4):
B_1 = T.Buffer((T.int64(288),), "float32x4", data=B)
C_1[b_x_0_fused * 64:b_x_0_fused * 64 + 4] = C_1[b_x_0_fused * 64:b_x_0_fused * 64 + 4] + T.Broadcast(data_im2col_1[b_x_0_fused * 288 + k_0 * 4 + k_1], 4) * B_1[k_0 * 16 + k_1 * 4]
.......................
I noticed that moving the B = T.match_buffer(...)
outside of the if statement produces the behaviour I expected from it, but it would also nullify the if block entirely, of course.
@T.prim_func
def tvmgen_default_fused_nn_contrib_conv2d_gemm_without_weight_transform(data: T.Buffer((T.int64(1), T.int64(128), T.int64(128), T.int64(8)), "float32"), B_in: T.Buffer((T.int64(1), T.int64(20), T.int64(4), T.int64(16)), "float32")):
.......................
for b_x_0_fused in T.parallel(1024):
.......................
for k_0, k_1 in T.grid(18, 4):
B_in_1 = T.Buffer((T.int64(1280),), data=B_in.data)
C_1[b_x_0_fused * 64:b_x_0_fused * 64 + 4] = C_1[b_x_0_fused * 64:b_x_0_fused * 64 + 4] + T.Broadcast(data_im2col_1[b_x_0_fused * 288 + k_0 * 4 + k_1], 4) * B_in_1[k_0 * 64 + k_1 * 16:k_0 * 64 + k_1 * 16 + 4]
.......................
I am not sure how else to try redefining B
such that it would conditionally represent either buffer/tensor going forward in the schedule. Any help would be greatly appreciated. Thank you!