Can One reduce stage fuse into another reduce stage?

Hello,I am trying to fuse one layer convolution computation and their relu result into next layer convolution computation. I tried two methods, one is to use te.sum expression as a parameter of another te.sum, and the other is to use s.compute_inline(), but both fail. I would like to know if it is possible to combine two reduce stages (te.sum) into one reduce stage in te, if not, can relay and tir complete the expression of this function.here is current tir without fusion:

    primfn(args: handle, arg_type_ids: handle, num_args: int32, out_ret_value: handle, out_ret_tcode: handle, resource_handle: handle) -> int32`
  attr = {"target": meta[Target][0], "tir.noalias": True, "global_symbol": "myfunc_fusion", "from_legacy_te_schedule": True, "tir.is_entry_func": True, "calling_conv": 1} {
  assert((num_args == 4), "myfunc_fusion: num_args should be 4")
  let arg0: handle = @tir.tvm_struct_get(args, 0, 12, dtype=handle)
  let arg0.code: int32 = (int32*)arg_type_ids[0]
  let arg1: handle = @tir.tvm_struct_get(args, 1, 12, dtype=handle)
  let arg1.code: int32 = (int32*)arg_type_ids[1]
  let arg2: handle = @tir.tvm_struct_get(args, 2, 12, dtype=handle)
  let arg2.code: int32 = (int32*)arg_type_ids[2]
  let arg3: handle = @tir.tvm_struct_get(args, 3, 12, dtype=handle)
  let arg3.code: int32 = (int32*)arg_type_ids[3]
  let A: Pointer(float32) = @tir.tvm_struct_get(arg0, 0, 1, dtype=handle)
  attr [A] "storage_alignment" = 128;
  let arg0.shape: handle = @tir.tvm_struct_get(arg0, 0, 2, dtype=handle)
  let arg0.strides: handle = @tir.tvm_struct_get(arg0, 0, 3, dtype=handle)
  let dev_id: int32 = @tir.tvm_struct_get(arg0, 0, 9, dtype=int32)
  let W: Pointer(float32) = @tir.tvm_struct_get(arg1, 0, 1, dtype=handle)
  attr [W] "storage_alignment" = 128;
  let arg1.shape: handle = @tir.tvm_struct_get(arg1, 0, 2, dtype=handle)
  let arg1.strides: handle = @tir.tvm_struct_get(arg1, 0, 3, dtype=handle)
  let W_2: Pointer(float32) = @tir.tvm_struct_get(arg2, 0, 1, dtype=handle)
  attr [W_2] "storage_alignment" = 128;
  let arg2.shape: handle = @tir.tvm_struct_get(arg2, 0, 2, dtype=handle)
  let arg2.strides: handle = @tir.tvm_struct_get(arg2, 0, 3, dtype=handle)
  let C: Pointer(float32) = @tir.tvm_struct_get(arg3, 0, 1, dtype=handle)
  attr [C] "storage_alignment" = 128;
  let arg3.shape: handle = @tir.tvm_struct_get(arg3, 0, 2, dtype=handle)
  let arg3.strides: handle = @tir.tvm_struct_get(arg3, 0, 3, dtype=handle)
  assert(((((arg0.code == 3) || (arg0.code == 13)) || (arg0.code == 7)) || (arg0.code == 4)), "myfunc_fusion: Expect arg[0] to be pointer")
  assert(((((arg1.code == 3) || (arg1.code == 13)) || (arg1.code == 7)) || (arg1.code == 4)), "myfunc_fusion: Expect arg[1] to be pointer")
  assert(((((arg2.code == 3) || (arg2.code == 13)) || (arg2.code == 7)) || (arg2.code == 4)), "myfunc_fusion: Expect arg[2] to be pointer")
  assert(((((arg3.code == 3) || (arg3.code == 13)) || (arg3.code == 7)) || (arg3.code == 4)), "myfunc_fusion: Expect arg[3] to be pointer")
  assert((4 == @tir.tvm_struct_get(arg0, 0, 4, dtype=int32)), "arg0.ndim is expected to equal 4")
  assert((4 == @tir.tvm_struct_get(arg0, 0, 4, dtype=int32)), "arg0.ndim is expected to equal 4")
  assert((((@tir.tvm_struct_get(arg0, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg0, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg0, 0, 7, dtype=uint16) == 1u16)), "arg0.dtype is expected to be float32")
  assert((56 == cast(int32, (int64*)arg0.shape[0])), "Argument arg0.shape[0] has an unsatisfied constraint: (56 == int32(arg0.shape[0]))")
  assert((56 == cast(int32, (int64*)arg0.shape[1])), "Argument arg0.shape[1] has an unsatisfied constraint: (56 == int32(arg0.shape[1]))")
  assert((64 == cast(int32, (int64*)arg0.shape[2])), "Argument arg0.shape[2] has an unsatisfied constraint: (64 == int32(arg0.shape[2]))")
  assert((3 == cast(int32, (int64*)arg0.shape[3])), "Argument arg0.shape[3] has an unsatisfied constraint: (3 == int32(arg0.shape[3]))")
   {
    if !@tir.isnullptr(arg0.strides, dtype=bool) {
      assert(((((1 == cast(int32, (int64*)arg0.strides[3])) && (3 == cast(int32, (int64*)arg0.strides[2]))) && (192 == cast(int32, (int64*)arg0.strides[1]))) && (10752 == cast(int32, (int64*)arg0.strides[0]))), "arg0.strides: expected to be compact array")
      0
    }
    assert((0u64 == @tir.tvm_struct_get(arg0, 0, 8, dtype=uint64)), "Argument arg0.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg0, 0, 8))")
    assert((1 == @tir.tvm_struct_get(arg0, 0, 10, dtype=int32)), "Argument arg0.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg0, 0, 10))")
    assert((4 == @tir.tvm_struct_get(arg1, 0, 4, dtype=int32)), "arg1.ndim is expected to equal 4")
    assert((4 == @tir.tvm_struct_get(arg1, 0, 4, dtype=int32)), "arg1.ndim is expected to equal 4")
    assert((((@tir.tvm_struct_get(arg1, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg1, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg1, 0, 7, dtype=uint16) == 1u16)), "arg1.dtype is expected to be float32")
    assert((3 == cast(int32, (int64*)arg1.shape[0])), "Argument arg1.shape[0] has an unsatisfied constraint: (3 == int32(arg1.shape[0]))")
    assert((3 == cast(int32, (int64*)arg1.shape[1])), "Argument arg1.shape[1] has an unsatisfied constraint: (3 == int32(arg1.shape[1]))")
    assert((64 == cast(int32, (int64*)arg1.shape[2])), "Argument arg1.shape[2] has an unsatisfied constraint: (64 == int32(arg1.shape[2]))")
    assert((64 == cast(int32, (int64*)arg1.shape[3])), "Argument arg1.shape[3] has an unsatisfied constraint: (64 == int32(arg1.shape[3]))")
     {
      if !@tir.isnullptr(arg1.strides, dtype=bool) {
        assert(((((1 == cast(int32, (int64*)arg1.strides[3])) && (64 == cast(int32, (int64*)arg1.strides[2]))) && (4096 == cast(int32, (int64*)arg1.strides[1]))) && (12288 == cast(int32, (int64*)arg1.strides[0]))), "arg1.strides: expected to be compact array")
        0
      }
      assert((0u64 == @tir.tvm_struct_get(arg1, 0, 8, dtype=uint64)), "Argument arg1.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg1, 0, 8))")
      assert((1 == @tir.tvm_struct_get(arg1, 0, 10, dtype=int32)), "Argument arg1.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg1, 0, 10))")
      assert((dev_id == @tir.tvm_struct_get(arg1, 0, 9, dtype=int32)), "Argument arg1.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg1, 0, 9))")
      assert((4 == @tir.tvm_struct_get(arg2, 0, 4, dtype=int32)), "arg2.ndim is expected to equal 4")
      assert((4 == @tir.tvm_struct_get(arg2, 0, 4, dtype=int32)), "arg2.ndim is expected to equal 4")
      assert((((@tir.tvm_struct_get(arg2, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg2, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg2, 0, 7, dtype=uint16) == 1u16)), "arg2.dtype is expected to be float32")
      assert((3 == cast(int32, (int64*)arg2.shape[0])), "Argument arg2.shape[0] has an unsatisfied constraint: (3 == int32(arg2.shape[0]))")
      assert((3 == cast(int32, (int64*)arg2.shape[1])), "Argument arg2.shape[1] has an unsatisfied constraint: (3 == int32(arg2.shape[1]))")
      assert((64 == cast(int32, (int64*)arg2.shape[2])), "Argument arg2.shape[2] has an unsatisfied constraint: (64 == int32(arg2.shape[2]))")
      assert((64 == cast(int32, (int64*)arg2.shape[3])), "Argument arg2.shape[3] has an unsatisfied constraint: (64 == int32(arg2.shape[3]))")
       {
        if !@tir.isnullptr(arg2.strides, dtype=bool) {
          assert(((((1 == cast(int32, (int64*)arg2.strides[3])) && (64 == cast(int32, (int64*)arg2.strides[2]))) && (4096 == cast(int32, (int64*)arg2.strides[1]))) && (12288 == cast(int32, (int64*)arg2.strides[0]))), "arg2.strides: expected to be compact array")
          0
        }
        assert((0u64 == @tir.tvm_struct_get(arg2, 0, 8, dtype=uint64)), "Argument arg2.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg2, 0, 8))")
        assert((1 == @tir.tvm_struct_get(arg2, 0, 10, dtype=int32)), "Argument arg2.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg2, 0, 10))")
        assert((dev_id == @tir.tvm_struct_get(arg2, 0, 9, dtype=int32)), "Argument arg2.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg2, 0, 9))")
        assert((4 == @tir.tvm_struct_get(arg3, 0, 4, dtype=int32)), "arg3.ndim is expected to equal 4")
        assert((4 == @tir.tvm_struct_get(arg3, 0, 4, dtype=int32)), "arg3.ndim is expected to equal 4")
        assert((((@tir.tvm_struct_get(arg3, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg3, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg3, 0, 7, dtype=uint16) == 1u16)), "arg3.dtype is expected to be float32")
        assert((54 == cast(int32, (int64*)arg3.shape[0])), "Argument arg3.shape[0] has an unsatisfied constraint: (54 == int32(arg3.shape[0]))")
        assert((54 == cast(int32, (int64*)arg3.shape[1])), "Argument arg3.shape[1] has an unsatisfied constraint: (54 == int32(arg3.shape[1]))")
        assert((64 == cast(int32, (int64*)arg3.shape[2])), "Argument arg3.shape[2] has an unsatisfied constraint: (64 == int32(arg3.shape[2]))")
        assert((3 == cast(int32, (int64*)arg3.shape[3])), "Argument arg3.shape[3] has an unsatisfied constraint: (3 == int32(arg3.shape[3]))")
         {
          if !@tir.isnullptr(arg3.strides, dtype=bool) {
            assert(((((1 == cast(int32, (int64*)arg3.strides[3])) && (3 == cast(int32, (int64*)arg3.strides[2]))) && (192 == cast(int32, (int64*)arg3.strides[1]))) && (10368 == cast(int32, (int64*)arg3.strides[0]))), "arg3.strides: expected to be compact array")
            0
          }
          assert((0u64 == @tir.tvm_struct_get(arg3, 0, 8, dtype=uint64)), "Argument arg3.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg3, 0, 8))")
          assert((1 == @tir.tvm_struct_get(arg3, 0, 10, dtype=int32)), "Argument arg3.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg3, 0, 10))")
          assert((dev_id == @tir.tvm_struct_get(arg3, 0, 9, dtype=int32)), "Argument arg3.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg3, 0, 9))")
          attr [0] "compute_scope" = "myfunc_fusion_compute_";
          attr [R: Pointer(global float32)] "storage_alignment" = 128 {
            let R = @tir.TVMBackendAllocWorkspace(1, dev_id, 2239488u64, 2, 32, dtype=handle)
             {
              if @tir.isnullptr(R, dtype=bool) {
                @tir.tvm_throw_last_error(, dtype=int32)
              }
              allocate(B: Pointer(global float32), float32, [1]), storage_scope = global {
                for (yy: int32, 0, 54) {
                  for (xx: int32, 0, 54) {
                    for (cc: int32, 0, 64) {
                      for (batch: int32, 0, 3) {
                        B[0] = 0f32
                        for (ry: int32, 0, 3) {
                          for (rx: int32, 0, 3) {
                            for (rc: int32, 0, 64) {
                              B[0] = @tir.call_llvm_pure_intrin(134u32, 3u32, (float32*)A[((((((yy*10752) + (ry*10752)) + (xx*192)) + (rx*192)) + (rc*3)) + batch)], (float32*)W[((((ry*12288) + (rx*4096)) + (rc*64)) + cc)], (float32*)B[0], dtype=float32)
                            }
                          }
                        }
                        R[((((yy*10368) + (xx*192)) + (cc*3)) + batch)] = max(0f32, (float32*)B[0])
                      }
                    }
                  }
                }
                for (yy_1: int32, 0, 54) {
                  for (xx_1: int32, 0, 54) {
                    for (ff: int32, 0, 64) {
                      for (nn: int32, 0, 3) {
                        C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)] = 0f32
                        for (ry_2: int32, 0, 3) {
                          for (rx_2: int32, 0, 3) {
                            for (rc_2: int32, 0, 64) {
                              C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)] = @tir.call_llvm_pure_intrin(134u32, 3u32, (float32*)R[((((((yy_1*10368) + (ry_2*10368)) + (xx_1*192)) + (rx_2*192)) + (rc_2*3)) + nn)], (float32*)W_2[((((ry_2*12288) + (rx_2*4096)) + (rc_2*64)) + ff)], (float32*)C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)], dtype=float32)
                            }
                          }
                        }
                      }
                    }
                  }
                }
              }
            }
            if (@tir.TVMBackendFreeWorkspace(1, dev_id, R, dtype=int32) != 0) {
              @tir.tvm_throw_last_error(, dtype=int32)
            }
          }
        }
      }
    }
  }
}

below is the c code model to express this fusion function(conv - relu - conv):

    #include<stdio.h>
    #define 	INPUT_SIZE_1 	56
    #define 	IN_CHANNEL_1 	64
    #define 	OUT_CHANNEL_1 	64
    #define 	BATCH_1 		4
    #define		KERNEL_SIZE_1	3
    #define		OUT_SIZE_1		54
    #define		PAD_1			0
    #define 	INPUT_SIZE_2    54
    #define 	IN_CHANNEL_2    64
    #define 	OUT_CHANNEL_2   64
    #define 	BATCH_2         4
    #define 	KERNEL_SIZE_2	3
    #define		OUT_SIZE_2		52
    #define		PAD_2			0
    int main()
    { 	printf("conv - relu - conv fusion\n");
    	float input_img[BATCH_1][IN_CHANNEL_1][INPUT_SIZE_1][INPUT_SIZE_1];
    	float weight_1[OUT_CHANNEL_1][IN_CHANNEL_1][KERNEL_SIZE_1][KERNEL_SIZE_1];
    	float weight_2[OUT_CHANNEL_2][IN_CHANNEL_2][KERNEL_SIZE_2][KERNEL_SIZE_2];
    	float output[BATCH_2][OUT_CHANNEL_2][OUT_SIZE_2][OUT_SIZE_2] ;
    	menset(output,0,sizeof(output));
    	float output_med = 0;
    	for(int batch_2 = 0 ; batch_2 < BATCH_2 ; batch_2++)
    	{
    		for(int in_channel_2 = 0 ; in_channel_2 < IN_CHANNEL_2 ; in_channel_2++)
    		{
    			for(int out_size_2_y = 0 ; out_size_2_y < OUT_SIZE_2 ; out_size_2_y++)
    			{
    				for(int out_size_2_x = 0 ; out_size_2_x < OUT_SIZE_2 ; out_size_2_x++)
    				{
    					for(int out_channel_2 = 0 ;out_channel_2 < OUT_CHANNEL_2 ; out_channel_2++)
    					{
    						for(int in_channel_2 = 0 ;in_channel_2 < IN_CHANNEL_2 ; in_channel_2++)
    						{
    							for(int kernel_size_2_y = 0 ;kernel_size_2_y < KERNEL_SIZE_2 ; kernel_size_2_y++)
    							{
    								for(int kernel_size_2_x = 0 ;kernel_size_2_x < KERNEL_SIZE_2 ; kernel_size_2_x++)
    								{
    									for(int out_channel_1 = 0 ;out_channel_1 < OUT_CHANNEL_1 ; out_channel_1++)
    									{
    										for(int in_channel_1 = 0 ;in_channel_1 < IN_CHANNEL_1 ; in_channel_1++)
    										{
    											for(int kernel_size_1_y = 0 ;kernel_size_1_y < KERNEL_SIZE_1 ; kernel_size_1_y++)
    											{
    												for(int kernel_size_1_x = 0 ;kernel_size_1_x < KERNEL_SIZE_1 ; kernel_size_1_x++)
    												{
    													output_med = output_med + input_img[batch_2][in_channel_1][out_size_2_y+kernel_size_2_y+kernel_size_1_y][out_size_2_x+kernel_size_2_x+kernel_size_1_x] * weight_1[out_channel_1][in_channel_1][kernel_size_1_y][kernel_size_1_x];

    												}
    											}
    										}
    									}
    									output_med = output_med > 0 ? output_med : 0;
    									output[batch_2][out_channel_2][out_size_2_y][out_size_2_x] = output[batch_2][out_channel_2][out_size_2_y][out_size_2_x] + output_med * weight_2[out_channel_2][in_channel_2][kernel_size_2_y][kernel_size_2_x];
    									output_med = 0 ;
    								}
    							}
    						}
    					}
    				}
    			}
    		}
    	}
    	

    }

I’m not sure what you mean here, but if I take it literally, that wouldn’t be feasible.

But there is an example of scheduling fused conv2d → conv2d: https://github.com/apache/tvm/blob/6aa5ba281ba669d01038ca67b2f6d55ba2299249/tests/python/contrib/test_hexagon/test_conv2d_conv2d.py