How to stitch kernels with shared memory in TVM

I designed a simple relay program and output its TIR. I can see that there are currently 5 kernels generated, and I wonder if I can integrate them into a kernel (similar to the stitch fusion work of Alibaba’s team) this is my relay program

Visualize the output of relay

TIR

primfn(p0_1: handle, T_add_1: handle) → () attr = {“from_legacy_te_schedule”: True, “global_symbol”: “tvmgen_default_fused_add”, “tir.noalias”: True} buffers = {p0: Buffer(p0_2: Pointer(float32), float32, [32768], ), T_add: Buffer(T_add_2: Pointer(float32), float32, [32768], )} buffer_map = {p0_1: p0, T_add_1: T_add} preflattened_buffer_map = {p0_1: p0_3: Buffer(p0_2, float32, [32, 32, 32], ), T_add_1: T_add_3: Buffer(T_add_2, float32, [32, 32, 32], )} { for (ax0.ax1.fused: int32, 0, 1024) “parallel” { for (ax2.outer: int32, 0, 2) { T_add[ramp(((ax0.ax1.fused32) + (ax2.outer16)), 1, 16)] = (p0[ramp(((ax0.ax1.fused32) + (ax2.outer16)), 1, 16)] + p0[ramp(((ax0.ax1.fused32) + (ax2.outer16)), 1, 16)]) } } }

primfn(p0_1: handle, T_divide_1: handle) → () attr = {“from_legacy_te_schedule”: True, “global_symbol”: “tvmgen_default_fused_mean”, “tir.noalias”: True} buffers = {p0: Buffer(p0_2: Pointer(float32), float32, [32768], ), T_divide: Buffer(T_divide_2: Pointer(float32), float32, [1024], )} buffer_map = {p0_1: p0, T_divide_1: T_divide} preflattened_buffer_map = {p0_1: p0_3: Buffer(p0_2, float32, [32, 32, 32], ), T_divide_1: T_divide_3: Buffer(T_divide_2, float32, [32, 32], )} { allocate(p0_red: Pointer(global float32), float32, [1024]), storage_scope = global { for (ax0.ax1.fused: int32, 0, 1024) “parallel” { p0_red_1: Buffer(p0_red, float32, [1024], )[ax0.ax1.fused] = 0f32 for (k0: int32, 0, 32) { p0_red_1[ax0.ax1.fused] = (p0_red_1[ax0.ax1.fused] + p0[((k01024) + ax0.ax1.fused)]) } } for (ax0: int32, 0, 32) “parallel” { for (ax1.outer: int32, 0, 2) { T_divide[ramp(((ax032) + (ax1.outer16)), 1, 16)] = (p0_red_1[ramp(((ax032) + (ax1.outer*16)), 1, 16)]*broadcast(0.03125f32, 16)) } } } }

primfn(p0_1: handle, p1_1: handle, T_subtract_1: handle) → () attr = {“from_legacy_te_schedule”: True, “global_symbol”: “tvmgen_default_fused_subtract”, “tir.noalias”: True} buffers = {p0: Buffer(p0_2: Pointer(float32), float32, [32768], ), p1: Buffer(p1_2: Pointer(float32), float32, [1024], ), T_subtract: Buffer(T_subtract_2: Pointer(float32), float32, [32768], )} buffer_map = {p0_1: p0, p1_1: p1, T_subtract_1: T_subtract} preflattened_buffer_map = {p0_1: p0_3: Buffer(p0_2, float32, [32, 32, 32], ), p1_1: p1_3: Buffer(p1_2, float32, [32, 32], ), T_subtract_1: T_subtract_3: Buffer(T_subtract_2, float32, [32, 32, 32], )} { for (ax0.ax1.fused: int32, 0, 1024) “parallel” { for (ax2.outer: int32, 0, 2) { T_subtract[ramp(((ax0.ax1.fused32) + (ax2.outer16)), 1, 16)] = (p0[ramp(((ax0.ax1.fused32) + (ax2.outer16)), 1, 16)] - p1[ramp(((floormod(ax0.ax1.fused, 32)32) + (ax2.outer16)), 1, 16)]) } } }

primfn(p0_1: handle, T_divide_1: handle) → () attr = {“from_legacy_te_schedule”: True, “global_symbol”: “tvmgen_default_fused_power_mean”, “tir.noalias”: True} buffers = {p0: Buffer(p0_2: Pointer(float32), float32, [32768], ), T_divide: Buffer(T_divide_2: Pointer(float32), float32, [1024], )} buffer_map = {p0_1: p0, T_divide_1: T_divide} preflattened_buffer_map = {p0_1: p0_3: Buffer(p0_2, float32, [32, 32, 32], ), T_divide_1: T_divide_3: Buffer(T_divide_2, float32, [32, 32], )} { allocate(T_power_red: Pointer(global float32), float32, [1024]), storage_scope = global { for (ax0.ax1.fused: int32, 0, 1024) “parallel” { T_power_red_1: Buffer(T_power_red, float32, [1024], )[ax0.ax1.fused] = 0f32 for (k0: int32, 0, 32) { T_power_red_1[ax0.ax1.fused] = (T_power_red_1[ax0.ax1.fused] + @tir.pow(p0[((k01024) + ax0.ax1.fused)], 2f32, dtype=float32)) } } for (ax0: int32, 0, 32) “parallel” { for (ax1.outer: int32, 0, 2) { T_divide[ramp(((ax032) + (ax1.outer16)), 1, 16)] = (T_power_red_1[ramp(((ax032) + (ax1.outer*16)), 1, 16)]*broadcast(0.03125f32, 16)) } } } }

primfn(p0_1: handle, p1_1: handle, p2_1: handle, p3_1: handle, T_add_1: handle) → () attr = {“from_legacy_te_schedule”: True, “global_symbol”: “tvmgen_default_fused_add_rsqrt_multiply_multiply_add”, “tir.noalias”: True} buffers = {p0: Buffer(p0_2: Pointer(float32), float32, [32], ), p1: Buffer(p1_2: Pointer(float32), float32, [1024], ), p2: Buffer(p2_2: Pointer(float32), float32, [32768], ), p3: Buffer(p3_2: Pointer(float32), float32, [1024], ), T_add: Buffer(T_add_2: Pointer(float32), float32, [32768], )} buffer_map = {p0_1: p0, p1_1: p1, p2_1: p2, p3_1: p3, T_add_1: T_add} preflattened_buffer_map = {p2_1: p2_3: Buffer(p2_2, float32, [32, 32, 32], ), p3_1: p3_3: Buffer(p3_2, float32, [32, 32], ), p0_1: p0_3: Buffer(p0_2, float32, [1, 32], ), p1_1: p1_3: Buffer(p1_2, float32, [32, 32], ), T_add_1: T_add_3: Buffer(T_add_2, float32, [32, 32, 32], )} { for (ax0.ax1.fused: int32, 0, 1024) “parallel” { for (ax2.outer: int32, 0, 2) { T_add[ramp(((ax0.ax1.fused32) + (ax2.outer16)), 1, 16)] = (((p2[ramp(((ax0.ax1.fused32) + (ax2.outer16)), 1, 16)](broadcast(1f32, 16) / @tir.sqrt((p0[ramp((ax2.outer16), 1, 16)] + p1[ramp(((floormod(ax0.ax1.fused, 32)32) + (ax2.outer16)), 1, 16)]), dtype=float32x16)))*p3[ramp(((floormod(ax0.ax1.fused, 32)32) + (ax2.outer16)), 1, 16)]) + p3[ramp(((floormod(ax0.ax1.fused, 32)32) + (ax2.outer16)), 1, 16)]) } } }

This is stitch fusion work of Alibaba’s team