Questions about conv2d weight transform

Sorry, I just started contacting tvm. I noticed that in order to speed up conv2d, the necessary transforms for the input and weight will be required:

def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
    KH, KW, IC, OC = get_const_tuple(kernel.shape)
    K = KH * KW * IC
    N = OC

    kernel_flat = te.compute(
        (K, N), lambda x, y: kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y], "weight_flatten"
    )

    pad_K = 0
    pad_N = 0

    if N % tile_rows != 0:
        pad_N = tile_rows - (N % tile_rows)

    if K % tile_cols != 0:
        pad_K = tile_cols - (K % tile_cols)

    N_padded = N + pad_N
    K_padded = K + pad_K

    if pad_K != 0 or pad_N != 0:
        kernel_flat = pad(
            kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding"
        )

    return te.compute(
        (N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols),
        lambda x, y, z, w: kernel_flat[w + tile_cols * y, z + tile_rows * x],
        name="weight_block_reshape",
    )

You can see that this is a te.compute, and you can still see this calculation expression at the TIR level,below is the printed TIR:

  // attr [A_padded] storage_scope = "global"
  allocate A_padded[uint8 * 1605632]
  // attr [weight_block_reshape] storage_scope = "global"
  allocate weight_block_reshape[uint8 * 1024]
  // attr [C] storage_scope = "global"
  allocate C[int32 * 1605632]
  // attr [iter_var(pipeline, , pipeline)] pipeline_exec_scope = 1
  for (i1, 0, 50176) {
    for (i2, 0, 32) {
      A_padded[((i1*32) + i2)] = tir.if_then_else((i2 < 12), placeholder[((((floordiv(i1, 224)*675) + (floordiv(i2, 6)*675)) + (floormod(i1, 224)*3)) + floormod(i2, 6))], (uint8)0)
    }
  }
  for (x.y.fused, 0, 8) {
    for (z, 0, 32) {
      weight_block_reshape[((x.y.fused*128) + (z*4))] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[((x.y.fused*12) + z)], (uint8)0)
      weight_block_reshape[(((x.y.fused*128) + (z*4)) + 1)] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[(((x.y.fused*12) + z) + 3)], (uint8)0)
      weight_block_reshape[(((x.y.fused*128) + (z*4)) + 2)] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[(((x.y.fused*12) + z) + 6)], (uint8)0)
      weight_block_reshape[(((x.y.fused*128) + (z*4)) + 3)] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[(((x.y.fused*12) + z) + 9)], (uint8)0)
    }
  }

This is just a transformation of the data, why not optimize it? In other words, LLVM will be optimized in the future?

By looking at the generated llvm code, it is found that the weight transform still exists, and llvm is not optimized. Did I forget to turn on any optimization switch? Still weight transform is not optimized.

If you’re wondering why TVM/LLVM doesn’t do a compile-time evaluation of this code (because weights are constants), the answer is because TIR doesn’t treat weight data as constant. It expects you to pass in the weights at runtime, so the compiler can’t do the weight transformation ahead-of-time.

I do agree this isn’t ideal and especially for convolutions/GEMMs it means we get worse performance than we otherwise could. We have an RFC up at the moment which would add the ability to represent constant weights directly in TIR: [RFC][TIR] TIR Non-scalar Constants by manupa-arm · Pull Request #22 · apache/tvm-rfcs · GitHub. This could be the first step to enabling the behaviour you want, so feel free to take a look and comment.