Hello, I’m working with Relax and TensorIR. When I run to the pass of CompactBufferAllocation, the result is not as expected.
Before this pass, the tir is like:
for (ax0_2: int32, 0, 1) {
for (ax1_0: int32, 0, 2) {
for (ax2_0: int32, 0, 1) {
for (ax3_0: int32, 0, 1) {
block([], "") {
tir.reads([rxplaceholder[0, ((ax1_0*41) - 1):(((ax1_0*41) - 1) + 43), 0:64, 0:66], rxplaceholder_1[0:256, 0:66, 0:3, 0:3], rxplaceholder_shared_1[0:256], rxplaceholder_shared[0, 0, 0, 0:256]])
tir.writes([Conv2dOutput11[ax0_2, (ax1_0*41):((ax1_0*41) + 41), (ax2_0*64):((ax2_0*64) + 64), (ax3_0*256):((ax3_0*256) + 256)]])
PaddedInput = alloc_buffer(int8[1, 84, 66, 66])
rxplaceholder_shared_2 = alloc_buffer(int8[256, 66, 3, 3])
Conv2dOutput11_shared = alloc_buffer(int8[1, 82, 64, 256])
rxplaceholder_shared_3 = alloc_buffer(int8[1, 82, 64, 66])
{
for (ax0_3: int32, 0, 43) {
for (ax1_1: int32, 0, 64) {
for (ax2_1: int32, 0, 66) {
block([], "rxplaceholder_shared") {
where(((1 <= ((ax1_0*41) + ax0_3)) && (((ax1_0*41) + ax0_3) < 83)))
tir.reads([rxplaceholder[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1]])
tir.writes([rxplaceholder_shared_3[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1]])
rxplaceholder_shared_3[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1] = rxplaceholder[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1]
}
}
}
for (ax0_4: int32, 0, 43) {
for (ax1_2: int32, 0, 66) {
for (ax2_2: int32, 0, 66) {
block([], "PaddedInput") {
tir.reads([rxplaceholder_shared_3[0, (((ax1_0*41) + ax0_4) - 1), (ax1_2 - 1), ax2_2]])
tir.writes([PaddedInput[0, ((ax1_0*41) + ax0_4), ax1_2, ax2_2]])
PaddedInput[0, ((ax1_0*41) + ax0_4), ax1_2, ax2_2] = @tir.if_then_else(((((1 <= ((ax1_0*41) + ax0_4)) && (((ax1_0*41) + ax0_4) < 83)) && (1 <= ax1_2)) && (ax1_2 < 65)), rxplaceholder_shared_3[0, (((ax1_0*41) + ax0_4) - 1), (ax1_2 - 1), ax2_2], 0i8, dtype=int8)
}
}
}
for (ax0_5: int32, 0, 256) {
for (ax1_3: int32, 0, 66) {
for (ax2_3: int32, 0, 3) {
for (ax3_1: int32, 0, 3) {
block([], "rxplaceholder_shared") {
tir.reads([rxplaceholder_1[ax0_5, ax1_3, ax2_3, ax3_1]])
tir.writes([rxplaceholder_shared_2[ax0_5, ax1_3, ax2_3, ax3_1]])
rxplaceholder_shared_2[ax0_5, ax1_3, ax2_3, ax3_1] = rxplaceholder_1[ax0_5, ax1_3, ax2_3, ax3_1]
}
}
}
}
for (ax0_6: int32, 0, 41) {
for (ax1_4: int32, 0, 64) {
for (ax2_4: int32, 0, 256) {
for (ax3_2: int32, 0, 3) {
for (ax4: int32, 0, 3) {
for (ax5: int32, 0, 66) {
block([], "Conv2dOutput11") {
tir.reads([PaddedInput[0, (((ax1_0*41) + ax0_6) + ax3_2), (ax1_4 + ax4), ax5], rxplaceholder_shared_2[ax2_4, ax5, ax3_2, ax4], rxplaceholder_shared_1[ax2_4], rxplaceholder_shared[0, 0, 0, ax2_4]])
tir.writes([Conv2dOutput11_shared[0, ((ax1_0*41) + ax0_6), ax1_4, ax2_4]])
tir.attrs({"stride_h": 1, "stride_w": 1, "kernel_h": 3, "has_pad": True, "has_bias": True, "shift": 10, "kernel_w": 3, "layout": "NHWC", "in_channel": 66, "dilation_h": 1, "has_mul": True, "relu_type": 1, "dilation_w": 1, "out_channel": 256})
{
if (((ax3_2 == 0) && (ax4 == 0)) && (ax5 == 0)) {
Conv2dOutput11_shared[0, ((ax1_0*41) + ax0_6), ax1_4, ax2_4] = 0i8
}
Conv2dOutput11_shared[0, ((ax1_0*41) + ax0_6), ax1_4, ax2_4] = (Conv2dOutput11_shared[0, ((ax1_0*41) + ax0_6), ax1_4, ax2_4] + cast(int8, (cast(int32, (PaddedInput[0, (((ax1_0*41) + ax0_6) + ax3_2), (ax1_4 + ax4), ax5]*rxplaceholder_shared_2[ax2_4, ax5, ax3_2, ax4])) + (rxplaceholder_shared_1[ax2_4]*rxplaceholder_shared[0, 0, 0, ax2_4]))))
}
}
}
}
}
}
}
for (ax1_1_1: int32, 0, 41) {
for (ax2_1_1: int32, 0, 64) {
for (ax3_1_1: int32, 0, 256) {
block([], "Conv2dOutput11_shared") {
tir.reads([Conv2dOutput11_shared[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]])
tir.writes([Conv2dOutput11[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]])
Conv2dOutput11[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)] = Conv2dOutput11_shared[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]
}
}
}
}
}
}
}
}
}
After Pass is:
for (ax0_2: int32, 0, 1) {
for (ax1_0: int32, 0, 2) {
for (ax2_0: int32, 0, 1) {
for (ax3_0: int32, 0, 1) {
block([], "") {
tir.reads([rxplaceholder[0, ((ax1_0*41) - 1):(((ax1_0*41) - 1) + 43), 0:64, 0:66], rxplaceholder_1[0:256, 0:66, 0:3, 0:3], rxplaceholder_shared_1[0:256], rxplaceholder_shared[0, 0, 0, 0:256]])
tir.writes([Conv2dOutput11[ax0_2, (ax1_0*41):((ax1_0*41) + 41), (ax2_0*64):((ax2_0*64) + 64), (ax3_0*256):((ax3_0*256) + 256)]])
PaddedInput = alloc_buffer(int8[1, 43, 66, 66])
rxplaceholder_shared_2 = alloc_buffer(int8[256, 66, 3, 3])
Conv2dOutput11_shared = alloc_buffer(int8[1, 41, 64, 256])
rxplaceholder_shared_3 = alloc_buffer(int8[1, 82, 64, 66])
{
for (ax0_3: int32, 0, 43) {
for (ax1_1: int32, 0, 64) {
for (ax2_1: int32, 0, 66) {
block([], "rxplaceholder_shared") {
where(((1 <= ((ax1_0*41) + ax0_3)) && (((ax1_0*41) + ax0_3) < 83)))
tir.reads([rxplaceholder[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1]])
tir.writes([rxplaceholder_shared_3[0, ((((ax1_0*41) - 1) + ax0_3) - max(0, ((ax1_0*41) - 1))), ax1_1, ax2_1]])
rxplaceholder_shared_3[0, ((((ax1_0*41) - 1) + ax0_3) - max(0, ((ax1_0*41) - 1))), ax1_1, ax2_1] = rxplaceholder[0, (((ax1_0*41) - 1) + ax0_3), ax1_1, ax2_1]
}
}
}
for (ax0_4: int32, 0, 43) {
for (ax1_2: int32, 0, 66) {
for (ax2_2: int32, 0, 66) {
block([], "PaddedInput") {
tir.reads([rxplaceholder_shared_3[0, ((((ax1_0*41) + ax0_4) - 1) - max(0, ((ax1_0*41) - 1))), (ax1_2 - 1), ax2_2]])
tir.writes([PaddedInput[0, (((ax1_0*41) + ax0_4) - (ax1_0*41)), ax1_2, ax2_2]])
PaddedInput[0, (((ax1_0*41) + ax0_4) - (ax1_0*41)), ax1_2, ax2_2] = @tir.if_then_else(((((1 <= ((ax1_0*41) + ax0_4)) && (((ax1_0*41) + ax0_4) < 83)) && (1 <= ax1_2)) && (ax1_2 < 65)), rxplaceholder_shared_3[0, ((((ax1_0*41) + ax0_4) - 1) - max(0, ((ax1_0*41) - 1))), (ax1_2 - 1), ax2_2], 0i8, dtype=int8)
}
}
}
for (ax0_5: int32, 0, 256) {
for (ax1_3: int32, 0, 66) {
for (ax2_3: int32, 0, 3) {
for (ax3_1: int32, 0, 3) {
block([], "rxplaceholder_shared") {
tir.reads([rxplaceholder_1[ax0_5, ax1_3, ax2_3, ax3_1]])
tir.writes([rxplaceholder_shared_2[ax0_5, ax1_3, ax2_3, ax3_1]])
rxplaceholder_shared_2[ax0_5, ax1_3, ax2_3, ax3_1] = rxplaceholder_1[ax0_5, ax1_3, ax2_3, ax3_1]
}
}
}
}
for (ax0_6: int32, 0, 41) {
for (ax1_4: int32, 0, 64) {
for (ax2_4: int32, 0, 256) {
for (ax3_2: int32, 0, 3) {
for (ax4: int32, 0, 3) {
for (ax5: int32, 0, 66) {
block([], "Conv2dOutput11") {
tir.reads([PaddedInput[0, ((((ax1_0*41) + ax0_6) + ax3_2) - (ax1_0*41)), (ax1_4 + ax4), ax5], rxplaceholder_shared_2[ax2_4, ax5, ax3_2, ax4], rxplaceholder_shared_1[ax2_4], rxplaceholder_shared[0, 0, 0, ax2_4]])
tir.writes([Conv2dOutput11_shared[0, (((ax1_0*41) + ax0_6) - (ax1_0*41)), ax1_4, ax2_4]])
tir.attrs({"stride_h": 1, "stride_w": 1, "kernel_h": 3, "has_pad": True, "has_bias": True, "shift": 10, "kernel_w": 3, "layout": "NHWC", "in_channel": 66, "dilation_h": 1, "has_mul": True, "relu_type": 1, "dilation_w": 1, "out_channel": 256})
{
if (((ax3_2 == 0) && (ax4 == 0)) && (ax5 == 0)) {
Conv2dOutput11_shared[0, (((ax1_0*41) + ax0_6) - (ax1_0*41)), ax1_4, ax2_4] = 0i8
}
Conv2dOutput11_shared[0, (((ax1_0*41) + ax0_6) - (ax1_0*41)), ax1_4, ax2_4] = (Conv2dOutput11_shared[0, (((ax1_0*41) + ax0_6) - (ax1_0*41)), ax1_4, ax2_4] + cast(int8, (cast(int32, (PaddedInput[0, ((((ax1_0*41) + ax0_6) + ax3_2) - (ax1_0*41)), (ax1_4 + ax4), ax5]*rxplaceholder_shared_2[ax2_4, ax5, ax3_2, ax4])) + (rxplaceholder_shared_1[ax2_4]*rxplaceholder_shared[0, 0, 0, ax2_4]))))
}
}
}
}
}
}
}
for (ax1_1_1: int32, 0, 41) {
for (ax2_1_1: int32, 0, 64) {
for (ax3_1_1: int32, 0, 256) {
block([], "Conv2dOutput11_shared") {
tir.reads([Conv2dOutput11_shared[ax0_2, (((ax1_0*41) + ax1_1_1) - (ax1_0*41)), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]])
tir.writes([Conv2dOutput11[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]])
Conv2dOutput11[ax0_2, ((ax1_0*41) + ax1_1_1), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)] = Conv2dOutput11_shared[ax0_2, (((ax1_0*41) + ax1_1_1) - (ax1_0*41)), ((ax2_0*64) + ax2_1_1), ((ax3_0*256) + ax3_1_1)]
}
}
}
}
}
}
The problem is that the rxplaceholder_shared_3 buffer’s shape, the shapeI expect is [1,43,64,66], but got [1, 82, 64, 66]. It seems that it does not compact the buffer effectively. This will cause the buffer size to be too large.
I tried to debug the pass and added some printing:
printf("extent: %s \n", PrettyPrint(extent).c_str());
got:
extent: min(82, (max(((((max((ax1_0: int32*41), 1) + 82) - max((ax1_0*41), 40)) - max((1 - (ax1_0*41)), 0)) - (ax1_0*41)), 42) + 1))
Finally, it is found that the extent is 82, but whether ax1_0 equals 1 or 0, the extent should be 43:
int64_t upperbound = analyzer->const_int_bound(extent)->max_value;
Can someone help me look at this problem?