Ops become slow when using te.var

I wrote some ops with te and topi, and found that when using te.var to represent op input shape,it became much slower than constant shape,when with the same schedule.

When I look into the generated cuda code, there are lots of if sentences about the var,which make it much slower.So I wonder if there is a way to get rid of those if sentences,because I think most of them are useless.Maybe some assertion? But how to add assertion in TE?

Thank you for any advice!

I guess what you are seeing are potential cases when the variable can have any value in the range of the data type you selected and therefore for correctness all those if statements are necessary?

Any reason you want to use te.var ? I am guessing due to some dynamic shape you want to support, but I think the TVM standard way of doing this is to JIT each time the values changed.

Yes,I guess so.However for most case I think those ifs are unnecessary.So I want to know the assertion sentences to avoid them.

Here’s an example. N is the te.var.You can clearly see the duplicate if

.

extern "C" __global__ void default_function_kernel0(float* __restrict__ T_subtract1_red, float* __restrict__ Aa, float* __restrict__ ke, int k, int N) {
  if (((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) >> 18) < N) {
    if (((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) >> 18) < N) {
      if (((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) >> 9) < (N * 512)) {
        if ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) < (N * 262144)) {
          T_subtract1_red[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)))] = 3.402823e+38f;
        }
      }
    }
  }
  for (int k4 = 0; k4 < k; ++k4) {
    if (((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) >> 18) < N) {
      if (((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) >> 18) < N) {
        if (((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) >> 9) < (N * 512)) {
          if ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) < (N * 262144)) {
            T_subtract1_red[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)))] = min(T_subtract1_red[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)))], (((((((k >> 1) <= ((((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) & 262143) >> 9) + k4)) & (((((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) & 262143) >> 9) + k4) < ((k + 511) - (k >> 1)))) & ((k >> 1) <= ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) % k) + ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) & 511)))) & (((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) % k) + ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k) & 511)) < ((k + 511) - (k >> 1)))) ? Aa[(((((k4 * 512) + (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) / k)) + (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) % k)) - ((k >> 1) * 513)))] : 1.000000e+04f) - ((float)((ke[(((k4 * k) + (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) % k)))] == 0.000000e+00f) ? -10000 : 0))));
          }
        }
      }
    }
  }
}

Just inline the one stage into the other one?

EDIT: wait your if statements require variables which are not defined (blockIdx.x andThreadIdx.x)

You mean combine the two kernels? But actually I want to cut the if statements.