[relay.op.nn] AvgPool2D in pytorch gives different results than tvm?

I found this behavior, didn’t expect it, could use some help for understanding it. Here is a script which shows this, please note that with strides=(1, 1) it produces equal results as pytorch. Strides=(2, 2) give different results. As this came up recently, I suspect this behavior is caused by https://github.com/apache/tvm/pull/9835. I tried to modify the calculations a little bit and below the script there is a diff which will help to produce exactly the same output as pytorch.

import torch
from tvm import relay
import tvm
import numpy as np

input_shape = (1, 16, 32, 24)
pool_size = (3, 3)
strides = (2, 2)
dilation = (1, 1)
padding = (0, 0)
ceil_mode = True
layout = "NHWC"
count_include_pad = True
inputs = np.random.randint(-10, 10, input_shape)
inputs = inputs.astype('float32')

data = relay.var("data", shape=input_shape)
out = relay.nn.avg_pool2d(data, pool_size=pool_size, strides=strides, dilation=dilation, padding=padding, ceil_mode=ceil_mode, layout=layout, count_include_pad=count_include_pad)

# wrap into a function and make a module
f = relay.Function(relay.analysis.free_vars(out), out)
mod = tvm.IRModule.from_expr(f)

ctx = tvm.device("llvm -mcpu=core-avx2")
executor = tvm.relay.create_executor('vm', mod, ctx, "llvm -mcpu=core-avx2")
prediction = executor.evaluate()(inputs)
prediction = prediction.asnumpy().astype('float')

inputs = np.transpose(inputs, (0, 3, 1, 2))
m = torch.nn.AvgPool2d(pool_size, strides, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
output = m(torch.from_numpy(inputs))
output= output.numpy().astype('float')
output = np.transpose(output, (0, 2, 3, 1))

print(prediction.shape)
print(output.shape)
print(prediction.dtype)
print(output.dtype)

if not np.allclose(prediction, output):
	max_l1_err = np.max(abs(prediction.astype(np.float32) - output.astype(np.float32)))
	raise Exception(f"Error: Outputs do not match within tolerance. Got {max_l1_err}")

###################

diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h
index c81c7cda7..1c6688793 100644
--- a/include/tvm/topi/nn/pooling.h
+++ b/include/tvm/topi/nn/pooling.h
@@ -612,14 +612,14 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
             auto num_el = make_const(DataType::Int(32), 1);
             for (int i = 0; i < k_size; i++) {
               int ii = axis[i];
-              start[i] = output[ii] * stride[i] - pad_head[i];
+              start[i] = (output[ii] - 1) * stride[i] - pad_head[i];
               // When computing the output shape in ceil_mode,
               // we have added the extra padding of offset[i],
               // so now in order to calculate the correct boundary ,
               // we need to substract the offset[i].
-              end[i] = start[i] + (kernel[i] - 1) * dilation[i];
+              end[i] = start[i] + (kernel[i] - 1) * dilation[i] + 1;
               end[i] = min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]);
-              num_el *= (end[i] - start[i]) / dilation[i] + 1;
+              num_el *= (end[i] - start[i]) / dilation[i];
             }
             return div(pool_sum(indices), num_el);
           } else {
1 Like

@xiaolong18 @masahi Kindly tag you because I saw you worked/reviewed the PR. Could you please provide some feedback on this thread?

ok I’ll take a look.

Can you also verify that the change in https://github.com/apache/tvm/pull/9835 is a good one for its purpose?

Hi @masahi thanks for the prompt response. I am not familiar with the padding calculation enough to verify it. What we found is the results are different between Pytorch and TVM. We will do our best to take a deeper look. Maybe @xiaolong18 can put more comments on it.

Can you open an issue for this? So that I won’t forget.

No problem. Does this look good to you @masahi https://github.com/apache/tvm/issues/10614?

yes thanks. I will be a bit busy for up to 2 month so I don’t know when I can get to it.