[pre-RFC][Dynamic Shape] Use SizeVar instead of Var when convert Any in the GetShape function

Background and Motivation

Currently, TVM uses Any to represent an unknown dimension when the input has a dynamic shape. When building the schedule for a Relay primitive function, the AnyNode will be converted to a Var named "any_dim". However, an element in a shape array cannot be a negative number, and use Var to represent shape array elements will cause redundant boundary check because we cannot deduce the sign of a Var. For example, given a simple network with only a softmax operation like

import numpy as np
import tvm
from tvm import relay
from time import time

# actual input shape
dim0 = 2
dim1 = 10
dim2 = 8
dim3 = 16

# relay var shapes
v_dim0 = relay.Any()
v_dim1 = relay.Any()
v_dim2 = relay.Any()
v_dim3 = relay.Any()

# rt settings
exec_mod = "vm"
tgt = "cuda"
dev = tvm.device(tgt)

def get_mod():
    x = relay.var("x", shape=(v_dim0, v_dim1, v_dim2, v_dim3))
    y = relay.nn.softmax(x)
    mod = tvm.IRModule()
    mod["main"] = relay.Function([x], y)
    return mod

mod = get_mod()

The corresponding tir looks like (here we show a part of the whole function since it is too long)

primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"global_symbol": "fused_nn_softmax", "tir.noalias": True}
  buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [any_dim: int32, any_dim_1: int32, any_dim_2: int32, any_dim_3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [any_dim, any_dim_1, any_dim_2, any_dim_3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  attr [T_softmax_maxelem: Pointer(global float32)] "storage_scope" = "global";
  allocate(T_softmax_maxelem, float32, [((any_dim*any_dim_1)*any_dim_2)]);
  attr [T_softmax_exp: Pointer(global float32)] "storage_scope" = "global";
  allocate(T_softmax_exp, float32, [(((any_dim*any_dim_1)*any_dim_2)*any_dim_3)]) {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((any_dim*any_dim_1)*any_dim_2) + 511), 512);
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512;
    if (blockIdx.x < floordiv(((any_dim*any_dim_1)*any_dim_2), 512)) {
      if (floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2), any_dim_1) < any_dim) {
        if (floormod(floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2), any_dim_1) < any_dim_1) {
          if (floormod(((blockIdx.x*512) + threadIdx.x), any_dim_2) < any_dim_2) {
            if (floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2) < (any_dim*any_dim_1)) {
              T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
            }
          }
        }
      }
...

If you look carefully, you will find some unnecessary if-conditions like floormod(floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2), any_dim_1) < any_dim_1. This is because floormod can have negative inputs. For example, y < floodmod(x, y) <= 0 if y < 0. Since we do not know the sign of any_dim_1 and the output of floodmod will be greater than any_dim_1 if any_dim_1 < 0, a redundant if-condition will be added here.

Solution

I think currently we do not have the need for negative shape arrays. Therefore, perhaps using SizeVar instead of Var here is a better choice. By replacing Var with SizeVar, we can remove these redundant if-conditions:

  attr = {"global_symbol": "fused_nn_softmax", "tir.noalias": True}
  buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [any_dim: int32, any_dim_1: int32, any_dim_2: int32, any_dim_3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [any_dim, any_dim_1, any_dim_2, any_dim_3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  attr [T_softmax_maxelem: Pointer(global float32)] "storage_scope" = "global";
  allocate(T_softmax_maxelem, float32, [((any_dim*any_dim_1)*any_dim_2)]);
  attr [T_softmax_exp: Pointer(global float32)] "storage_scope" = "global";
  allocate(T_softmax_exp, float32, [(((any_dim*any_dim_1)*any_dim_2)*any_dim_3)]) {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((any_dim*any_dim_1)*any_dim_2) + 511), 512);
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512;
    if (blockIdx.x < floordiv(((any_dim*any_dim_1)*any_dim_2), 512)) {
      if (floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2), any_dim_1) < any_dim) {
        if (floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2) < (any_dim*any_dim_1)) {
          T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
        }
      }
...

We try this modify and perform experiments on the BERT-Base model. With actual input shape = (64, 128, 256), the end-to-end time drop from 0.5578s to 0.4799s (TF time: 0.5208s).

PR for this: https://github.com/apache/tvm/pull/8555

4 Likes