Strange CUDNN latency on conv2d

I want to measure the latency of a conv2d operator with cuDNN. The operator data shape is (16, 128, 130, 130), and the kernel shape is (128, 128, 3, 3). stride = 1, padding =0, dilation=1 I use the following code to do measurement:

data_np = np.random.uniform(size= (16, 128, 130, 130)).astype(np.float32)
weight_np = np.random.uniform(size=(128, 128, 3, 3)).astype(np.float32)
out_np = conv2d_nchw_python(data_np, weight_np, (sh, sw), padding)

dev = tvm.cuda()
data_tvm = tvm.nd.array(data_np, device=dev)
weight_tvm = tvm.nd.array(weight_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)

X = te.placeholder((16, 128, 130, 130), name='X')
W = te.placeholder((128, 128, 3, 3), name='W')
Y = tvm.contrib.cudnn.conv_forward(X, W, padding, (1,1), (1,1), 0, 0, -1, None, groups=1)
tensor_args_cudnn = [X, W, Y]
sched_cudnn = te.create_schedule(Y.op)
cudnn_kernel = tvm.build(sched_cudnn, [X, W, Y], target=tvm.target.Target("cuda"))
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
cudnn_kernel(data_tvm, weight_tvm, out_tvm)

# check results of cudnn
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)
warmup_evaluator = cudnn_kernel.time_evaluator(cudnn_kernel.entry_name, dev, number=3, repeat=1, min_repeat_ms=300)
warmup_evaluator(data_tvm, weight_tvm, out_tvm)
time_evaluator = cudnn_kernel.time_evaluator(cudnn_kernel.entry_name, dev,
				number=3, repeat=1, min_repeat_ms=300)

However, the time is strange, which is 0.0008706393072164948 on A100 GPU, so the GFLOPS is 88796.14173998704, while the peak GFLOPS of A100 is less than 20000.

Can anyone point out what’s wrong here? Thanks a lot!

Is it possible that cuDNN is using tensorcore internally via tf32?

I guess so. Do you know how to disable the tensorcore in cuDNN, or how to check whether the tensorcore is used?

Also, when I am checking the correctness, the “np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)” can fail on some conv2d shapes. Is it because the tf32 is used by cuDNN?

Yeah, if the perf is better than expected and there is also an accuracy issue, it is likely that cuDNN is using lower-precision tensorcore. This is not what should be happening by default.

Can you try commenting out the line https://github.com/apache/tvm/blob/1cf0c0a5bfa6b0c61ce253142b66f6235d694e07/src/runtime/contrib/cudnn/cudnn_utils.cc#L258?

Hi, I download the latest TVM version, and comment out that line, but the result does not change.

>>> np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)
AssertionError:
Not equal to tolerance rtol=0.001, atol=0

Mismatched elements: 31520965 / 33554432 (93.9%)
Max absolute difference: 21.738129
Max relative difference: 0.07819495
x: array([[[[294.9323 , 292.7774 , 291.48584, ..., 289.28162, 286.24948,
      283.29086],
     [290.47208, 288.9986 , 286.26587, ..., 289.65204, 279.58716,...
y: array([[[[293.80002, 293.36163, 294.78125, ..., 286.68457, 282.50385,
      283.06412],
     [294.82407, 292.03946, 287.39087, ..., 289.7172 , 275.12323,...
>>> warmup_evaluator = cudnn_kernel.time_evaluator(cudnn_kernel.entry_name, dev, number=300, 
repeat=1, min_repeat_ms=300)
>>> warmup_evaluator(data_tvm, weight_tvm, out_tvm)
BenchmarkResult(min=0.0008357166350515464, mean=0.0008357166350515464, 
median=0.0008357166350515464, max=0.0008357166350515464, std=0.0, results= 
(0.0008357166350515464,))

How about replacing CUDNN_TENSOR_OP_MATH with CUDNN_FMA_MATH? See the cuDNN reference API Reference :: NVIDIA Deep Learning cuDNN Documentation. As the doc suggests, you can also try using NVIDIA_TF32_OVERRIDE=0.

Hi, I tried both methods. Now the latency becomes longer, but the correctness checking still fails:

Not equal to tolerance rtol=0.001, atol=0

Mismatched elements: 31518164 / 33554432 (93.9%)
Max absolute difference: 20.986359
Max relative difference: 0.07764272
x: array([[[[281.5848 , 277.72504, 287.6785 , ..., 283.46228, 282.42328,
      284.68054],
     [281.51117, 282.23495, 282.1859 , ..., 282.11166, 277.6587 ,...
y: array([[[[275.6013 , 281.1124 , 282.56717, ..., 283.83157, 285.45538,
      286.98752],
     [283.07947, 286.64365, 283.8292 , ..., 285.82422, 278.8415 ,...

>>> warmup_evaluator = cudnn_kernel.time_evaluator(cudnn_kernel.entry_name, dev, 
number=300, repeat=1, min_repeat_ms=300)
>>> warmup_evaluator(data_tvm, weight_tvm, out_tvm)
BenchmarkResult(min=0.0018765057033333335, mean=0.0018765057033333335, 
median=0.0018765057033333335, max=0.0018765057033333335, std=0.0, results= 
(0.0018765057033333335,))
>>> time_evaluator = cudnn_kernel.time_evaluator(cudnn_kernel.entry_name, dev,
...                     number=100, repeat=10, min_repeat_ms=100)
>>> vendor_cost = np.average(time_evaluator(data_tvm, weight_tvm, out_tvm).results)
>>> print(f"the avg run time of vendor kernel: {vendor_cost}")
the avg run time of vendor kernel: 0.00176073022

The GFLOPS becomes 43904.823392968334 now.

I think a possible explanation is that cuDNN choose the winograd algorithm, so the GFLOPS I computed is wrong.