Hello,I am trying to fuse one layer convolution computation and their relu result into next layer convolution computation. I tried two methods, one is to use te.sum expression as a parameter of another te.sum, and the other is to use s.compute_inline(), but both fail. I would like to know if it is possible to combine two reduce stages (te.sum) into one reduce stage in te, if not, can relay and tir complete the expression of this function.here is current tir without fusion:
primfn(args: handle, arg_type_ids: handle, num_args: int32, out_ret_value: handle, out_ret_tcode: handle, resource_handle: handle) -> int32`
attr = {"target": meta[Target][0], "tir.noalias": True, "global_symbol": "myfunc_fusion", "from_legacy_te_schedule": True, "tir.is_entry_func": True, "calling_conv": 1} {
assert((num_args == 4), "myfunc_fusion: num_args should be 4")
let arg0: handle = @tir.tvm_struct_get(args, 0, 12, dtype=handle)
let arg0.code: int32 = (int32*)arg_type_ids[0]
let arg1: handle = @tir.tvm_struct_get(args, 1, 12, dtype=handle)
let arg1.code: int32 = (int32*)arg_type_ids[1]
let arg2: handle = @tir.tvm_struct_get(args, 2, 12, dtype=handle)
let arg2.code: int32 = (int32*)arg_type_ids[2]
let arg3: handle = @tir.tvm_struct_get(args, 3, 12, dtype=handle)
let arg3.code: int32 = (int32*)arg_type_ids[3]
let A: Pointer(float32) = @tir.tvm_struct_get(arg0, 0, 1, dtype=handle)
attr [A] "storage_alignment" = 128;
let arg0.shape: handle = @tir.tvm_struct_get(arg0, 0, 2, dtype=handle)
let arg0.strides: handle = @tir.tvm_struct_get(arg0, 0, 3, dtype=handle)
let dev_id: int32 = @tir.tvm_struct_get(arg0, 0, 9, dtype=int32)
let W: Pointer(float32) = @tir.tvm_struct_get(arg1, 0, 1, dtype=handle)
attr [W] "storage_alignment" = 128;
let arg1.shape: handle = @tir.tvm_struct_get(arg1, 0, 2, dtype=handle)
let arg1.strides: handle = @tir.tvm_struct_get(arg1, 0, 3, dtype=handle)
let W_2: Pointer(float32) = @tir.tvm_struct_get(arg2, 0, 1, dtype=handle)
attr [W_2] "storage_alignment" = 128;
let arg2.shape: handle = @tir.tvm_struct_get(arg2, 0, 2, dtype=handle)
let arg2.strides: handle = @tir.tvm_struct_get(arg2, 0, 3, dtype=handle)
let C: Pointer(float32) = @tir.tvm_struct_get(arg3, 0, 1, dtype=handle)
attr [C] "storage_alignment" = 128;
let arg3.shape: handle = @tir.tvm_struct_get(arg3, 0, 2, dtype=handle)
let arg3.strides: handle = @tir.tvm_struct_get(arg3, 0, 3, dtype=handle)
assert(((((arg0.code == 3) || (arg0.code == 13)) || (arg0.code == 7)) || (arg0.code == 4)), "myfunc_fusion: Expect arg[0] to be pointer")
assert(((((arg1.code == 3) || (arg1.code == 13)) || (arg1.code == 7)) || (arg1.code == 4)), "myfunc_fusion: Expect arg[1] to be pointer")
assert(((((arg2.code == 3) || (arg2.code == 13)) || (arg2.code == 7)) || (arg2.code == 4)), "myfunc_fusion: Expect arg[2] to be pointer")
assert(((((arg3.code == 3) || (arg3.code == 13)) || (arg3.code == 7)) || (arg3.code == 4)), "myfunc_fusion: Expect arg[3] to be pointer")
assert((4 == @tir.tvm_struct_get(arg0, 0, 4, dtype=int32)), "arg0.ndim is expected to equal 4")
assert((4 == @tir.tvm_struct_get(arg0, 0, 4, dtype=int32)), "arg0.ndim is expected to equal 4")
assert((((@tir.tvm_struct_get(arg0, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg0, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg0, 0, 7, dtype=uint16) == 1u16)), "arg0.dtype is expected to be float32")
assert((56 == cast(int32, (int64*)arg0.shape[0])), "Argument arg0.shape[0] has an unsatisfied constraint: (56 == int32(arg0.shape[0]))")
assert((56 == cast(int32, (int64*)arg0.shape[1])), "Argument arg0.shape[1] has an unsatisfied constraint: (56 == int32(arg0.shape[1]))")
assert((64 == cast(int32, (int64*)arg0.shape[2])), "Argument arg0.shape[2] has an unsatisfied constraint: (64 == int32(arg0.shape[2]))")
assert((3 == cast(int32, (int64*)arg0.shape[3])), "Argument arg0.shape[3] has an unsatisfied constraint: (3 == int32(arg0.shape[3]))")
{
if !@tir.isnullptr(arg0.strides, dtype=bool) {
assert(((((1 == cast(int32, (int64*)arg0.strides[3])) && (3 == cast(int32, (int64*)arg0.strides[2]))) && (192 == cast(int32, (int64*)arg0.strides[1]))) && (10752 == cast(int32, (int64*)arg0.strides[0]))), "arg0.strides: expected to be compact array")
0
}
assert((0u64 == @tir.tvm_struct_get(arg0, 0, 8, dtype=uint64)), "Argument arg0.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg0, 0, 8))")
assert((1 == @tir.tvm_struct_get(arg0, 0, 10, dtype=int32)), "Argument arg0.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg0, 0, 10))")
assert((4 == @tir.tvm_struct_get(arg1, 0, 4, dtype=int32)), "arg1.ndim is expected to equal 4")
assert((4 == @tir.tvm_struct_get(arg1, 0, 4, dtype=int32)), "arg1.ndim is expected to equal 4")
assert((((@tir.tvm_struct_get(arg1, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg1, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg1, 0, 7, dtype=uint16) == 1u16)), "arg1.dtype is expected to be float32")
assert((3 == cast(int32, (int64*)arg1.shape[0])), "Argument arg1.shape[0] has an unsatisfied constraint: (3 == int32(arg1.shape[0]))")
assert((3 == cast(int32, (int64*)arg1.shape[1])), "Argument arg1.shape[1] has an unsatisfied constraint: (3 == int32(arg1.shape[1]))")
assert((64 == cast(int32, (int64*)arg1.shape[2])), "Argument arg1.shape[2] has an unsatisfied constraint: (64 == int32(arg1.shape[2]))")
assert((64 == cast(int32, (int64*)arg1.shape[3])), "Argument arg1.shape[3] has an unsatisfied constraint: (64 == int32(arg1.shape[3]))")
{
if !@tir.isnullptr(arg1.strides, dtype=bool) {
assert(((((1 == cast(int32, (int64*)arg1.strides[3])) && (64 == cast(int32, (int64*)arg1.strides[2]))) && (4096 == cast(int32, (int64*)arg1.strides[1]))) && (12288 == cast(int32, (int64*)arg1.strides[0]))), "arg1.strides: expected to be compact array")
0
}
assert((0u64 == @tir.tvm_struct_get(arg1, 0, 8, dtype=uint64)), "Argument arg1.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg1, 0, 8))")
assert((1 == @tir.tvm_struct_get(arg1, 0, 10, dtype=int32)), "Argument arg1.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg1, 0, 10))")
assert((dev_id == @tir.tvm_struct_get(arg1, 0, 9, dtype=int32)), "Argument arg1.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg1, 0, 9))")
assert((4 == @tir.tvm_struct_get(arg2, 0, 4, dtype=int32)), "arg2.ndim is expected to equal 4")
assert((4 == @tir.tvm_struct_get(arg2, 0, 4, dtype=int32)), "arg2.ndim is expected to equal 4")
assert((((@tir.tvm_struct_get(arg2, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg2, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg2, 0, 7, dtype=uint16) == 1u16)), "arg2.dtype is expected to be float32")
assert((3 == cast(int32, (int64*)arg2.shape[0])), "Argument arg2.shape[0] has an unsatisfied constraint: (3 == int32(arg2.shape[0]))")
assert((3 == cast(int32, (int64*)arg2.shape[1])), "Argument arg2.shape[1] has an unsatisfied constraint: (3 == int32(arg2.shape[1]))")
assert((64 == cast(int32, (int64*)arg2.shape[2])), "Argument arg2.shape[2] has an unsatisfied constraint: (64 == int32(arg2.shape[2]))")
assert((64 == cast(int32, (int64*)arg2.shape[3])), "Argument arg2.shape[3] has an unsatisfied constraint: (64 == int32(arg2.shape[3]))")
{
if !@tir.isnullptr(arg2.strides, dtype=bool) {
assert(((((1 == cast(int32, (int64*)arg2.strides[3])) && (64 == cast(int32, (int64*)arg2.strides[2]))) && (4096 == cast(int32, (int64*)arg2.strides[1]))) && (12288 == cast(int32, (int64*)arg2.strides[0]))), "arg2.strides: expected to be compact array")
0
}
assert((0u64 == @tir.tvm_struct_get(arg2, 0, 8, dtype=uint64)), "Argument arg2.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg2, 0, 8))")
assert((1 == @tir.tvm_struct_get(arg2, 0, 10, dtype=int32)), "Argument arg2.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg2, 0, 10))")
assert((dev_id == @tir.tvm_struct_get(arg2, 0, 9, dtype=int32)), "Argument arg2.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg2, 0, 9))")
assert((4 == @tir.tvm_struct_get(arg3, 0, 4, dtype=int32)), "arg3.ndim is expected to equal 4")
assert((4 == @tir.tvm_struct_get(arg3, 0, 4, dtype=int32)), "arg3.ndim is expected to equal 4")
assert((((@tir.tvm_struct_get(arg3, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg3, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg3, 0, 7, dtype=uint16) == 1u16)), "arg3.dtype is expected to be float32")
assert((54 == cast(int32, (int64*)arg3.shape[0])), "Argument arg3.shape[0] has an unsatisfied constraint: (54 == int32(arg3.shape[0]))")
assert((54 == cast(int32, (int64*)arg3.shape[1])), "Argument arg3.shape[1] has an unsatisfied constraint: (54 == int32(arg3.shape[1]))")
assert((64 == cast(int32, (int64*)arg3.shape[2])), "Argument arg3.shape[2] has an unsatisfied constraint: (64 == int32(arg3.shape[2]))")
assert((3 == cast(int32, (int64*)arg3.shape[3])), "Argument arg3.shape[3] has an unsatisfied constraint: (3 == int32(arg3.shape[3]))")
{
if !@tir.isnullptr(arg3.strides, dtype=bool) {
assert(((((1 == cast(int32, (int64*)arg3.strides[3])) && (3 == cast(int32, (int64*)arg3.strides[2]))) && (192 == cast(int32, (int64*)arg3.strides[1]))) && (10368 == cast(int32, (int64*)arg3.strides[0]))), "arg3.strides: expected to be compact array")
0
}
assert((0u64 == @tir.tvm_struct_get(arg3, 0, 8, dtype=uint64)), "Argument arg3.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg3, 0, 8))")
assert((1 == @tir.tvm_struct_get(arg3, 0, 10, dtype=int32)), "Argument arg3.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg3, 0, 10))")
assert((dev_id == @tir.tvm_struct_get(arg3, 0, 9, dtype=int32)), "Argument arg3.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg3, 0, 9))")
attr [0] "compute_scope" = "myfunc_fusion_compute_";
attr [R: Pointer(global float32)] "storage_alignment" = 128 {
let R = @tir.TVMBackendAllocWorkspace(1, dev_id, 2239488u64, 2, 32, dtype=handle)
{
if @tir.isnullptr(R, dtype=bool) {
@tir.tvm_throw_last_error(, dtype=int32)
}
allocate(B: Pointer(global float32), float32, [1]), storage_scope = global {
for (yy: int32, 0, 54) {
for (xx: int32, 0, 54) {
for (cc: int32, 0, 64) {
for (batch: int32, 0, 3) {
B[0] = 0f32
for (ry: int32, 0, 3) {
for (rx: int32, 0, 3) {
for (rc: int32, 0, 64) {
B[0] = @tir.call_llvm_pure_intrin(134u32, 3u32, (float32*)A[((((((yy*10752) + (ry*10752)) + (xx*192)) + (rx*192)) + (rc*3)) + batch)], (float32*)W[((((ry*12288) + (rx*4096)) + (rc*64)) + cc)], (float32*)B[0], dtype=float32)
}
}
}
R[((((yy*10368) + (xx*192)) + (cc*3)) + batch)] = max(0f32, (float32*)B[0])
}
}
}
}
for (yy_1: int32, 0, 54) {
for (xx_1: int32, 0, 54) {
for (ff: int32, 0, 64) {
for (nn: int32, 0, 3) {
C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)] = 0f32
for (ry_2: int32, 0, 3) {
for (rx_2: int32, 0, 3) {
for (rc_2: int32, 0, 64) {
C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)] = @tir.call_llvm_pure_intrin(134u32, 3u32, (float32*)R[((((((yy_1*10368) + (ry_2*10368)) + (xx_1*192)) + (rx_2*192)) + (rc_2*3)) + nn)], (float32*)W_2[((((ry_2*12288) + (rx_2*4096)) + (rc_2*64)) + ff)], (float32*)C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)], dtype=float32)
}
}
}
}
}
}
}
}
}
if (@tir.TVMBackendFreeWorkspace(1, dev_id, R, dtype=int32) != 0) {
@tir.tvm_throw_last_error(, dtype=int32)
}
}
}
}
}
}
}
below is the c code model to express this fusion function(conv - relu - conv):
#include<stdio.h>
#define INPUT_SIZE_1 56
#define IN_CHANNEL_1 64
#define OUT_CHANNEL_1 64
#define BATCH_1 4
#define KERNEL_SIZE_1 3
#define OUT_SIZE_1 54
#define PAD_1 0
#define INPUT_SIZE_2 54
#define IN_CHANNEL_2 64
#define OUT_CHANNEL_2 64
#define BATCH_2 4
#define KERNEL_SIZE_2 3
#define OUT_SIZE_2 52
#define PAD_2 0
int main()
{ printf("conv - relu - conv fusion\n");
float input_img[BATCH_1][IN_CHANNEL_1][INPUT_SIZE_1][INPUT_SIZE_1];
float weight_1[OUT_CHANNEL_1][IN_CHANNEL_1][KERNEL_SIZE_1][KERNEL_SIZE_1];
float weight_2[OUT_CHANNEL_2][IN_CHANNEL_2][KERNEL_SIZE_2][KERNEL_SIZE_2];
float output[BATCH_2][OUT_CHANNEL_2][OUT_SIZE_2][OUT_SIZE_2] ;
menset(output,0,sizeof(output));
float output_med = 0;
for(int batch_2 = 0 ; batch_2 < BATCH_2 ; batch_2++)
{
for(int in_channel_2 = 0 ; in_channel_2 < IN_CHANNEL_2 ; in_channel_2++)
{
for(int out_size_2_y = 0 ; out_size_2_y < OUT_SIZE_2 ; out_size_2_y++)
{
for(int out_size_2_x = 0 ; out_size_2_x < OUT_SIZE_2 ; out_size_2_x++)
{
for(int out_channel_2 = 0 ;out_channel_2 < OUT_CHANNEL_2 ; out_channel_2++)
{
for(int in_channel_2 = 0 ;in_channel_2 < IN_CHANNEL_2 ; in_channel_2++)
{
for(int kernel_size_2_y = 0 ;kernel_size_2_y < KERNEL_SIZE_2 ; kernel_size_2_y++)
{
for(int kernel_size_2_x = 0 ;kernel_size_2_x < KERNEL_SIZE_2 ; kernel_size_2_x++)
{
for(int out_channel_1 = 0 ;out_channel_1 < OUT_CHANNEL_1 ; out_channel_1++)
{
for(int in_channel_1 = 0 ;in_channel_1 < IN_CHANNEL_1 ; in_channel_1++)
{
for(int kernel_size_1_y = 0 ;kernel_size_1_y < KERNEL_SIZE_1 ; kernel_size_1_y++)
{
for(int kernel_size_1_x = 0 ;kernel_size_1_x < KERNEL_SIZE_1 ; kernel_size_1_x++)
{
output_med = output_med + input_img[batch_2][in_channel_1][out_size_2_y+kernel_size_2_y+kernel_size_1_y][out_size_2_x+kernel_size_2_x+kernel_size_1_x] * weight_1[out_channel_1][in_channel_1][kernel_size_1_y][kernel_size_1_x];
}
}
}
}
output_med = output_med > 0 ? output_med : 0;
output[batch_2][out_channel_2][out_size_2_y][out_size_2_x] = output[batch_2][out_channel_2][out_size_2_y][out_size_2_x] + output_med * weight_2[out_channel_2][in_channel_2][kernel_size_2_y][kernel_size_2_x];
output_med = 0 ;
}
}
}
}
}
}
}
}
}