Check failed: (func) is false: Expected the operator to be a global var, but got relay.Let

Hello!

I`m new to TVM, and trying to tune and compile my pytorch PageRank model in TVM. But I met a problem when compiling the relay model: Check failed: (func) is false: Expected the operator to be a global var, but got relay.Let.

My model contains control flow, I`m wondering whether this problem is cause by this reason.

Please help me check my workflow, Thank to your kindness and patience.

my pytorch model:

class PageRank_CPU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,source,target,init_vertex,iteration,vertex_num):
        source = source.int()
        target = target.int()
        V_out_deg = torch.zeros_like(init_vertex, dtype=torch.int32)
        V_out_deg = V_out_deg.scatter_add(0, source , torch.ones_like(source, dtype=torch.int32))
        mask = (V_out_deg == 0)
        V_old = init_vertex
        sum =  torch.sum(V_old)
        V_old = V_old /sum
        # start iteration
        round = torch.tensor(0)
        while round < iteration:
            V_new = torch.zeros_like(init_vertex)
            V_old_temp = V_old / V_out_deg
            blind_sum = torch.masked_select(V_old,mask).sum()
            V_new = V_new.scatter_add(0, target, V_old_temp[source])
            V_new = V_new * 0.85 + (0.15 + blind_sum * 0.85) / vertex_num
            diff = torch.abs(V_new-V_old).sum()
            V_old = V_new
            round+=1
            if torch.lt(diff,1e-7):break
        return V_old

use torchscript to trace model and transform into relay:

dummy_input = (edge_list[0,:],edge_list[1,:],torch.rand(vertex_num),torch.tensor(30),torch.tensor(vertex_num))
scripted_model = torch.jit.script(model,example_inputs=dummy_input)
#print(scripted_model.code)
import tvm
from tvm import relay

shape_list = list(zip(input_names,[i.shape for i in dummy_input]))
print(shape_list)
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
print(mod)

relay model

def @main(%iter: float32, %init_vertex: Tensor[(8298), float32], %target: Tensor[(103688), float32], %source: Tensor[(103688), float32], %vertex_num: float32) {
  %29 = sum(%init_vertex);
  %30 = zeros_like(%init_vertex);
  %31 = cast(%source, dtype="int32");
  %32 = ones_like(%31);
  %33 = cast(%30, dtype="int32");
  %34 = cast(%32, dtype="int32");
  %35 = scatter_add(%33, %31, %34);
  %36 = less(0f, %iter);
  %37 = divide(%init_vertex, %29);
  %38 = cast(%target, dtype="int32");
  %39 = equal(%35, 0);
  %40 = (
    let %while_loop = fn (%v28: bool, %V_old.17: Tensor[(8298), float32], %round.11: int32, %target.5: Tensor[(103688), int32], %mask.1: Tensor[(8298), bool], %source.5: Tensor[(103688), int32], %V_out_deg.5: Tensor[(8298), int32], %vertex_num.1: float32, %init_vertex.1: Tensor[(8298), float32]) {
      %0 = equal(%v28, True);
      if (%0) {
        %1 = cast(%V_out_deg.5, dtype="float32");
        %2 = divide(%V_old.17, %1);
        %3 = (%2, %source.5);
        %4 = zeros_like(%init_vertex.1);
        %5 = adv_index(%3);
        %6 = scatter_add(%4, %target.5, %5);
        %7 = argwhere(%mask.1);
        %8 = split(%7, indices_or_sections=1, axis=1);
        %9 = %8.0;
        %10 = squeeze(%9, axis=[1]);
        %11 = (%10,);
        %12 = %11.0;
        %13 = (%V_old.17, %12);
        %14 = adv_index(%13);
        %15 = sum(%14);
        %16 = multiply(%15, 0.85f);
        %17 = add(%16, 0.15f);
        %18 = multiply(%6, 0.85f);
        %19 = divide(%17, %vertex_num.1);
        %20 = add(%18, %19);
        %21 = multiply(1f, %V_old.17);
        %22 = subtract(%20, %21);
        %23 = abs(%22);
        %24 = sum(%23);
        %25 = less(%24, 1e-07f);
        %28 = if (%25) {
          False
        } else {
          %26 = add(%round.11, 1);
          %27 = cast(%26, dtype="float32");
          less(%27, %iter)
        };
        %while_loop(%28, %20, %26, %target.5, %mask.1, %source.5, %V_out_deg.5, %vertex_num.1, %init_vertex.1)
      } else {
        (%v28, %V_old.17, %round.11, %target.5, %mask.1, %source.5, %V_out_deg.5, %vertex_num.1, %init_vertex.1)
      }
    };
    %while_loop
  );
  %41 = %40(%36, %37, 0, %38, %39, %31, %35, %vertex_num, %init_vertex);
  %41.1
}

compile relay model:

target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

met an error:

Traceback (most recent call last):
  File "/home/zht/GraphDemos/test_dynamic.py", line 150, in <module>
    lib = relay.build(mod, target=target, params=params)
  File "/usr/local/relax/python/tvm/relay/build_module.py", line 364, in build
    graph_json, runtime_mod, params = bld_mod.build(
  File "/usr/local/relax/python/tvm/relay/build_module.py", line 161, in build
    self._build(
  File "/usr/local/relax/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  11: TVMFuncCall
  10: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  9: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  8: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::GraphExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  7: tvm::relay::backend::GraphExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
  6: tvm::relay::backend::MemoizedExprTranslator<std::vector<tvm::relay::backend::GraphNodeRef, std::allocator<tvm::relay::backend::GraphNodeRef> > >::VisitExpr(tvm::RelayExpr const&)
  5: _ZZN3tvm5relay11ExprFunctorIFSt6vectorINS0_7backend12GraphNodeRefESaIS4_EER
  4: tvm::relay::backend::GraphExecutorCodegen::VisitExpr_(tvm::relay::TupleGetItemNode const*)
  3: tvm::relay::backend::MemoizedExprTranslator<std::vector<tvm::relay::backend::GraphNodeRef, std::allocator<tvm::relay::backend::GraphNodeRef> > >::VisitExpr(tvm::RelayExpr const&)
  2: _ZZN3tvm5relay11ExprFunctorIFSt6vectorINS0_7backend12GraphNodeRefESaIS4_EER
  1: tvm::relay::backend::GraphExecutorCodegen::VisitExpr_(tvm::relay::CallNode const*)
  0: tvm::relay::backend::GraphExecutorCodegen::GraphAddCallNode(tvm::relay::CallNode const*, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, dmlc::any, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, dmlc::any> > >)
  File "/home/zht/relax/src/relay/backend/graph_executor_codegen.cc", line 452
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (func) is false: Expected the operator to be a global var, but got relay.Let

my configuration: TVM:0.8 ; Pytorch:1.10;

My code :my code and experiment data

I got the same error did you solve it? thank you so much (