Background and Motivation
Currently, TVM uses Any
to represent an unknown dimension when the input has a dynamic shape. When building the schedule for a Relay primitive function, the AnyNode
will be converted to a Var
named "any_dim"
. However, an element in a shape array cannot be a negative number, and use Var
to represent shape array elements will cause redundant boundary check because we cannot deduce the sign of a Var
. For example, given a simple network with only a softmax operation like
import numpy as np
import tvm
from tvm import relay
from time import time
# actual input shape
dim0 = 2
dim1 = 10
dim2 = 8
dim3 = 16
# relay var shapes
v_dim0 = relay.Any()
v_dim1 = relay.Any()
v_dim2 = relay.Any()
v_dim3 = relay.Any()
# rt settings
exec_mod = "vm"
tgt = "cuda"
dev = tvm.device(tgt)
def get_mod():
x = relay.var("x", shape=(v_dim0, v_dim1, v_dim2, v_dim3))
y = relay.nn.softmax(x)
mod = tvm.IRModule()
mod["main"] = relay.Function([x], y)
return mod
mod = get_mod()
The corresponding tir looks like (here we show a part of the whole function since it is too long)
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"global_symbol": "fused_nn_softmax", "tir.noalias": True}
buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [any_dim: int32, any_dim_1: int32, any_dim_2: int32, any_dim_3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
placeholder: Buffer(placeholder_2: Pointer(float32), float32, [any_dim, any_dim_1, any_dim_2, any_dim_3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
attr [T_softmax_maxelem: Pointer(global float32)] "storage_scope" = "global";
allocate(T_softmax_maxelem, float32, [((any_dim*any_dim_1)*any_dim_2)]);
attr [T_softmax_exp: Pointer(global float32)] "storage_scope" = "global";
allocate(T_softmax_exp, float32, [(((any_dim*any_dim_1)*any_dim_2)*any_dim_3)]) {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((any_dim*any_dim_1)*any_dim_2) + 511), 512);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512;
if (blockIdx.x < floordiv(((any_dim*any_dim_1)*any_dim_2), 512)) {
if (floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2), any_dim_1) < any_dim) {
if (floormod(floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2), any_dim_1) < any_dim_1) {
if (floormod(((blockIdx.x*512) + threadIdx.x), any_dim_2) < any_dim_2) {
if (floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2) < (any_dim*any_dim_1)) {
T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
}
}
}
}
...
If you look carefully, you will find some unnecessary if-conditions like floormod(floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2), any_dim_1) < any_dim_1
. This is because floormod
can have negative inputs. For example, y < floodmod(x, y) <= 0
if y < 0
. Since we do not know the sign of any_dim_1
and the output of floodmod
will be greater than any_dim_1
if any_dim_1 < 0
, a redundant if-condition will be added here.
Solution
I think currently we do not have the need for negative shape arrays. Therefore, perhaps using SizeVar
instead of Var
here is a better choice. By replacing Var
with SizeVar
, we can remove these redundant if-conditions:
attr = {"global_symbol": "fused_nn_softmax", "tir.noalias": True}
buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [any_dim: int32, any_dim_1: int32, any_dim_2: int32, any_dim_3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
placeholder: Buffer(placeholder_2: Pointer(float32), float32, [any_dim, any_dim_1, any_dim_2, any_dim_3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
attr [T_softmax_maxelem: Pointer(global float32)] "storage_scope" = "global";
allocate(T_softmax_maxelem, float32, [((any_dim*any_dim_1)*any_dim_2)]);
attr [T_softmax_exp: Pointer(global float32)] "storage_scope" = "global";
allocate(T_softmax_exp, float32, [(((any_dim*any_dim_1)*any_dim_2)*any_dim_3)]) {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((any_dim*any_dim_1)*any_dim_2) + 511), 512);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512;
if (blockIdx.x < floordiv(((any_dim*any_dim_1)*any_dim_2), 512)) {
if (floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2), any_dim_1) < any_dim) {
if (floordiv(((blockIdx.x*512) + threadIdx.x), any_dim_2) < (any_dim*any_dim_1)) {
T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
}
}
...
We try this modify and perform experiments on the BERT-Base model. With actual input shape = (64, 128, 256), the end-to-end time drop from 0.5578s to 0.4799s (TF time: 0.5208s).
PR for this: https://github.com/apache/tvm/pull/8555