How to properly compact the buffer in TensorIR

Hello, I’m working with Relax and TensorIR. When I run to the pass of CompactBufferAllocation, the result is not as expected.

Before this pass, the tir is like:

for (ax0_2: int32, 0, 1) {
        for (ax1_0: int32, 0, 2) {
          for (ax2_0: int32, 0, 1) {
            for (ax3_0: int32, 0, 1) {
              block([], "") {
                tir.reads([rxplaceholder[0, ((ax1_0*41) - 1):(((ax1_0*41) - 1) + 43), 0:64, 0:66], rxplaceholder_1[0:256, 0:66, 0:3, 0:3], rxplaceholder_shared_1[0:256], rxplaceholder_shared[0, 0, 0, 0:256]])
                tir.writes([Conv2dOutput11[ax0_2, (ax1_0*41):((ax1_0*41) + 41), (ax2_0*64):((ax2_0*64) + 64), (ax3_0*256):((ax3_0*256) + 256)]])
                PaddedInput = alloc_buffer(int8[1, 84, 66, 66])
                rxplaceholder_shared_2 = alloc_buffer(int8[256, 66, 3, 3])
                Conv2dOutput11_shared = alloc_buffer(int8[1, 82, 64, 256])
                rxplaceholder_shared_3 = alloc_buffer(int8[1, 82, 64, 66])
                 {
                  for (ax0_3: int32, 0, 43) {
                    for (ax1_1: int32, 0, 64) {
                      for (ax2_1: int32, 0, 66) {
                        block([], "rxplaceholder_shared") {
                          where(((1 <= ((ax1_0*41) + ax0_3)) && (((ax1_0*41) + ax0_3) < 83)))
                          tir.reads([rxplaceholder[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1]])
                          tir.writes([rxplaceholder_shared_3[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1]])
                          rxplaceholder_shared_3[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1] = rxplaceholder[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1]
                      }
                    }
                  }
                  for (ax0_4: int32, 0, 43) {
                    for (ax1_2: int32, 0, 66) {
                      for (ax2_2: int32, 0, 66) {
                        block([], "PaddedInput") {
                          tir.reads([rxplaceholder_shared_3[0, (((ax1_0*41) + ax0_4) - 1), (ax1_2 - 1), ax2_2]])
                          tir.writes([PaddedInput[0, ((ax1_0*41) + ax0_4), ax1_2, ax2_2]])
                          PaddedInput[0, ((ax1_0*41) + ax0_4), ax1_2, ax2_2] = @tir.if_then_else(((((1 <= ((ax1_0*41) + ax0_4)) && (((ax1_0*41) + ax0_4) < 83)) && (1 <= ax1_2)) && (ax1_2 < 65)), rxplaceholder_shared_3[0, (((ax1_0*41) + ax0_4) - 1), (ax1_2 - 1), ax2_2], 0i8, dtype=int8)
                      }
                    }
                  }
                  for (ax0_5: int32, 0, 256) {
                    for (ax1_3: int32, 0, 66) {
                      for (ax2_3: int32, 0, 3) {
                        for (ax3_1: int32, 0, 3) {
                          block([], "rxplaceholder_shared") {
                            tir.reads([rxplaceholder_1[ax0_5, ax1_3, ax2_3, ax3_1]])
                            tir.writes([rxplaceholder_shared_2[ax0_5, ax1_3, ax2_3, ax3_1]])
                            rxplaceholder_shared_2[ax0_5, ax1_3, ax2_3, ax3_1] = rxplaceholder_1[ax0_5, ax1_3, ax2_3, ax3_1]
                        }
                      }
                    }
                  }
                  for (ax0_6: int32, 0, 41) {
                    for (ax1_4: int32, 0, 64) {
                      for (ax2_4: int32, 0, 256) {
                        for (ax3_2: int32, 0, 3) {
                          for (ax4: int32, 0, 3) {
                            for (ax5: int32, 0, 66) {
                              block([], "Conv2dOutput11") {
                                tir.reads([PaddedInput[0, (((ax1_0*41) + ax0_6) + ax3_2), (ax1_4 + ax4), ax5], rxplaceholder_shared_2[ax2_4, ax5, ax3_2, ax4], rxplaceholder_shared_1[ax2_4], rxplaceholder_shared[0, 0, 0, ax2_4]])
                                tir.writes([Conv2dOutput11_shared[0, ((ax1_0*41) + ax0_6), ax1_4, ax2_4]])
                                tir.attrs({"stride_h": 1, "stride_w": 1, "kernel_h": 3, "has_pad": True, "has_bias": True, "shift": 10, "kernel_w": 3, "layout": "NHWC", "in_channel": 66, "dilation_h": 1, "has_mul": True, "relu_type": 1, "dilation_w": 1, "out_channel": 256})
                                 {
                                  if (((ax3_2 == 0) && (ax4 == 0)) && (ax5 == 0)) {
                                    Conv2dOutput11_shared[0, ((ax1_0*41) + ax0_6), ax1_4, ax2_4] = 0i8
                                  }
                                  Conv2dOutput11_shared[0, ((ax1_0*41) + ax0_6), ax1_4, ax2_4] = (Conv2dOutput11_shared[0, ((ax1_0*41) + ax0_6), ax1_4, ax2_4] + cast(int8, (cast(int32, (PaddedInput[0, (((ax1_0*41) + ax0_6) + ax3_2), (ax1_4 + ax4), ax5]*rxplaceholder_shared_2[ax2_4, ax5, ax3_2, ax4])) + (rxplaceholder_shared_1[ax2_4]*rxplaceholder_shared[0, 0, 0, ax2_4]))))
                                }
                            }
                          }
                        }
                      }
                    }
                  }
                  for (ax1_1_1: int32, 0, 41) {
                    for (ax2_1_1: int32, 0, 64) {
                      for (ax3_1_1: int32, 0, 256) {
                        block([], "Conv2dOutput11_shared") {
                          tir.reads([Conv2dOutput11_shared[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]])
                          tir.writes([Conv2dOutput11[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]])
                          Conv2dOutput11[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)] = Conv2dOutput11_shared[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]
                      }
                    }
                  }
                }
            }
          }
        }
      }
    }

After Pass is:

      for (ax0_2: int32, 0, 1) {
        for (ax1_0: int32, 0, 2) {
          for (ax2_0: int32, 0, 1) {
            for (ax3_0: int32, 0, 1) {
              block([], "") {
                tir.reads([rxplaceholder[0, ((ax1_0*41) - 1):(((ax1_0*41) - 1) + 43), 0:64, 0:66], rxplaceholder_1[0:256, 0:66, 0:3, 0:3], rxplaceholder_shared_1[0:256], rxplaceholder_shared[0, 0, 0, 0:256]])
                tir.writes([Conv2dOutput11[ax0_2, (ax1_0*41):((ax1_0*41) + 41), (ax2_0*64):((ax2_0*64) + 64), (ax3_0*256):((ax3_0*256) + 256)]])
                PaddedInput = alloc_buffer(int8[1, 43, 66, 66])
                rxplaceholder_shared_2 = alloc_buffer(int8[256, 66, 3, 3])
                Conv2dOutput11_shared = alloc_buffer(int8[1, 41, 64, 256])
                rxplaceholder_shared_3 = alloc_buffer(int8[1, 82, 64, 66])
                 {
                  for (ax0_3: int32, 0, 43) {
                    for (ax1_1: int32, 0, 64) {
                      for (ax2_1: int32, 0, 66) {
                        block([], "rxplaceholder_shared") {
                          where(((1 <= ((ax1_0*41) + ax0_3)) && (((ax1_0*41) + ax0_3) < 83)))
                          tir.reads([rxplaceholder[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1]])
                          tir.writes([rxplaceholder_shared_3[0, ((((ax1_0*41) - 1) + ax0_3) - max(0, ((ax1_0*41) - 1))), ax1_1, ax2_1]])
						  rxplaceholder_shared_3[0, ((((ax1_0*41) - 1) + ax0_3) - max(0, ((ax1_0*41) - 1))), ax1_1, ax2_1] = rxplaceholder[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1]
                      }
                    }
                  }
                  for (ax0_4: int32, 0, 43) {
                    for (ax1_2: int32, 0, 66) {
                      for (ax2_2: int32, 0, 66) {
                        block([], "PaddedInput") {
                          tir.reads([rxplaceholder_shared_3[0, ((((ax1_0*41) + ax0_4) - 1) - max(0, ((ax1_0*41) - 1))), (ax1_2 - 1), ax2_2]])
                          tir.writes([PaddedInput[0, (((ax1_0*41) + ax0_4) - (ax1_0*41)), ax1_2, ax2_2]])
                          PaddedInput[0, (((ax1_0*41) + ax0_4) - (ax1_0*41)), ax1_2, ax2_2] = @tir.if_then_else(((((1 <= ((ax1_0*41) + ax0_4)) && (((ax1_0*41) + ax0_4) < 83)) && (1 <= ax1_2)) && (ax1_2 < 65)), rxplaceholder_shared_3[0, ((((ax1_0*41) + ax0_4) - 1) - max(0, ((ax1_0*41) - 1))), (ax1_2 - 1), ax2_2], 0i8, dtype=int8)
                      }
                    }
                  }
                  for (ax0_5: int32, 0, 256) {
                    for (ax1_3: int32, 0, 66) {
                      for (ax2_3: int32, 0, 3) {
                        for (ax3_1: int32, 0, 3) {
                          block([], "rxplaceholder_shared") {
                            tir.reads([rxplaceholder_1[ax0_5, ax1_3, ax2_3, ax3_1]])
                            tir.writes([rxplaceholder_shared_2[ax0_5, ax1_3, ax2_3, ax3_1]])
                            rxplaceholder_shared_2[ax0_5, ax1_3, ax2_3, ax3_1] = rxplaceholder_1[ax0_5, ax1_3, ax2_3, ax3_1]
                        }
                      }
                    }
                  }
                  for (ax0_6: int32, 0, 41) {
                    for (ax1_4: int32, 0, 64) {
                      for (ax2_4: int32, 0, 256) {
                        for (ax3_2: int32, 0, 3) {
                          for (ax4: int32, 0, 3) {
                            for (ax5: int32, 0, 66) {
                              block([], "Conv2dOutput11") {
                                tir.reads([PaddedInput[0, ((((ax1_0*41) + ax0_6) + ax3_2) - (ax1_0*41)), (ax1_4 + ax4), ax5], rxplaceholder_shared_2[ax2_4, ax5, ax3_2, ax4], rxplaceholder_shared_1[ax2_4], rxplaceholder_shared[0, 0, 0, ax2_4]])
                                tir.writes([Conv2dOutput11_shared[0, (((ax1_0*41) + ax0_6) - (ax1_0*41)), ax1_4, ax2_4]])
                                tir.attrs({"stride_h": 1, "stride_w": 1, "kernel_h": 3, "has_pad": True, "has_bias": True, "shift": 10, "kernel_w": 3, "layout": "NHWC", "in_channel": 66, "dilation_h": 1, "has_mul": True, "relu_type": 1, "dilation_w": 1, "out_channel": 256})
                                 {
                                  if (((ax3_2 == 0) && (ax4 == 0)) && (ax5 == 0)) {
                                    Conv2dOutput11_shared[0, (((ax1_0*41) + ax0_6) - (ax1_0*41)), ax1_4, ax2_4] = 0i8
                                  }
                                  Conv2dOutput11_shared[0, (((ax1_0*41) + ax0_6) - (ax1_0*41)), ax1_4, ax2_4] = (Conv2dOutput11_shared[0, (((ax1_0*41) + ax0_6) - (ax1_0*41)), ax1_4, ax2_4] + cast(int8, (cast(int32, (PaddedInput[0, ((((ax1_0*41) + ax0_6) + ax3_2) - (ax1_0*41)), (ax1_4 + ax4), ax5]*rxplaceholder_shared_2[ax2_4, ax5, ax3_2, ax4])) + (rxplaceholder_shared_1[ax2_4]*rxplaceholder_shared[0, 0, 0, ax2_4]))))
                                }
                            }
                          }
                        }
                      }
                    }
                  }
                  for (ax1_1_1: int32, 0, 41) {
                    for (ax2_1_1: int32, 0, 64) {
                      for (ax3_1_1: int32, 0, 256) {
                        block([], "Conv2dOutput11_shared") {
                          tir.reads([Conv2dOutput11_shared[ax0_2, (((ax1_0*41) + ax1_1_1) - (ax1_0*41)), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]])
                          tir.writes([Conv2dOutput11[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]])
                          Conv2dOutput11[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)] = Conv2dOutput11_shared[ax0_2, (((ax1_0*41) + ax1_1_1) - (ax1_0*41)), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]
                      }
                    }
                  }
                }
            }
          }

The problem is that the rxplaceholder_shared_3 buffer’s shape, the shapeI expect is [1,43,64,66], but got [1, 82, 64, 66]. It seems that it does not compact the buffer effectively. This will cause the buffer size to be too large.

I tried to debug the pass and added some printing:

printf("extent: %s \n", PrettyPrint(extent).c_str());

got:

extent: min(82, (max(((((max((ax1_0: int32*41), 1) + 82) - max((ax1_0*41), 40)) - max((1 - (ax1_0*41)), 0)) - (ax1_0*41)), 42) + 1))

Finally, it is found that the extent is 82, but whether ax1_0 equals 1 or 0, the extent should be 43:

int64_t upperbound = analyzer->const_int_bound(extent)->max_value;

Can someone help me look at this problem?

@Hzfengsy @wrongtest @junrushao I’m so sorry to bother you, can you bring some help to me?

Could you give one minimal example printed as TVMScript? I’m sorry that the current codes are too long to analyze.

Thanks for your reply, here’s the original script:

    @T.prim_func
    def conv2d_integer(rxplaceholder: T.Buffer[(1, 82, 64, 66), "int8"], rxplaceholder_1: T.Buffer[(256, 66, 3, 3), "int8"], rxplaceholder_2: T.Buffer[256, "int32"], rxplaceholder_3: T.Buffer[(1, 1, 1, 256), "int32"], Conv2dOutput11: T.Buffer[(1, 82, 64, 256), "int8"]):
        # function attr dict
        T.func_attr({"global_symbol": "conv2d_integer", "tir.noalias": True, "layout_free_buffers": [1]})
        # body
        # with T.block("root")
        PaddedInput = T.alloc_buffer([1, 84, 66, 66], dtype="int8")
        Conv2dOutput11_shared = T.alloc_buffer([1, 82, 64, 256], dtype="int8", scope="shared")
        rxplaceholder_shared = T.alloc_buffer([256, 66, 3, 3], dtype="int8", scope="shared")
        rxplaceholder_shared_1 = T.alloc_buffer([256], dtype="int32", scope="shared")
        rxplaceholder_shared_2 = T.alloc_buffer([1, 1, 1, 256], dtype="int32", scope="shared")
        rxplaceholder_shared_3 = T.alloc_buffer([1, 82, 64, 66], dtype="int8", scope="shared")
        for ax0 in T.serial(1, annotations={"pragma_dma_in":"scale"}):
            for ax1, ax2, ax3 in T.grid(1, 1, 256):
                with T.block("rxplaceholder_shared"):
                    v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                    T.reads(rxplaceholder_3[v0, v1, v2, v3])
                    T.writes(rxplaceholder_shared_2[v0, v1, v2, v3])
                    rxplaceholder_shared_2[v0, v1, v2, v3] = rxplaceholder_3[v0, v1, v2, v3]
        for ax0 in T.serial(256, annotations={"pragma_dma_in":"bias"}):
            with T.block("rxplaceholder_shared"):
                v0 = T.axis.spatial(256, ax0)
                T.reads(rxplaceholder_2[v0])
                T.writes(rxplaceholder_shared_1[v0])
                rxplaceholder_shared_1[v0] = rxplaceholder_2[v0]
        for ax0 in T.serial(1, annotations={"pragma_outer_loop":"l1"}):
            for ax1_0 in T.serial(2, annotations={"pragma_outer_loop":"l1"}):
                for ax2_0 in T.serial(1, annotations={"pragma_outer_loop":"l1"}):
                    for ax3_0 in T.serial(1, annotations={"pragma_outer_loop":"l1"}):
                        for ax0_1 in T.serial(43, annotations={"pragma_dma_in":"input"}):
                            for ax1, ax2 in T.grid(64, 66):
                                with T.block("rxplaceholder_shared"):
                                    T.where(1 <= ax1_0 * 41 + ax0_1 and ax1_0 * 41 + ax0_1 < 83)
                                    v0 = T.axis.spatial(1, 0)
                                    v1 = T.axis.spatial(82, ax1_0 * 41 - 1 + ax0_1)
                                    v2, v3 = T.axis.remap("SS", [ax1, ax2])
                                    T.reads(rxplaceholder[v0, v1, v2, v3])
                                    T.writes(rxplaceholder_shared_3[v0, v1, v2, v3])
                                    rxplaceholder_shared_3[v0, v1, v2, v3] = rxplaceholder[v0, v1, v2, v3]
                        for ax0_2 in T.serial(43, annotations={"pragma_pad_in_place":"1"}):
                            for ax1, ax2 in T.grid(66, 66):
                                with T.block("PaddedInput"):
                                    i0 = T.axis.spatial(1, 0)
                                    i1 = T.axis.spatial(84, ax1_0 * 41 + ax0_2)
                                    i2, i3 = T.axis.remap("SS", [ax1, ax2])
                                    T.reads(rxplaceholder_shared_3[i0, i1 - 1, i2 - 1, i3])
                                    T.writes(PaddedInput[i0, i1, i2, i3])
                                    PaddedInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 83 and 1 <= i2 and i2 < 65, rxplaceholder_shared_3[i0, i1 - 1, i2 - 1, i3], T.int8(0), dtype="int8")
                        for ax0_3 in T.serial(256, annotations={"pragma_dma_in":"weight"}):
                            for ax1, ax2, ax3 in T.grid(66, 3, 3):
                                with T.block("rxplaceholder_shared"):
                                    v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0_3, ax1, ax2, ax3])
                                    T.reads(rxplaceholder_1[v0, v1, v2, v3])
                                    T.writes(rxplaceholder_shared[v0, v1, v2, v3])
                                    rxplaceholder_shared[v0, v1, v2, v3] = rxplaceholder_1[v0, v1, v2, v3]
                        for ax0_4 in T.serial(41, annotations={"pragma_compute_attr_conv":"conv2d"}):
                            for ax1, ax2, ax3, ax4, ax5 in T.grid(64, 256, 3, 3, 66):
                                with T.block("Conv2dOutput11"):
                                    nn = T.axis.spatial(1, 0)
                                    yy = T.axis.spatial(82, ax1_0 * 41 + ax0_4)
                                    xx, ff, ry, rx, rc = T.axis.remap("SSRRR", [ax1, ax2, ax3, ax4, ax5])
                                    T.reads(PaddedInput[nn, yy + ry, xx + rx, rc], rxplaceholder_shared[ff, rc, ry, rx], rxplaceholder_shared_1[ff], rxplaceholder_shared_2[0, 0, 0, ff])
                                    T.writes(Conv2dOutput11_shared[nn, yy, xx, ff])
                                    T.block_attr({"dilation_h":1, "dilation_w":1, "has_bias":True, "has_mul":True, "has_pad":True, "in_channel":66, "kernel_h":3, "kernel_w":3, "layout":"NHWC", "out_channel":256, "relu_type":1, "shift":10, "stride_h":1, "stride_w":1})
                                    with T.init():
                                        Conv2dOutput11_shared[nn, yy, xx, ff] = T.int8(0)
                                    Conv2dOutput11_shared[nn, yy, xx, ff] = Conv2dOutput11_shared[nn, yy, xx, ff] + T.Cast("int8", T.Cast("int32", PaddedInput[nn, yy + ry, xx + rx, rc] * rxplaceholder_shared[ff, rc, ry, rx]) + rxplaceholder_shared_1[ff] * rxplaceholder_shared_2[0, 0, 0, ff])
                        for ax1_1 in T.serial(41, annotations={"pragma_dma_out":"output"}):
                            for ax2_1, ax3_1 in T.grid(64, 256):
                                with T.block("Conv2dOutput11_shared"):
                                    v0 = T.axis.spatial(1, ax0)
                                    v1 = T.axis.spatial(82, ax1_0 * 41 + ax1_1)
                                    v2 = T.axis.spatial(64, ax2_0 * 64 + ax2_1)
                                    v3 = T.axis.spatial(256, ax3_0 * 256 + ax3_1)
                                    T.reads(Conv2dOutput11_shared[v0, v1, v2, v3])
                                    T.writes(Conv2dOutput11[v0, v1, v2, v3])
                                    Conv2dOutput11[v0, v1, v2, v3] = Conv2dOutput11_shared[v0, v1, v2, v3]

My main concern is:

rxplaceholder_shared_3 = T.alloc_buffer([1, 82, 64, 66], dtype="int8", scope="shared")

My debugging environment is relatively closed, The minimal example I’m trying to build. Thanks.

This is simplified version of script:

    @T.prim_func
    def conv2d_integer(rxplaceholder: T.Buffer[(1, 82, 64, 66), "int8"], rxplaceholder_1: T.Buffer[(256, 66, 3, 3), "int8"], rxplaceholder_2: T.Buffer[256, "int32"], rxplaceholder_3: T.Buffer[(1, 1, 1, 256), "int32"], Conv2dOutput11: T.Buffer[(1, 82, 64, 256), "int8"]):
        # function attr dict
        T.func_attr({"global_symbol": "conv2d_integer", "tir.noalias": True, "layout_free_buffers": [1]})
        # body
        # with T.block("root")
        PaddedInput = T.alloc_buffer([1, 84, 66, 66], dtype="int8")
        Conv2dOutput11_shared = T.alloc_buffer([1, 82, 64, 256], dtype="int8", scope="shared")
        rxplaceholder_shared = T.alloc_buffer([1, 82, 64, 66], dtype="int8", scope="shared")
        for ax0, ax1_0 in T.grid(1, 2):
            for ax0_1 in T.serial(43, annotations={"pragma_dma_in":"input"}):
                for ax1, ax2 in T.grid(64, 66):
                    with T.block("rxplaceholder_shared"):
                        T.where(1 <= ax1_0 * 41 + ax0_1 and ax1_0 * 41 + ax0_1 < 83)
                        v0 = T.axis.spatial(1, 0)
                        v1 = T.axis.spatial(82, ax1_0 * 41 - 1 + ax0_1)
                        v2, v3 = T.axis.remap("SS", [ax1, ax2])
                        T.reads(rxplaceholder[v0, v1, v2, v3])
                        T.writes(rxplaceholder_shared[v0, v1, v2, v3])
                        rxplaceholder_shared[v0, v1, v2, v3] = rxplaceholder[v0, v1, v2, v3]
            for ax0_2 in T.serial(43, annotations={"pragma_pad_in_place":"1"}):
                for ax1, ax2 in T.grid(66, 66):
                    with T.block("PaddedInput"):
                        i0 = T.axis.spatial(1, 0)
                        i1 = T.axis.spatial(84, ax1_0 * 41 + ax0_2)
                        i2, i3 = T.axis.remap("SS", [ax1, ax2])
                        T.reads(rxplaceholder_shared[i0, i1 - 1, i2 - 1, i3])
                        T.writes(PaddedInput[i0, i1, i2, i3])
                        PaddedInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 83 and 1 <= i2 and i2 < 65, rxplaceholder_shared[i0, i1 - 1, i2 - 1, i3], T.int8(0), dtype="int8")
            for ax0_4 in T.serial(41, annotations={"pragma_compute_attr_conv":"conv2d"}):
                for ax1, ax2, ax3, ax4, ax5 in T.grid(64, 256, 3, 3, 66):
                    with T.block("Conv2dOutput11"):
                        nn = T.axis.spatial(1, 0)
                        yy = T.axis.spatial(82, ax1_0 * 41 + ax0_4)
                        xx, ff, ry, rx, rc = T.axis.remap("SSRRR", [ax1, ax2, ax3, ax4, ax5])
                        T.reads(PaddedInput[nn, yy + ry, xx + rx, rc], rxplaceholder_1[ff, rc, ry, rx], rxplaceholder_2[ff], rxplaceholder_3[0, 0, 0, ff])
                        T.writes(Conv2dOutput11_shared[nn, yy, xx, ff])
                        T.block_attr({"dilation_h":1, "dilation_w":1, "has_bias":True, "has_mul":True, "has_pad":True, "in_channel":66, "kernel_h":3, "kernel_w":3, "layout":"NHWC", "out_channel":256, "relu_type":1, "shift":10, "stride_h":1, "stride_w":1})
                        with T.init():
                            Conv2dOutput11_shared[nn, yy, xx, ff] = T.int8(0)
                        Conv2dOutput11_shared[nn, yy, xx, ff] = Conv2dOutput11_shared[nn, yy, xx, ff] + T.Cast("int8", T.Cast("int32", PaddedInput[nn, yy + ry, xx + rx, rc] * rxplaceholder_1[ff, rc, ry, rx]) + rxplaceholder_2[ff] * rxplaceholder_3[0, 0, 0, ff])
            for ax1_1 in T.serial(41, annotations={"pragma_dma_out":"output"}):
                for ax2_1, ax3_1 in T.grid(64, 256):
                    with T.block("Conv2dOutput11_shared"):
                        v0 = T.axis.spatial(1, ax0)
                        v1 = T.axis.spatial(82, ax1_0 * 41 + ax1_1)
                        v2 = T.axis.spatial(64, ax2_1)
                        v3 = T.axis.spatial(256, ax3_1)
                        T.reads(Conv2dOutput11_shared[v0, v1, v2, v3])
                        T.writes(Conv2dOutput11[v0, v1, v2, v3])
                        Conv2dOutput11[v0, v1, v2, v3] = Conv2dOutput11_shared[v0, v1, v2, v3]

It looks like an arithmetic problem. Would be great if you can dive into the analyzer and fix it

Thanks for your reply. I’ll learn about the corresponding implementation in analyzer. I’ll be glad to fix it if I can.

Hi :slight_smile: Thanks for the finding~ The compact buffer pass has a fallback implementation when the inferred shape extents have loop carried dependency, it will try to use const_int_bound to get a constant upperbound to eliminate the dynamic shape. Refer to https://github.com/apache/tvm/pull/11428

This usually happens when we have non-uniform tiling conditions. And currently const_int_bound may get a too loose upperbound since it is not enough clever to achive what you mentioned like…

I’m trying to handle this issue by independently infering upperbounds on different iter partitions, thus avoid too loose upperbound wrt the whole iter region. We also have suffered from similar workloads, thus very glad to know any opinions to improve that!

Oh, that’s great. It seems that this situation is reasonable. I am very concerned about the buffer compaction here, because it will affect the subsequent mem allocation and optimization. I will also learn about the specific implementation of this pass and see if I can provide some help.

I appreciate and look forward to your repair here.

Thanks.