# Porting conditional tensor transformation from TE to TensorIR

Hello,

I’ve recently been trying to gain a better understanding of TensorIR by porting over a TE conv2d schedule I have written. However, I cannot figure out how to express the following use of the variable `B` to conditionally represent either the transformed or the unchanged weight matrix:

``````if transform_B:
B_flat = te.compute(
B_flat_shape,
lambda x, y: B_in[...],
name="B_flat"
)
B = te.compute(
B_shape,
lambda xo, yo, xi, yi: B_flat[...],
name="B_t_interleaved"
)
else:
B = B_in
``````

My current attempt is the following:

``````@T.prim_func
def compute(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, data_a.shape, dtype=data_a.dtype)
B_in = T.match_buffer(b, data_b.shape, dtype=data_b.dtype)
out = T.match_buffer(c, (OH, OW, OC), dtype=out_dtype)

B = T.decl_buffer((b_xo, b_yo, tile_cols_B, tile_rows_B), dtype=data_b.dtype)

# B matrix transformation
if transform_B:
with T.block("B_transform"):
.......................
B_flat = T.alloc_buffer(...)
for ...
B_flat[...] =  B_in[...]
for ...
B[...] = B_flat[...]
.......................
else:
B = T.match_buffer(B_in[:,:,:,:], shape = data_b.shape)

# Matrix-Matrix multiplication
with T.block("gemm"):
v_x, v_y, v_k = T.axis.remap("SSR", [x, y, k])
with T.init():
C[v_x, v_y] = T.float32(0)
C[v_x, v_y] += A[v_x, v_k] * B[...]
.......................
``````

Although the STIR successfully compiles when `transform_B = False`, the `B` buffer does not seem to point to the unchanged weight matrix as I expected. This is a dump of the TIR:

``````@T.prim_func
def tvmgen_default_fused_nn_contrib_conv2d_gemm_without_weight_transform(data: T.Buffer((T.int64(1), T.int64(128), T.int64(128), T.int64(8)), "float32"), B_in: T.Buffer((T.int64(1), T.int64(20), T.int64(4), T.int64(16)), "float32")):
.......................
B = T.allocate([288], "float32x4", "global")
.......................
for b_x_0_fused in T.parallel(1024):
.......................
for k_0, k_1 in T.grid(18, 4):
B_1 = T.Buffer((T.int64(288),), "float32x4", data=B)
C_1[b_x_0_fused * 64:b_x_0_fused * 64 + 4] = C_1[b_x_0_fused * 64:b_x_0_fused * 64 + 4] + T.Broadcast(data_im2col_1[b_x_0_fused * 288 + k_0 * 4 + k_1], 4) * B_1[k_0 * 16 + k_1 * 4]
.......................
``````

I noticed that moving the `B = T.match_buffer(...)` outside of the if statement produces the behaviour I expected from it, but it would also nullify the if block entirely, of course.

``````@T.prim_func
def tvmgen_default_fused_nn_contrib_conv2d_gemm_without_weight_transform(data: T.Buffer((T.int64(1), T.int64(128), T.int64(128), T.int64(8)), "float32"), B_in: T.Buffer((T.int64(1), T.int64(20), T.int64(4), T.int64(16)), "float32")):
.......................
for b_x_0_fused in T.parallel(1024):
.......................
for k_0, k_1 in T.grid(18, 4):
B_in_1 = T.Buffer((T.int64(1280),), data=B_in.data)
C_1[b_x_0_fused * 64:b_x_0_fused * 64 + 4] = C_1[b_x_0_fused * 64:b_x_0_fused * 64 + 4] + T.Broadcast(data_im2col_1[b_x_0_fused * 288 + k_0 * 4 + k_1], 4) * B_in_1[k_0 * 64 + k_1 * 16:k_0 * 64 + k_1 * 16 + 4]
.......................
``````

I am not sure how else to try redefining `B` such that it would conditionally represent either buffer/tensor going forward in the schedule. Any help would be greatly appreciated. Thank you!

@Hzfengsy @wrongtest is this a bug? Any suggestions?

All branches in TIR are RUNTIME branches. That means it’s hard (even impossible) to write branches at compilation stages, inside the PrimFunc.

On the other side, we still encourage users to use TE compute (not TE schedule) for TensorIR frontend. The API `te.create_prim_func` will convert the te compute to STIR

4 Likes

Building on @Hzfengsy’s reply, if you still need to implement something in TIR directly for some reason (perhaps the op is too complex to represent in TE like the NMS op), I think we can take advantage of `T.macro` to write conditional macros and build the primfunc with these macros.

For example, something like the below code might work:

``````import tvm
from tvm import tir
from tvm.script import tir as T

def get_func(some_condition):
if some_condition:
@T.macro
def assign(i, *args, t1, **kwargs):
vi, vj, vk = T.axis.remap("SSR", [i, args[0], args[1]])
kwargs["t3"][vi, vj] = kwargs["t3"][vi, vj] + t1[vi, vk] * kwargs["t2"][vj, vk]
else:
@T.macro
def assign(i, *args, t1, **kwargs):
T.evaluate(0)

@T.prim_func(private=True)
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
assign(i, j, k, t1=A, t2=B, t3=C)

return matmul

def main():
func = get_func(some_condition=True)
print(func)
func = get_func(some_condition=False)
print(func)

if __name__ == '__main__':
main()
``````
3 Likes