[cross-compilation-and-rpc] How to setup remote RPC compute with lib build with relay.build?

Hello, I follow the tutorial cross-compilation-and-rpc for test a simple conv model on armv7a device without RPC tracker.

Remote RPC can compute with lib build with tvm.build, but not with relay.build. Here is my dev envs:

TVM version: commit 7315c9d5b46408a047409f1e047cb6b7779c668e (HEAD, tag: v0.14.0.rc0, 
tag: v0.14.0, origin/v0.14.0)
Target Device: (Hardware: RK3188, OS: Android-19 (armeabi-v7a))
TVM compile on Host Device: (Hardware: PC (x86_64), OS: Ubuntu 18.04.6)

The test step and code is below:

1. Start target device server

./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090

2. Run the code on Linux host

import numpy as np
import os
import torch

import tvm
from tvm import te
from tvm import rpc
from tvm.contrib import utils
from tvm import relay, autotvm

simple_test = True

# target = "llvm -mtriple=armv7l-linux-gnueabihf"
target = 'llvm -device=arm_cpu -mtriple=armv7a-linux-gnueabihf'

if simple_test:
    n = tvm.runtime.convert(1024)
    A = te.placeholder((n,), name="A")
    B = te.compute((n,), lambda i: A[i] + 1.0, name="B")
    s = te.create_schedule(B.op)

    # func = tvm.build(s, [A, B], target=target, name="add_one")

    from tvm.ir.module import IRModule
    func = te.create_prim_func([A, B])
    ir_module_from_te = IRModule({"main": func})
    print(type(ir_module_from_te), ir_module_from_te.script())

    mod = tvm.build(ir_module_from_te, target=target)  # The module for CPU backends.
    print(type(mod))
else:
    class SingleConv(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv2d(1,1,1,bias=False)
        def forward(self, input):
            return self.conv(input)

    model = SingleConv()
    model = model.eval()

    # We grab the TorchScripted model via tracing
    input_shape = [1, 1, 40, 40]
    input_data = torch.randn(input_shape)
    scripted_model = torch.jit.trace(model, input_data).eval()

    print('scripted_model:', scripted_model)

    input_name = "input0"
    shape_list = [(input_name, input_data.shape)]
    ir_mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
    print(type(ir_mod), ir_mod.script())

    lib = relay.build(ir_mod, target=target)  # The module for CPU backends.
    mod = lib.lib
    print(type(mod))

host = "192.168.1.5"
port = 9000
remote = rpc.connect(host, port)

CC = '/home/dyc/toolchains/android-ndk-r25c/toolchains/llvm/prebuilt/linux-x86_64/bin/armv7a-linux-androideabi19-clang'
from tvm.contrib import ndk
os.environ["TVM_NDK_CC"] = CC

f_so = 'mod.so'
mod.export_library(f_so, fcompile=ndk.create_shared)

print('[Upload]:', f_so)
remote.upload(f_so)
print('After upload.')

func = remote.load_module(os.path.split(f_so)[1])
print('After load_module.')

if simple_test:
    # create arrays on the remote device
    dev = remote.cpu()
    a = tvm.nd.array(np.random.uniform(size=1024).astype(np.float32), dev)
    b = tvm.nd.array(np.zeros(1024, dtype=np.float32), dev)
    # the function will run on the remote device
    func(a, b)
    np.testing.assert_equal(b.numpy(), a.numpy() + 1)
    print('ok')

    time_f = func.time_evaluator(func.entry_name, dev, number=10)
    cost = time_f(a, b).mean
    print("%g secs/op" % cost)
else:
    # create arrays on the remote device
    dev = remote.cpu()
    a = tvm.nd.array(np.random.uniform(size=[1,1,40,40]).astype(np.float32), dev)
    b = tvm.nd.array(np.zeros([1,1,40,40], dtype=np.float32), dev)
    # the function will run on the remote device
    func(a, b)
    print(a.numpy(), b.numpy())
    print('ok')
    # np.testing.assert_equal(b.numpy(), a.numpy() + 1)

    time_f = func.time_evaluator(func.entry_name, dev, number=10)
    cost = time_f(a, b).mean
    print("%g secs/op" % cost)

This output:

<class 'tvm.ir.module.IRModule'> # from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024,), "float32"), B: T.Buffer((1024,), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i in range(1024):
            with T.block("B"):
                v_i = T.axis.spatial(1024, i)
                T.reads(A[v_i])
                T.writes(B[v_i])
                B[v_i] = A[v_i] + T.float32(1)
<class 'tvm.driver.build_module.OperatorModule'>
[Upload]: mod.so
After upload.
After load_module.
ok
5.7458e-06 secs/op

With simple_test=True, the code test lib build by tvm.build, but when simple_test=False, I got following error:

One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
<class 'tvm.runtime.module.Module'>
[Upload]: mod.so
After upload.
After load_module.
Traceback (most recent call last):
  File "test_rpc_api_android_one_so.py", line 95, in <module>
    func(a, b)
  File "/mnt/old_home/home/dyc/work_space/work_directory/model_acceleration_projects/quantize/framework/tvm_release/python/tvm/runtime/module.py", line 201, in __call__
    return self.entry_func(*args)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/mnt/old_home/home/dyc/work_space/work_directory/model_acceleration_projects/quantize/framework/tvm_release/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  2: tvm::runtime::RPCWrappedFunc::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  1: tvm::runtime::RPCClientSession::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)> const&)
  0: tvm::runtime::RPCEndpoint::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)>)
  File "/home/dyc/work_space/work_directory/model_acceleration_projects/quantize/framework/tvm_release/src/runtime/rpc/rpc_endpoint.cc", line 818
InternalError: Check failed: (code == RPCCode::kReturn) is false: code=kShutdown

How to fix this error? Thanks.