Using 'vthread' leads to assertion error Check failed: is_zero(dom->min) == false

Hello everyone,

I am trying to understand more about the te.bind operator and more specifically with the use of ‘vthread’. Sadly, I am getting confused in something pretty early in my trials.

The setup

I have the following script, in which I am trying to bind the batch dimension of a fused conv2d_relu to vthreads

import tvm
from tvm import topi
from tvm import te

input_shape= (2,16,16,16) #NCHW
input_ph = te.placeholder(input_shape)

kernel_shape = (3,3,16,16) #HWIO
kernel_ph = te.placeholder(kernel_shape)

res_conv = topi.nn.conv2d(input_ph,kernel_ph,1,0,1)
res_relu = topi.nn.relu(res_conv)

#1 Vanilla schedule
s = te.create_schedule(res_relu.op)
print(tvm.lower(s,[input_ph,kernel_ph,res_relu],simple_mode=True))

#2 Thread level parallelism of batch
s[res_conv].bind(s[res_conv].op.axis[0], te.thread_axis("vthread"))
print(tvm.lower(s,[input_ph,kernel_ph,res_relu],simple_mode=True))

#3 Fuse conv2d into relu and thread level parallelism of batch of relu stage
s = te.create_schedule(res_relu.op)
s[res_conv].compute_at(s[res_relu],s[res_relu].op.axis[-1])
s[res_relu].bind(s[res_relu].op.axis[0], te.thread_axis("vthread"))
#NOTE: this pragma works
s[res_conv].pragma(s[res_conv].op.axis[0], "this is a pragma")
print(tvm.lower(s,[input_ph,kernel_ph,res_relu],simple_mode=True))

#4 Fuse conv2d into relu and thread level parallelism of the batch of conv2d stage
s = te.create_schedule(res_relu.op)
s[res_conv].compute_at(s[res_relu],s[res_relu].op.axis[-1])
s[res_conv].bind(s[res_conv].op.axis[0], te.thread_axis("vthread"))
#NOTE:Following line fails
print(tvm.lower(s,[input_ph,kernel_ph,res_relu],simple_mode=True))

My comments

  1. Vanilla schedule: the printout has all the stages (including a padding which is not necessary) and is as expected
  2. Thread level parallelism of batch: without fusion of the conv2d and relu, I can bind the batch axis of the conv2d and printout the lowered representation
  3. Fuse conv2d into relu and thread level parallelism of batch of relu stage: I want a “lowest” level fusion between them. Therefore I decide to compute_at(...axis[-1]). Because now the conv2d is “inside” the scope of the relu operator, I imagined that I had to bind the batch dimension of the relu to the thread. This works, but note that the pragma is being linked to the 0th axis of the conv2d… so I guess that the conv2d, although it has been fused, still preserves information of its own axes
  4. Fuse conv2d into relu and thread level parallelism of the batch of conv2d stage: I was actually doing it this way before I tried #3. Executing the tvm.lower throws the following error
File "/home/tvm/src/te/operation/op_utils.cc", line 138
TVMError: 
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------
  Check failed: is_zero(dom->min) == false: 

Questions

  • I dont see why the minimum value of the batch dimension of the conv2d operator would not be 0. All prints of the TIR have this loop going from [0,2). What is happening here? also… why the limitation that it has to start at 0 ?
  • Do fused operators (via compute_at) loose their axes information at some point during lowering if they are equivalent (or equal) to some other axis of the higher scoped operator? if so, why is the pragma not affected by this?

Thanks

@thierry any ideas? I saw you use vthreads in the VTA examples and would love the help.