Conv2D CUDA Performance without tuning

Hi,

I am trying to build an AI-driven performance predictor, which is able to predict the power consumption, memory allocation and inference time of a network based on its Relay description. The final goal is to enable an AI-driven scheduling of inference exeuctions across the targets of heterogeneous or distriubted systems, achieving the most optimal performance. (But it is still is at an early stage, as I am currently struggleing with the following problem)

To do this I am benchmarking and profiling a larger number of random hyperparameter configurations for each relevant layer type and train a regression model for each combination of layer type, performance characteristic and target device.

This seems to be working reasonably well for Fully Connected and Pooling Layers. For Conv2D workloads on CUDA targets, however, I am not able to fit my models to the measurement data.

To investigate I manually profilied very similar hyperparameter configurations and realized that small configuration changes can a huge impact on the performance, which I cannot explain:

One example:

  • Input Tensor (NCHW): (1, 3, 225, 225)
  • Kernel: 3x3, 32 Output Channels
  • Dilation: 1
  • Strides: (1, 1)
  • Groups: 1
  • Padding: 0
  • Measured Runtime: 178 µsecs
  • Measured Power Consumption: 225 W
  • Measured Memory Allocation: 430MB

now, when I add a padding of 1:

  • Measured Runtime: 30 µsecs
  • Measured Power Consumption: 270 W
  • Measured Memory Allocation: 430MB

I am not performing tuning, as it takes too much time and I would like to provide a baseline with the performance prediction. I looked into TOPI and how the CUDA backend selects the schedules/templates, but it looks like, both layer configurations should be executed using topi.cuda.conv2d_nchw.

How can this huge difference in execution time be explained?

EDIT: To showcase the inconsistency, I measured the the execution time of the same workload with increasing padding:

There is no guarantee to the performance without tuning. Even small changes like padding can cause great difference. For example, padding can affect the size of shared memory loading. Some imperfect tiling could introduce predicated statements that is slower. You can check the generated Cuda code.

1 Like

Thank you for your response :slight_smile:

I looked into the generated CUDA code and it looks like you are right - the slower configurations do not seem to utilize the shared memory and have some difference in the overall code.

Slower kernel:

float compute_local[16];
__shared__ float pad_temp_shared[540];
__shared__ float placeholder_shared[864];
for (int yy_c_init = 0; yy_c_init < 2; ++yy_c_init) {
    compute_local[(yy_c_init)] = 0.000000e+00f;
    compute_local[((yy_c_init + 4))] = 0.000000e+00f;
    compute_local[((yy_c_init + 8))] = 0.000000e+00f;
    compute_local[((yy_c_init + 12))] = 0.000000e+00f;
    compute_local[((yy_c_init + 2))] = 0.000000e+00f;
    compute_local[((yy_c_init + 6))] = 0.000000e+00f;
    compute_local[((yy_c_init + 10))] = 0.000000e+00f;
    compute_local[((yy_c_init + 14))] = 0.000000e+00f;
}

Faster kernel:

float compute_local[4];
__shared__ float pad_temp_shared[27];
__shared__ float placeholder_shared[864];
compute_local[(0)] = 0.000000e+00f;
compute_local[(1)] = 0.000000e+00f;
compute_local[(2)] = 0.000000e+00f;
compute_local[(3)] = 0.000000e+00f;
for (int ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner = 0; ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner < 4; ++ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) {
  if (((((int)threadIdx.z) * 4) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) < 27){
    pad_temp_shared[(((((int)threadIdx.z) * 4) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner))] = placeholder[(((((((((((int)threadIdx.z) * 4) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) / 9) * 50625) + (((((((int)threadIdx.z) * 4) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) % 9) / 3) * 225)) + (((int)blockIdx.y) * 225)) + ((int)blockIdx.x)) + (((((int)threadIdx.z) * 4) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) % 3)))];

The difference between these two is just 1 pixel in the width of the input tensor.

I was hoping, that it might have another cause, as the measured data seems very structured if only one parameter is changed at a time:

This example shows the impact of increasing the input feature map height.

Now I need to find a way to collect a large number of samples without spending a lot of time on tuning, or is there another way to see why TVM generates the kernels as it does?

EDIT:

  • Is there a way to access the knobs of the schedule after compiling the module? Maybe it is possible to use them as additional inputs for the performance prediction, just as AutoTVM does it during tuning
  • For testing, I tried to disable the download from tophub to always get an untuned schedule by always returning an empty context during relay.build_module.build, however, it did not change the results
  • I tested with a different tensor layout: using NHWC and HWIO instead of NCHW and OIHW and the noise between runs is gone, but this layout seems much slower and execution fails for larger tensors