I found this behavior, didn’t expect it, could use some help for understanding it. Here is a script which shows this, please note that with strides=(1, 1) it produces equal results as pytorch. Strides=(2, 2) give different results. As this came up recently, I suspect this behavior is caused by https://github.com/apache/tvm/pull/9835. I tried to modify the calculations a little bit and below the script there is a diff which will help to produce exactly the same output as pytorch.
import torch
from tvm import relay
import tvm
import numpy as np
input_shape = (1, 16, 32, 24)
pool_size = (3, 3)
strides = (2, 2)
dilation = (1, 1)
padding = (0, 0)
ceil_mode = True
layout = "NHWC"
count_include_pad = True
inputs = np.random.randint(-10, 10, input_shape)
inputs = inputs.astype('float32')
data = relay.var("data", shape=input_shape)
out = relay.nn.avg_pool2d(data, pool_size=pool_size, strides=strides, dilation=dilation, padding=padding, ceil_mode=ceil_mode, layout=layout, count_include_pad=count_include_pad)
# wrap into a function and make a module
f = relay.Function(relay.analysis.free_vars(out), out)
mod = tvm.IRModule.from_expr(f)
ctx = tvm.device("llvm -mcpu=core-avx2")
executor = tvm.relay.create_executor('vm', mod, ctx, "llvm -mcpu=core-avx2")
prediction = executor.evaluate()(inputs)
prediction = prediction.asnumpy().astype('float')
inputs = np.transpose(inputs, (0, 3, 1, 2))
m = torch.nn.AvgPool2d(pool_size, strides, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
output = m(torch.from_numpy(inputs))
output= output.numpy().astype('float')
output = np.transpose(output, (0, 2, 3, 1))
print(prediction.shape)
print(output.shape)
print(prediction.dtype)
print(output.dtype)
if not np.allclose(prediction, output):
max_l1_err = np.max(abs(prediction.astype(np.float32) - output.astype(np.float32)))
raise Exception(f"Error: Outputs do not match within tolerance. Got {max_l1_err}")
###################
diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h
index c81c7cda7..1c6688793 100644
--- a/include/tvm/topi/nn/pooling.h
+++ b/include/tvm/topi/nn/pooling.h
@@ -612,14 +612,14 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
auto num_el = make_const(DataType::Int(32), 1);
for (int i = 0; i < k_size; i++) {
int ii = axis[i];
- start[i] = output[ii] * stride[i] - pad_head[i];
+ start[i] = (output[ii] - 1) * stride[i] - pad_head[i];
// When computing the output shape in ceil_mode,
// we have added the extra padding of offset[i],
// so now in order to calculate the correct boundary ,
// we need to substract the offset[i].
- end[i] = start[i] + (kernel[i] - 1) * dilation[i];
+ end[i] = start[i] + (kernel[i] - 1) * dilation[i] + 1;
end[i] = min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]);
- num_el *= (end[i] - start[i]) / dilation[i] + 1;
+ num_el *= (end[i] - start[i]) / dilation[i];
}
return div(pool_sum(indices), num_el);
} else {