[AutoTVM] CUDA runtime error when tuning extern ops

@merrymercy @comaniac

This is an issue related to my WIP PR (see this comment https://github.com/apache/tvm/pull/7233#issuecomment-757000179 for context).

I have two implementations of GPU scatter op, and both of them are extern ops. I want AutoTVM to run both of them and pick the faster one for a given input.

I thought this should be a simple matter, but I couldn’t get it working. Even though there is no tuning parameters and the generated kernel looks no problem, I get the error

Check failed: e == cudaSuccess || e == cudaErrorCudartUnloading ==
false: CUDA: an illegal memory access was encountered

Any idea what could go wrong in such cases? Since the op is an extern op and argument shapes look correct, I don’t understand why I get CUDA runtime error.

This is the tuning script I’m using. The module consists of a single scatter op.

This might or might not be related to your question but I have also been seeing this recently with CUDA 10.2 and when I switched to CUDA 10.0 it goes away.

So the ops work well if you directly build the model without running AutoTVM?

Yes. Outside of AutoTVM tuning, both implementation work no problem.

Since they are extern op, compiled kernels are the same inside or outside of AutoTVM. And input shapes are correct (otherwise TVM should give an error). So it’s weird, I have no idea what could go wrong inside AutoTVM.

Hmm it’s weird. Not sure if this helps. Maybe you can add a dummy parameter such as cfg.define_knob("dummy", [1]) and see what happens? If it still doesn’t work, maybe you could share your branch with the AutoTVM task registration so that I could try to reproduce the problem.

Adding dumming knob didn’t help. I tried on both CUDA 11 and 10.2, got the same error:

EBUG:autotvm:No: 1	GFLOPS: 0.00/0.00	result: MeasureResult(costs=('Traceback (most recent call last):\n  [bt] (8) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(+0x1360897) [0x7f5734e74897]\n  [bt] (7) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::RPCWrappedFunc::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x20b) [0x7f5734e7996b]\n  [bt] (6) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::RPCClientSession::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)> const&)+0x57) [0x7f5734e6c5b7]\n  [bt] (5) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::RPCEndpoint::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)>)+0x3a1) [0x7f5734e62781]\n  [bt] (4) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::RPCEndpoint::HandleUntilReturnEvent(bool, std::function<void (tvm::runtime::TVMArgs)>)+0x2dd) [0x7f5734e613ed]\n  [bt] (3) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::RPCEndpoint::EventHandler::HandleNextEvent(bool, bool, std::function<void (tvm::runtime::TVMArgs)>)+0xd7) [0x7f5734e6c247]\n  [bt] (2) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::RPCEndpoint::EventHandler::HandleProcessPacket(std::function<void (tvm::runtime::TVMArgs)>)+0x125) [0x7f5734e6c005]\n  [bt] (1) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::RPCEndpoint::EventHandler::HandleReturn(tvm::runtime::RPCCode, std::function<void (tvm::runtime::TVMArgs)>)+0x13f) [0x7f5734e6be0f]\n  [bt] (0) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(+0x134c27b) [0x7f5734e6027b]\n  [bt] (8) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::RPCEndpoint::HandleUntilReturnEvent(bool, std::function<void (tvm::runtime::TVMArgs)>)+0x2dd) [0x7f85195993ed]\n  [bt] (7) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::RPCEndpoint::EventHandler::HandleNextEvent(bool, bool, std::function<void (tvm::runtime::TVMArgs)>)+0xd7) [0x7f85195a4247]\n  [bt] (6) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::RPCEndpoint::EventHandler::HandleProcessPacket(std::function<void (tvm::runtime::TVMArgs)>)+0x1cb) [0x7f85195a40ab]\n  [bt] (5) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::RPCSession::AsyncCallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::RPCCode, tvm::runtime::TVMArgs)>)+0x54) [0x7f85195b6284]\n  [bt] (4) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::LocalSession::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)> const&)+0x66) [0x7f85195a9d06]\n  [bt] (3) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(+0x135f2ab) [0x7f85195ab2ab]\n  [bt] (2) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(+0x135ef64) [0x7f85195aaf64]\n  [bt] (1) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::CUDADeviceAPI::StreamSync(DLContext, void*)+0xd4) [0x7f85195dbb14]\n  [bt] (0) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(+0x138f04b) [0x7f85195db04b]\n  File "/home/masa/projects/dev/tvm/src/runtime/rpc/rpc_endpoint.cc", line 378\nRPCError: Error caught from RPC call:\n[06:55:06] /home/masa/projects/dev/tvm/src/runtime/cuda/cuda_device_api.cc:195: \n---------------------------------------------------------------\nAn internal invariant was violated during the execution of TVM.\nPlease read TVM\'s error reporting guidelines.\nMore details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.\n---------------------------------------------------------------\n  Check failed: e == cudaSuccess || e == cudaErrorCudartUnloading == false: CUDA: an illegal memory access was encountered\n',), error_no=7, all_cost=4, timestamp=1610402106.017409)	[('dummy', 1)],None,0
WARNING:root:Could not find any valid schedule for task Task(func_name=scatter.cuda, args=(('TENSOR', (5000,), 'float32'), ('TENSOR', (5000,), 'int64'), ('TENSOR', (5000,), 'float32'), 0), kwargs={}, workload=('scatter.cuda', ('TENSOR', (5000,), 'float32'), ('TENSOR', (5000,), 'int64'), ('TENSOR', (5000,), 'float32'), 0)). A file containing the errors has been written to /tmp/tvm_tuning_errors_6mrav52x.log.

These are my branch and test script. The error should be reproducible by running that script under my branch. I’d really appreciate if you could take a look at them.

1 Like

@comaniac @FrozenGene @merrymercy

I figured out what’s going on. The issue is random_fill function used by AutoTVM:

I’m trying to tune scatter op, which has an integer input as one of inputs. This poses two problems for random_fill.

  1. random_fill implementation is busted if the input is 32 or 64 bit integer tensor, see

As a result of this, the index tensor is filled with values like

array([1091617754, 1090832347, 1091032493, ..., 1084215127, 1073655395,
       1090813385], dtype=int32
  1. The value of index tensor cannot be any random values, since the entry is the index into the scattered tensor. In our relay test for example, the index tensor is initialized as follows:

How should we solve this issue? I think I’ll simply skip the random init when the workload is from scatter op.

2 Likes

Ah this is tricky…I would say the simplest solution is using zeros for the inputs that are not float32/64, although this might still have the latency accuracy as before.

On the other hand, we could let users optionally provide the input data in each tuning task. IIRC, we had this logic before for correctness checking but removed recently due to some stability issues.

cc @FrozenGene

1 Like

Yes, user-provided, per-task input is an ideal solution (maybe via callback). For some ops, even float inputs need to be initialized in a specialized way. For example, numpy searchsorted function, which we don’t have in TVM yet, requires one of its input to be sorted.

A quick question: Why random_fill needs to be implemented in C++? Why not just use numpy array to initialize tvm array? @FrozenGene

In https://github.com/apache/tvm/pull/7233, I changed the initialization of tvm.nd.array to use zero-initialized numpy array instead of nd.empty. Maybe you can take a look?

I think this problem’s solution should be not use random input data(nd.empty / random fill should not be used). Before, we could provide ref_input, however, I don’t know whether this is deleted or not.

Implementation in C++ is we want to leverage the logic in sort.cc, another is we want to support CPU@remote device -> GPU@remote device, not CPU@local device -> CPU@remote device -> GPU@remote device. If we use numpy, we will generate the data in CPU@local device.

It seems ref_input stuff was removed very recently in https://github.com/apache/tvm/pull/7250. But even before that PR, ref_input is randomly initialized.

So I believe we need a new mechanism to hook up user-provided inputs and tuning, to support tuning non-standard operators that have requirement on their inputs.