Compile error for CUDA target

The error:
TVMError: Check failed: !use_count_.count(v): variable dw has been used before definition!

It looks like this has been introduced by commit:
c846d17c Mon Sep 16 13:03:32 2019 -0700 [TOPI] Improve conv2d_transpose schedule on X86 and CUDA (#3948)

Build prior to this commit can be used to compile same graph with no errors.

Backtrace:
Traceback (most recent call last):

File “from_mxnet.py”, line 149, in
graph, lib, params = relay.build(func, target, params=params)

File “/mnt/efs/src/tvm/python/tvm/relay/build_module.py”, line 207, in build
graph_json, mod, params = bld_mod.build(func, target, target_host, params)

File “/mnt/efs/src/tvm/python/tvm/relay/build_module.py”, line 108, in build
self._build(func, target, target_host)

File “tvm/_ffi/_cython/./function.pxi”, line 310, in tvm._ffi._cy3.core.FunctionBase.call

File “tvm/_ffi/_cython/./function.pxi”, line 245, in tvm._ffi._cy3.core.FuncCall

File “tvm/_ffi/_cython/./function.pxi”, line 234, in tvm._ffi._cy3.core.FuncCall3

File “tvm/_ffi/_cython/./base.pxi”, line 171, in tvm._ffi._cy3.core.CALL

tvm.ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) /mnt/efs/src/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr(tvm::relay::Expr const&)+0x4a7) [0x7fa448c05f97]
[bt] (7) /mnt/efs/src/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr
(tvm::relay::TupleNode const*)+0xc5) [0x7fa448bffbd5]
[bt] (6) /mnt/efs/src/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr(tvm::relay::Expr const&)+0x630) [0x7fa448c06120]
[bt] (5) /mnt/efs/src/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr_(tvm::relay::CallNode const*)+0xb03) [0x7fa448c0b613]
[bt] (4) /mnt/efs/src/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr(tvm::relay::Expr const&)+0x630) [0x7fa448c06120]
[bt] (3) /mnt/efs/src/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr_(tvm::relay::CallNode const*)+0x6b6) [0x7fa448c0b1c6]
[bt] (2) /mnt/efs/src/tvm/build/libtvm.so(+0x111deac) [0x7fa448be0eac]
[bt] (1) /mnt/efs/src/tvm/build/libtvm.so(tvm::relay::CompileEngineImpl::LowerInternal(tvm::relay::CCacheKey const&)+0x811) [0x7fa448beb7f1]
[bt] (0) /mnt/efs/src/tvm/build/libtvm.so(+0x12ceaeb) [0x7fa448d91aeb]
File “/mnt/efs/src/tvm/python/tvm/relay/backend/_backend.py”, line 51, in lower
f = _build.lower(sch, inputs, name=func_name)
File “/mnt/efs/src/tvm/python/tvm/build_module.py”, line 416, in lower
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
File “tvm/_ffi/_cython/./function.pxi”, line 310, in tvm._ffi._cy3.core.FunctionBase.call
File “tvm/_ffi/_cython/./function.pxi”, line 255, in tvm._ffi._cy3.core.FuncCall
File “tvm/_ffi/_cython/./base.pxi”, line 171, in tvm.ffi.cy3.core.CALL
[bt] (8) /mnt/efs/src/tvm/build/libtvm.so(tvm::ir::IRUseDefAnalysis::Mutate
(tvm::ir::For const*, tvm::Stmt const&)+0x53) [0x7fa4488712a3]
[bt] (7) /mnt/efs/src/tvm/build/libtvm.so(tvm::ir::IRMutator::Mutate
(tvm::ir::For const*, tvm::Stmt const&)+0xb6) [0x7fa4487ed5c6]
[bt] (6) /mnt/efs/src/tvm/build/libtvm.so(tvm::ir::IRMutator::Mutate(tvm::Stmt)+0x5d) [0x7fa448622c5d]
[bt] (5) /mnt/efs/src/tvm/build/libtvm.so(tvm::IRFunctor<tvm::Stmt (tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)>::operator()(tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*) const+0x136) [0x7fa448622b16]
[bt] (4) /mnt/efs/src/tvm/build/libtvm.so(std::_Function_handler<tvm::Stmt (tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*), tvm::IRFunctor<tvm::Stmt (tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)>::set_dispatchtvm::ir::For(std::function<tvm::Stmt (tvm::ir::For const*, tvm::Stmt const&, tvm::ir::IRMutator*)>)::{lambda(tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)#1}>::_M_invoke(std::Any_data const&, tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*&&)+0x53) [0x7fa4487f4493]
[bt] (3) /mnt/efs/src/tvm/build/libtvm.so(+0xd273d4) [0x7fa4487ea3d4]
[bt] (2) /mnt/efs/src/tvm/build/libtvm.so(tvm::ir::IRUseDefAnalysis::Mutate
(tvm::ir::For const*, tvm::Stmt const&)+0x42) [0x7fa448871292]
[bt] (1) /mnt/efs/src/tvm/build/libtvm.so(tvm::ir::IRUseDefAnalysis::HandleDef(tvm::Variable const*)+0x15f) [0x7fa448870caf]
[bt] (0) /mnt/efs/src/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7fa4485b8032]
File “/mnt/efs/src/tvm/src/pass/split_host_device.cc”, line 135
File “tvm/_ffi/_cython/./function.pxi”, line 56, in tvm._ffi._cy3.core.tvm_callback
File “/mnt/efs/src/tvm/python/tvm/relay/backend/_backend.py”, line 59, in lower
raise RuntimeError(msg)
File “/mnt/efs/src/tvm/python/tvm/relay/backend/_backend.py”, line 51, in lower
f = _build.lower(sch, inputs, name=func_name)
File “/mnt/efs/src/tvm/python/tvm/build_module.py”, line 416, in lower
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
File “tvm/_ffi/_cython/./function.pxi”, line 310, in tvm._ffi._cy3.core.FunctionBase.call
File “tvm/_ffi/_cython/./function.pxi”, line 255, in tvm._ffi._cy3.core.FuncCall
File “tvm/_ffi/_cython/./base.pxi”, line 171, in tvm.ffi.cy3.core.CALL
[bt] (8) /mnt/efs/src/tvm/build/libtvm.so(tvm::ir::IRUseDefAnalysis::Mutate
(tvm::ir::For const*, tvm::Stmt const&)+0x53) [0x7fa4488712a3]
[bt] (7) /mnt/efs/src/tvm/build/libtvm.so(tvm::ir::IRMutator::Mutate
(tvm::ir::For const*, tvm::Stmt const&)+0xb6) [0x7fa4487ed5c6]
[bt] (6) /mnt/efs/src/tvm/build/libtvm.so(tvm::ir::IRMutator::Mutate(tvm::Stmt)+0x5d) [0x7fa448622c5d]
[bt] (5) /mnt/efs/src/tvm/build/libtvm.so(tvm::IRFunctor<tvm::Stmt (tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)>::operator()(tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*) const+0x136) [0x7fa448622b16]
[bt] (4) /mnt/efs/src/tvm/build/libtvm.so(std::_Function_handler<tvm::Stmt (tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*), tvm::IRFunctor<tvm::Stmt (tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)>::set_dispatchtvm::ir::For(std::function<tvm::Stmt (tvm::ir::For const*, tvm::Stmt const&, tvm::ir::IRMutator*)>)::{lambda(tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*)#1}>::M_invoke(std::Any_data const&, tvm::NodeRef const&, tvm::Stmt const&, tvm::ir::IRMutator*&&)+0x53) [0x7fa4487f4493]
[bt] (3) /mnt/efs/src/tvm/build/libtvm.so(+0xd273d4) [0x7fa4487ea3d4]
[bt] (2) /mnt/efs/src/tvm/build/libtvm.so(tvm::ir::IRUseDefAnalysis::Mutate
(tvm::ir::For const*, tvm::Stmt const&)+0x42) [0x7fa448871292]
[bt] (1) /mnt/efs/src/tvm/build/libtvm.so(tvm::ir::IRUseDefAnalysis::HandleDef(tvm::Variable const*)+0x15f) [0x7fa448870caf]
[bt] (0) /mnt/efs/src/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7fa4485b8032]
File “/mnt/efs/src/tvm/src/pass/split_host_device.cc”, line 135
TVMError: Check failed: !use_count
.count(v): variable dw has been used before definition!
During handling of the above exception, another exception occurred:

TVMError: Check failed: !use_count_.count(v): variable dw has been used before definition!
Error during compile function

v0.0.4
fn (%p0: Tensor[(1, 512, 1, 128), float32], %p1: Tensor[(512, 1, 1, 512), float32], Primitive=1) -> Tensor[(1, 1, 1, 16256), float32] {
nn.conv2d_transpose(%p0, %p1, channels=1, kernel_size=[1, 512], strides=[1, 128], padding=[0, 256]) /* ty=Tensor[(1, 1, 1, 16256), float32] */
}

Could you paste a code snippet that reproduces the error?

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
import tvm
import tvm.relay as relay
import numpy as np

def get_net(ctx):
net = nn.HybridSequential()
with net.name_scope():
net.add(gluon.nn.Conv2DTranspose(channels=1, kernel_size=(1, 512), strides=(1, 5), padding=(0,256)))
net.collect_params().initialize(ctx=ctx)
return net

target=“cuda”
net = get_net(mx.gpu(0))
net.hybridize()
x = np.random.rand(1, 32, 512, 128)
netx = net(mx.nd.array(x, ctx=mx.gpu(0)))
shape_dict = {‘data’: x.shape}
func, params = relay.frontend.from_mxnet(net, shape_dict)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)

I see the same error. A quick temporary fix is to turn off fallback schedule.

But the error does expose a bug in IR lowering and requires further investigation. @Laurawly @yzhliu could you also take a look?

Issue: https://github.com/apache/incubator-tvm/issues/4470

I fixed formatting in Alex’s reproduce error example above and make it work on non-Nvidia box

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
import tvm
import tvm.relay as relay
import numpy as np

target = tvm.target.cuda()
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
set_cuda_target_arch('sm_70')

ctx = mx.cpu(0)

net = nn.HybridSequential()
with net.name_scope():
  net.add(gluon.nn.Conv2DTranspose(channels=1, kernel_size=(1, 512), strides=(1, 5), padding=(0,256)))
  net.collect_params().initialize(ctx=ctx)

net.hybridize()
x = np.random.rand(1, 32, 512, 128)
netx=net(mx.nd.array(x, ctx=ctx))

shape_dict = {'data': x.shape}
func, params = relay.frontend.from_mxnet(net, shape_dict)
with relay.build_config(opt_level=3):
  graph, lib, params = relay.build(func, target, params=params)

This example shows that the compilation fails when output channel is 1. Temporary workaround fix is https://github.com/apache/incubator-tvm/pull/4472