[Heterogeneous Execution] Function parameters and result VirtualDevices do not match those of call. Call

Hi, I try to do heterogeneous execution on cpu and gpu. But I encounter some error. I used ExprMutator class to annotate nodes as follows:

class ScheduleOps(ExprMutator):
    def __init__(self, cpu_ops, gpu_ops, ops):
        self.cpu_ops = cpu_ops
        self.gpu_ops = gpu_ops
        self.ops = ops
        self.cnt = 0
        super().__init__()

    def visit_call(self, expr):
        visit = super().visit_call(expr)
        if str(visit.op) in self.ops:
            if self.cnt in self.cpu_ops:
                print("cpu--", self.cnt)
                self.cnt += 1
                return relay.annotation.on_device(visit, tvm.cpu(0), constrain_result=False, constrain_body=True)
            elif self.cnt in self.gpu_ops:
                print("gpu--", self.cnt)
                self.cnt += 1
                return relay.annotation.on_device(visit, tvm.cuda(0), constrain_result=False, constrain_body=True)
            else:
                raise ValueError("{}".format(self.cnt))
            
        else:
            return visit

def schedule_ops(expr, cpu_ops, gpu_ops, ops):
    sche = ScheduleOps(cpu_ops, gpu_ops, ops)
    return sche.visit(expr)

Here is the mod after using relay.annotation.on_device.

def @main(%TVMInputPH_0: Tensor[(10, 1), int64], %TVMInputPH_1: Tensor[(10, 1), int64], %TVMInputPH_2: Tensor[(10, 1), int64]) {
  %0 = reshape(%TVMInputPH_0, newshape=[-1, 1]) /* span=pre_batch/Reshape_188:0:0 */;
  %1 = transpose(%TVMInputPH_1, axes=[-1, 0]);
  %2 = on_device(%0, virtual_device=meta[VirtualDevice][0]) /* span=pre_batch/Reshape_188:0:0 */;
  %3 = on_device(%1, virtual_device=meta[VirtualDevice][1]);
  %4 = gather_nd(%2, %3) /* span=pre_batch/GatherNd_187:0:0 */;
  %5 = on_device(%4, virtual_device=meta[VirtualDevice][2]) /* span=pre_batch/GatherNd_187:0:0 */;
  %6 = reshape(%5, newshape=[-1, 1]) /* span=pre_batch/pre_batch_queryPoiSmoothPvCtr60:0:0 */;
  %7 = on_device(%6, virtual_device=meta[VirtualDevice][3]) /* span=pre_batch/pre_batch_queryPoiSmoothPvCtr60:0:0 */;
  %8 = less(%7, 73i64) /* span=Dense_Embedding/Less_211:0:0 */;
  %9 = zeros_like(%7) /* span=Dense_Embedding/zeros_like_211:0:0 */;
  %10 = reshape(%TVMInputPH_2, newshape=[-1, 1]) /* span=pre_batch/Reshape_35:0:0 */;
  %11 = transpose(%TVMInputPH_1, axes=[-1, 0]);
  %12 = on_device(%10, virtual_device=meta[VirtualDevice][6]) /* span=pre_batch/Reshape_35:0:0 */;
  %13 = on_device(%11, virtual_device=meta[VirtualDevice][7]);
  %14 = gather_nd(%12, %13) /* span=pre_batch/GatherNd_34:0:0 */;
  %15 = on_device(%14, virtual_device=meta[VirtualDevice][8]) /* span=pre_batch/GatherNd_34:0:0 */;
  %16 = reshape(%15, newshape=[-1, 1]) /* span=pre_batch/pre_batch_queryPoiSmoothPvCtr45:0:0 */;
  %17 = on_device(%16, virtual_device=meta[VirtualDevice][9]) /* span=pre_batch/pre_batch_queryPoiSmoothPvCtr45:0:0 */;
  %18 = less(%17, 73i64) /* span=Dense_Embedding/Less_68:0:0 */;
  %19 = zeros_like(%17) /* span=Dense_Embedding/zeros_like_68:0:0 */;
  %20 = on_device(%8, virtual_device=meta[VirtualDevice][4]) /* span=Dense_Embedding/Less_211:0:0 */;
  %21 = on_device(%9, virtual_device=meta[VirtualDevice][5]) /* span=Dense_Embedding/zeros_like_211:0:0 */;
  %22 = on_device(%18, virtual_device=meta[VirtualDevice][10]) /* span=Dense_Embedding/Less_68:0:0 */;
  %23 = on_device(%19, virtual_device=meta[VirtualDevice][11]) /* span=Dense_Embedding/zeros_like_68:0:0 */;
  (%20, %21, %22, %23)
}

But when I try to ‘relay.build’ annotated_mod, errors occur.

with tvm.transform.PassContext(opt_level = 3):
        graph, lib, params = relay.build(anno_mod, target={'cpu':'llvm', 'cuda':'cuda'}, params=params)

Traceback as follows:

Traceback (most recent call last):
  File "main.py", line 52, in <module>
    main()
  File "main.py", line 47, in main
    result = set_device(mod, params)
  File "/workdir/tvmtest/ctr_optimizer/minicut.py", line 729, in set_device
    graph, lib, params = relay.build(anno_mod, target, params=params)
  File "/root/tvm/python/tvm/relay/build_module.py", line 364, in build
    graph_json, runtime_mod, params = bld_mod.build(
  File "/root/tvm/python/tvm/relay/build_module.py", line 161, in build
    self._build(
  File "/root/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  21: TVMFuncCall
  20: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  19: tvm::relay::backend::RelayBuildModule::Build(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&, tvm::Target const&, tvm::relay::Executor const&, tvm::relay::Runtime const&, tvm::WorkspaceMemoryPools const&, tvm::ConstantMemoryPools const&, tvm::runtime::String)
  18: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  17: tvm::relay::backend::RelayBuildModule::OptimizeImpl(tvm::IRModule)
  16: tvm::transform::Pass::operator()(tvm::IRModule) const
  15: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  14: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  13: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  12: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  11: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  10: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  9: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay9transform12_GLOBAL__N_115PlanDevicesCoreENS_17CompilationConfigEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SH_SL_
  8: tvm::relay::transform::(anonymous namespace)::DeviceAnalyzer::Analyze()
  7: tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)
  6: void tvm::relay::ExpandDataflow<tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#2}, tvm::relay::ExpandDataflow<{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, {lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1})::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#2}, tvm::relay::ExpandDataflow, tvm::relay::ExpandDataflow<{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, {lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1})::{lambda(tvm::RelayExpr const&)#1}) [clone .isra.0]
  5: tvm::relay::MixedModeVisitor::VisitLeaf(tvm::RelayExpr const&)
  4: tvm::relay::transform::(anonymous namespace)::DeviceAnalyzer::VisitExpr_(tvm::relay::FunctionNode const*)
  3: tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)
  2: void tvm::relay::ExpandDataflow<tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#2}, tvm::relay::ExpandDataflow<{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, {lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1})::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#2}, tvm::relay::ExpandDataflow, tvm::relay::ExpandDataflow<{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, {lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeVisitor::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1})::{lambda(tvm::RelayExpr const&)#1}) [clone .isra.0]
  1: tvm::relay::MixedModeVisitor::VisitLeaf(tvm::RelayExpr const&)
  0: tvm::relay::transform::(anonymous namespace)::DeviceAnalyzer::VisitExpr_(tvm::relay::CallNode const*)
  File "/root/tvm/src/relay/transforms/device_planner.cc", line 519
TVMError: Function parameters and result VirtualDevices do not match those of call. Call:
free_var %TVMInputPH_1: Tensor[(10, 1), int64] /* ty=Tensor[(10, 1), int64] */;
%0 = transpose(%TVMInputPH_1, axes=[-1, 0]) /* ty=Tensor[(1, 10), int64] */;
on_device(%0, virtual_device=VirtualDevice(device_type=2, virtual_device_id=0)) /* ty=Tensor[(1, 10), int64] */
with function virtual devices:
fn(?454614896?VirtualDevice(device_type=2, virtual_device_id=0, target=Target(id=65b18a0, kind='cuda', keys={'cuda', 'gpu'}, attrs={'thread_warp_size': 20, 'max_num_threads': 400, 'arch': "sm_86"}, host=Target(id=1b17f510, kind='llvm', keys={'cpu'})))):?454614960?
and implied call virtual devices:
fn(?99036272?VirtualDevice(device_type=1, virtual_device_id=0, target=Target(id=65b1830, kind='llvm', keys={'cpu'}, host=Target(id=1b17f510, kind='llvm', keys={'cpu'})))):?454615088?

It seems to there is something wrong with VirtualDevices. Has anyone solved this problem? Thanks.