[AutoTVM] CUDAError: cuModuleUnload(module_[i]) failed with error: CUDA_ERROR_ILLEGAL_ADDRESS

Hello! I’m trying to run the following code yet I got the error of CUDA_ERROR_ILLEGAL_ADDRESS as shown afterwards:

"""Benchmark for dense-BSR-sparse matrix multiplication for TVM"""
import numpy as np
import tvm
from tvm import autotvm, te, topi
import tvm.topi.testing
from tvm.topi.utils import traverse_inline, get_const_int
import scipy.sparse as sp
import argparse
import logging
import sys

parser = argparse.ArgumentParser()

parser.add_argument('setting', choices=['PEP','PTP','PROB','PRWB','PRWB_AT'])
parser.add_argument('--n_trial', default=500, type=int, help='Number of trials for AutoTVM')
parser.add_argument('--repeat', default=3, type=int, help='Number of repeat for AutoTVM to profile on the device')
parser.add_argument('--tune', action='store_true', help='Enable AutoTVM tuning for setting PRWB_AT')
parser.add_argument('--autotvm_log', default='blocksparse.log', type=str, help='Log file for auto tuning')
args = parser.parse_args()

target = 'cuda'
device = tvm.device(target, 0)
np.random.seed(42)


def int_div(x, y):
    return (x + y - 1) // y


def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype):
    """Generate a random BSR matrix"""
    import itertools
    Y = np.zeros((M, N), dtype=dtype)
    assert M % BS_R == 0
    assert N % BS_C == 0
    nnz = int(density * M * N)
    num_blocks = int_div(nnz, (BS_R * BS_C))
    candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C))))
    assert candidate_blocks.shape[0] == M // BS_R * N // BS_C
    chosen_blocks = candidate_blocks[np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)]
    for i in range(len(chosen_blocks)):
        r, c = chosen_blocks[i]
        Y[r:r + BS_R, c:c + BS_C] = np.random.randn(BS_R, BS_C)
    s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C))
    assert s.data.shape == (num_blocks, BS_R, BS_C)
    assert s.indices.shape == (num_blocks, )
    assert s.indptr.shape == (M // BS_R + 1, )
    return s


def schedule_sparse_dense_cuda_allreduce(outs):
    """Create schedule for sparse dense"""
    s = te.create_schedule([x.op for x in outs])

    def _callback(op):
        if op.tag == "sparse_dense_sp_rhs_bsrmm":
            y_bsrmm = op.input_tensors[0]
            w_indptr = y_bsrmm.op.input_tensors[0]
            assert y_bsrmm.op.tag == "sparse_dense_sp_rhs_bsrmm_block"
            y_reshape = op
            (m, num_blocks, b_r) = s[y_bsrmm].op.axis
            bs_r = get_const_int(b_r.dom.extent)
            (elem_idx, c) = s[y_bsrmm].op.reduce_axis
            
            (m_o, n_o) = s[y_reshape].op.axis
            s[y_reshape].bind(m_o, te.thread_axis("blockIdx.x"))
            s[y_reshape].bind(n_o, te.thread_axis("blockIdx.y"))
            s[y_bsrmm].compute_at(s[y_reshape], n_o)

            thread_x = te.thread_axis("threadIdx.x")
            co, ci = s[y_bsrmm].split(c, 8)
            y_bsrmm_factored = s.rfactor(y_bsrmm, ci)
            tx = s[y_bsrmm].op.reduce_axis[0]
            s[y_bsrmm].bind(tx, thread_x)
            s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx)
            s[y_bsrmm].set_store_predicate(thread_x.var.equal(0))
            s[y_reshape].set_store_predicate(thread_x.var.equal(0))

    traverse_inline(s, outs[0].op, _callback)
    return s


def schedule_sparse_dense_cuda_allreduce_autotune(cfg, outs):
    """Create schedule for sparse dense"""
    s = te.create_schedule([x.op for x in outs])

    def _callback(op):
        if op.tag == "sparse_dense_sp_rhs_bsrmm":
            y_bsrmm = op.input_tensors[0]
            w_indptr = y_bsrmm.op.input_tensors[0]
            assert y_bsrmm.op.tag == "sparse_dense_sp_rhs_bsrmm_block"
            y_reshape = op
            (m, num_blocks, b_r) = s[y_bsrmm].op.axis
            bs_r = get_const_int(b_r.dom.extent)
            (elem_idx, c) = s[y_bsrmm].op.reduce_axis
            
            (m_o, n_o) = s[y_reshape].op.axis
            s[y_reshape].bind(m_o, te.thread_axis("blockIdx.x"))
            s[y_reshape].bind(n_o, te.thread_axis("blockIdx.y"))
            s[y_bsrmm].compute_at(s[y_reshape], n_o)

            cfg.define_knob("hehe", [0])
            thread_x = te.thread_axis("threadIdx.x")
            co, ci = s[y_bsrmm].split(c, 8)
            y_bsrmm_factored = s.rfactor(y_bsrmm, ci)
            tx = s[y_bsrmm].op.reduce_axis[0]
            s[y_bsrmm].bind(tx, thread_x)
            s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx)
            s[y_bsrmm].set_store_predicate(thread_x.var.equal(0))
            s[y_reshape].set_store_predicate(thread_x.var.equal(0))

    traverse_inline(s, outs[0].op, _callback)
    return s

# cfg.define_split("tile_c", c, num_outputs=2)
# co, ci = cfg['tile_c'].apply(s, y_bsrmm, c)


def test_sparse_dense_bsr_autotune(M, N, K, BS_R, BS_C, density):
    """Benchmark sparse-dense matrix multiplication with auto tuning enabled"""
    print("testing param", M, N, K, BS_R, BS_C, density)
    X_np = np.random.randn(M, K).astype("float32")
    W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32")
    W_np = W_sp_np.todense()
    Y_np = X_np.dot(W_np.T)

    # logging config (for printing tuning log to screen)
    logging.getLogger('autotvm').setLevel(logging.DEBUG)
    logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))

    W_sp_np_data_shape, W_sp_np_indices_shape, W_sp_np_indptr_shape, X_np_shape = W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, X_np.shape
    
    task = autotvm.task.create("benchmark/block_sparse",
                            args=(W_sp_np_data_shape, W_sp_np_indices_shape, W_sp_np_indptr_shape, X_np_shape),
                            target=target)
    print(task.config_space)

    # Use local gpu, measure multiple times for every config to reduce variance
    # The timeout for running is 4 seconds
    measure_option = autotvm.measure_option(
        builder=autotvm.LocalBuilder(),
        runner=autotvm.LocalRunner(repeat=args.repeat, min_repeat_ms=100, timeout=4)
    )

    # Begin tuning, log records to file `conv2d.log`
    # During tuning we will also try many invalid configs, so you are expected to
    # see many error reports. As long as you can see non-zero GFLOPS, it is okay.
    tuner = autotvm.tuner.XGBTuner(task)
    if args.tune:
        tuner.tune(n_trial=args.n_trial,
                measure_option=measure_option,
                callbacks=[autotvm.callback.log_to_file(args.autotvm_log)])

    # apply history best from log file
    with autotvm.apply_history_best(args.autotvm_log):
        with tvm.target.Target(target):
            s, arg_bufs = block_sparse_template(W_sp_np_data_shape, W_sp_np_indices_shape, W_sp_np_indptr_shape, X_np_shape)
            func = tvm.build(s, arg_bufs)
            print(tvm.lower(s, arg_bufs, simple_mode=True))
            print(func.imported_modules[0].get_source())

    timer = func.time_evaluator(func.entry_name, device, number=20)
    Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), device)

    mean_time = timer(tvm.nd.array(X_np, device),
                      tvm.nd.array(W_sp_np.data, device),
                      tvm.nd.array(W_sp_np.indices, device),
                      tvm.nd.array(W_sp_np.indptr, device),
                      Y_tvm).mean
    
    print('%g ms' % (mean_time * 1e3))
    print("------------------------")
    tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)


@autotvm.template("benchmark/block_sparse")
def block_sparse_template(W_sp_np_data_shape, W_sp_np_indices_shape, W_sp_np_indptr_shape, X_np_shape):
    W_data = te.placeholder(shape=W_sp_np_data_shape, dtype='float32', name='W_data')
    W_indices = te.placeholder(shape=W_sp_np_indices_shape, dtype='int32', name='W_indices')
    W_indptr = te.placeholder(shape=W_sp_np_indptr_shape, dtype='int32', name='W_indptr')
    X = te.placeholder(shape=X_np_shape, dtype='float32', name='X')
    Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)

    cfg = autotvm.get_config()
    cfg.add_flop(W_sp_np_data_shape[0] * X_np_shape[0] * W_sp_np_data_shape[1] * W_sp_np_data_shape[2] * 2)
    s = schedule_sparse_dense_cuda_allreduce_autotune(cfg, [Y])
    return s, [X, W_data, W_indices, W_indptr, Y]


def test_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, compute_func, schedule_func):
    """Benchmark sparse-dense matrix multiplication"""
    print("testing param", M, N, K, BS_R, BS_C, density)
    X_np = np.random.randn(M, K).astype("float32")
    W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32")
    W_np = W_sp_np.todense()
    Y_np = X_np.dot(W_np.T)

    W_data = te.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype), name='W_data')
    W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype), name='W_indices')
    W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype), name='W_indptr')
    X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype), name='X')    
    Y = compute_func(X, W_data, W_indices, W_indptr)
    s = schedule_func([Y])
    print(tvm.lower(s, [X, W_data, W_indices, W_indptr, Y], simple_mode=True))
    func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
    timer = func.time_evaluator(func.entry_name, device, number=20)
    Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), device)

    mean_time = timer(tvm.nd.array(X_np, device),
                      tvm.nd.array(W_sp_np.data, device),
                      tvm.nd.array(W_sp_np.indices, device),
                      tvm.nd.array(W_sp_np.indptr, device),
                      Y_tvm).mean
    
    print('%g ms' % (mean_time * 1e3))
    print("------------------------")
    tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)


settings = {
    'PRWB': (topi.nn.sparse_dense, schedule_sparse_dense_cuda_allreduce)
}


if __name__ == "__main__":
    with tvm.target.create(target):
        for N, K in [(128, 128)]:
            for M in [32]:
                for BS_R in [32]:
                    BS_C = BS_R
                    for density in [0.25]:
                        if args.setting == 'PRWB_AT':
                            test_sparse_dense_bsr_autotune(M, N, K, BS_R, BS_C, density)
                        else:
                            test_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, *settings[args.setting])

The error I got is like

terminate called after throwing an instance of 'tvm::runtime::InternalError'
  what():  [01:32:51] /home/moderato/Documents/incubator-tvm/src/runtime/cuda/cuda_module.cc:61: CUDAError: cuModuleUnload(module_[i]) failed with error: CUDA_ERROR_ILLEGAL_ADDRESS
Stack trace:
  0: tvm::runtime::SimpleObjAllocator::Handler<tvm::runtime::CUDAModuleNode>::Deleter_(tvm::runtime::Object*)
  1: tvm::runtime::SimpleObjAllocator::Handler<tvm::runtime::LibraryModuleNode>::Deleter_(tvm::runtime::Object*)
  2: _ZNSt14_Function_base13_Base_managerIZN3tvm7runtime14WrapPackedFuncEPFiP8TVMValuePiiS4_S5_PvERKNS2_9ObjectPtrINS2_6ObjectEEEEUlNS2_7TVMArgsEPNS2_11TVMRetValueEE_E10_M_managerERSt9_Any_dataRKSJ_St18
  3: std::_Function_base::_Base_manager<tvm::runtime::WrapTimeEvaluator(tvm::runtime::PackedFunc, DLDevice, int, int, int, tvm::runtime::PackedFunc)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}>::_M_manager(std::_Any_data&, std::_Any_data const&, std::_Manager_operation)
  4: tvm::runtime::LocalSession::FreeHandle(void*, int)
  5: tvm::runtime::RPCFreeHandle(tvm::runtime::RPCSession*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  6: void tvm::runtime::RPCEndpoint::EventHandler::SysCallHandler<void (*)(tvm::runtime::RPCSession*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>(void (*)(tvm::runtime::RPCSession*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*))
  7: tvm::runtime::RPCEndpoint::EventHandler::HandleSyscall(tvm::runtime::RPCCode)
  8: tvm::runtime::RPCEndpoint::HandleUntilReturnEvent(bool, std::function<void (tvm::runtime::TVMArgs)>)
  9: tvm::runtime::RPCEndpoint::ServerLoop()
  10: tvm::runtime::RPCServerLoop(int)
  11: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  12: TVMFuncCall

The command to run the above code is

python example.py PRWB # good
python example.py PRWB_AT --tune # bad

You can see that although the auto-tuning is basically dummy and has the same schedule like the one without auto-tuning, it still throws this error. I have tried CUDA 11.3 with drivers 470 and 510 but without any luck. The GPU I use is a 1080.

Anyone can help? Thanks in advance!

I wonder if anyone could help?