A Relay VM question: while loop and free vars

Hi, I have a following IR with some control flow, which comes from translating a TorchScript RNN (a related discussion). My pytorch parser and the RNN test case is available here.

v0.0.4
def @main(%X: Tensor[(10, 10, 4), float32], %v26: Tensor[(4, 4), float32], %v25: Tensor[(4), float32]) -> Tensor[(10, 4), float32] {
  %0 = full(0 /* ty=int32 */, shape=[10, 4], dtype="float32") /* ty=Tensor[(10, 4), float32] */;
  %1 = full(0 /* ty=int32 */, shape=[10, 4], dtype="float32") /* ty=Tensor[(10, 4), float32] */;
  %14 = (
    let %while_loop: fn (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) -> (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) = fn (%i.1: int32, %y.5: Tensor[(10, 4), float32], %h.5: Tensor[(10, 4), float32]) -> (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) {
      %2 = less(%i.1, 10 /* ty=int32 */) /* ty=bool */;
      if (%2) {
        %3 = add(%i.1, 1 /* ty=int32 */) /* ty=int32 */;
        %4 = take(%X, %i.1, axis=0) /* ty=Tensor[(10, 4), float32] */;
        %5 = multiply(%4, 1f /* ty=float32 */) /* ty=Tensor[(10, 4), float32] */;
        %6 = transpose(%v26, axes=[1, 0]) /* ty=Tensor[(4, 4), float32] */;
        %7 = multiply(%6, 1f /* ty=float32 */) /* ty=Tensor[(4, 4), float32] */;
        %8 = transpose(%7, axes=[1, 0]) /* ty=Tensor[(4, 4), float32] */;
        %9 = nn.dense(%5, %8, units=4) /* ty=Tensor[(10, 4), float32] */;
        %10 = nn.bias_add(%9, %v25) /* ty=Tensor[(10, 4), float32] */;
        %11 = negative(%10) /* ty=Tensor[(10, 4), float32] */;
        %12 = add(%11, %h.5) /* ty=Tensor[(10, 4), float32] */;
        %13 = tanh(%12) /* ty=Tensor[(10, 4), float32] */;
        %while_loop(%3, %13, %13) /* ty=(int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) */
      } else {
        (%i.1, %y.5, %h.5)
      }
    };
    %while_loop
  );
  %15 = %14(0 /* ty=int32 */, %0, %1) /* ty=(int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) */;
  %15.1
}

The IR looks reasonable to me, but it is rejected by the VM compiler:

File "/home/masa/projects/dev/tvm/src/relay/backend/vm/compiler.cc", line 643
TVMError: internal error: unreachable code,should be transformed away by previous passesfree_var %x: int32
free_var %x1: Tensor[(10, 4), float32]
free_var %X: Tensor[(10, 10, 4), float32]
free_var %v26: Tensor[(4, 4), float32]
free_var %v25: Tensor[(4), float32]
%0 = @lifted_name7090198356732118303(%X, %v26, %v25) /* ty=fn (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) -> (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) */;
%0(%x, %x1, %x1) /* ty=(int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) */

I’m not completely sure what this means, but I guess it is something about accessing variables that are not bound in the while loop. But they are bound as arguments to the function, so I’d hope that there would be no problem compiling this IR.

Is this error expected? In that case I need to add all free vars to while_loop’s vars and rename appropriately, is that right?

cc @jroesch @MarisaKirisame @haichen @zhiics

And here is the corresponding TorchScript IR. My IR above is a direct translation, including handling of bound/free vars.

graph(%self : __torch__.RNNLoop,
      %xs.1 : Tensor):
  %2 : bool = prim::Constant[value=1]() # rnn_test.py:37:8
  %3 : None = prim::Constant()
  %4 : int = prim::Constant[value=6]() # rnn_test.py:36:40
  %5 : int = prim::Constant[value=10]() # rnn_test.py:36:27
  %6 : int = prim::Constant[value=4]() # rnn_test.py:36:31
  %7 : int = prim::Constant[value=0]() # rnn_test.py:37:31
  %8 : int[] = prim::ListConstruct(%5, %6)
  %h.1 : Tensor = aten::zeros(%8, %4, %3, %3, %3) # rnn_test.py:36:15
  %10 : int[] = prim::ListConstruct(%5, %6)
  %y.1 : Tensor = aten::zeros(%10, %4, %3, %3, %3) # rnn_test.py:36:54
  %12 : int = aten::size(%xs.1, %7) # rnn_test.py:37:23
  %y : Tensor, %h : Tensor = prim::Loop(%12, %2, %y.1, %h.1) # rnn_test.py:37:8
    block0(%i.1 : int, %y.5 : Tensor, %h.5 : Tensor):
      %18 : __torch__.torch.nn.modules.module.___torch_mangle_1.Module = prim::GetAttr[name="cell"](%self)
      %19 : Tensor = aten::select(%xs.1, %7, %i.1) # rnn_test.py:38:29
      %23 : __torch__.torch.nn.modules.module.Module = prim::GetAttr[name="dg"](%18)
      %24 : __torch__.torch.nn.modules.module.___torch_mangle_0.Module = prim::GetAttr[name="linear"](%18)
      %25 : Tensor = prim::GetAttr[name="bias"](%24)
      %26 : Tensor = prim::GetAttr[name="weight"](%24)
      %27 : Float(4, 4) = aten::t(%26), scope: __module.linear # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py:1370:0
      %28 : int = prim::Constant[value=1](), scope: __module.linear # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py:1370:0
      %29 : int = prim::Constant[value=1](), scope: __module.linear # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py:1370:0
      %x : Float(10, 4) = aten::addmm(%25, %19, %27, %28, %29), scope: __module.linear # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py:1370:0
      %31 : Float(10, 4) = aten::neg(%x), scope: __module.dg # rnn_test.py:14:0
      %32 : int = prim::Constant[value=1]() # rnn_test.py:24:0
      %33 : Float(10, 4) = aten::add(%31, %h.5, %32) # rnn_test.py:24:0
      %34 : Float(10, 4) = aten::tanh(%33) # rnn_test.py:24:0
      %35 : (Float(10, 4), Float(10, 4)) = prim::TupleConstruct(%34, %34)
      %y.2 : Tensor, %h.3 : Tensor = prim::TupleUnpack(%35)
      -> (%2, %y.2, %h.3)
  return (%y)

@masahi I’m a bit busy right now. I can look into this next week. At first glance, the relay program looks legit, could be some bug in lambda lift.

1 Like

@haichen thanks for pointing me at lambda lift pass. If I understand the purpose of this pass right, this should do the job, rather than me doing the same thing in the frontend.

Here is the result of manually applying lambda lift pass.

v0.0.4
def @main(%X: Tensor[(10, 10, 4), float32], %v26: Tensor[(4, 4), float32], %v25: Tensor[(4), float32]) -> Tensor[(10, 4), float32] {
  %0 = full(0 /* ty=int32 */, shape=[10, 4], dtype="float32") /* ty=Tensor[(10, 4), float32] */;
  %1 = full(0 /* ty=int32 */, shape=[10, 4], dtype="float32") /* ty=Tensor[(10, 4), float32] */;
  %2 = (
    let %while_loop: fn (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) -> (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) = @lifted_name3474153724846960558(%X, %v26, %v25) /* ty=fn (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) -> (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) */;
    %while_loop
  );
  %3 = %2(0 /* ty=int32 */, %0, %1) /* ty=(int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) */;
  %3.1
}

def @lifted_name3474153724846960558(%X1: Tensor[(10, 10, 4), float32], %v261: Tensor[(4, 4), float32], %v251: Tensor[(4), float32], Closure=1) -> fn (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) -> (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) {
  fn (%i.1: int32, %y.5: Tensor[(10, 4), float32], %h.5: Tensor[(10, 4), float32]) -> (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) {
    %4 = less(%i.1, 10 /* ty=int32 */) /* ty=bool */;
    if (%4) {
      %5 = add(%i.1, 1 /* ty=int32 */) /* ty=int32 */;
      %6 = take(%X1, %i.1, axis=0) /* ty=Tensor[(10, 4), float32] */;
      %7 = multiply(%6, 1f /* ty=float32 */) /* ty=Tensor[(10, 4), float32] */;
      %8 = transpose(%v261, axes=[1, 0]) /* ty=Tensor[(4, 4), float32] */;
      %9 = multiply(%8, 1f /* ty=float32 */) /* ty=Tensor[(4, 4), float32] */;
      %10 = transpose(%9, axes=[1, 0]) /* ty=Tensor[(4, 4), float32] */;
      %11 = nn.dense(%7, %10, units=4) /* ty=Tensor[(10, 4), float32] */;
      %12 = nn.bias_add(%11, %v251) /* ty=Tensor[(10, 4), float32] */;
      %13 = add(%12, %h.5) /* ty=Tensor[(10, 4), float32] */;
      %14 = tanh(%13) /* ty=Tensor[(10, 4), float32] */;
      %15 = @lifted_name3474153724846960558(%X1, %v261, %v251) /* ty=fn (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) -> (int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) */;
      %15(%5, %14, %14) /* ty=(int32, Tensor[(10, 4), float32], Tensor[(10, 4), float32]) */
    } else {
      (%i.1, %y.5, %h.5)
    }
  }
}
1 Like

@haichen I found a fix which solves this issue and let me compile and run the above IR and get the identical result as Torch! But I’m not sure if my solution is correct.

I found that op in Expr op = call_node->op at the beginning of void VisitExpr_(const CallNode* call_node) is another CallNode, which is not covered by existing if/else switches. So I added another case for CallNode, like below:

    ...
    } else if (auto call_node = op.as<CallNode>()) {
      VisitExpr(GetRef<Call>(call_node));
      Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
    } else {
      // Finally if there are any other cases this is a bug.
      LOG(FATAL) << "internal error: unreachable code,"
                 << "should be transformed away by previous passes"
                 << PrettyPrint(GetRef<Expr>(call_node));
    }

The op at the top in my case is

CallNode(GlobalVar(lifted_name5306398090079843206), [Var(X, ty=TensorType([10, 10, 4], float32)), Var(26, ty=TensorType([4, 4], float32)), Var(25, ty=TensorType([4], float32))], (nullptr), [])

i.e. nested CallNode.

Do my situation and the fix above make sense? If so I can send a fix as PR. Also cc @wweic @jroesch

I remember there is a bug here. I have a local fix but forgot to upstream. I think this is a correct fix.

Btw, if it is a call node, we probably can just visit w/o emitting the invoke instruction because you will see the closure later

If I do not add emit there, I get an error at runtime when VM tries to downcast closure object to ADT (in my case this closure is supposed to return a tuple).

The fix look right. Maybe you can try to found a minimal counterexample in lambdalift without firing up the whole VM?