I was trying to compile a lidar object detection network with cuda target. The following error appeared when I tried to run it.
Check failed: ret == 0 (-1 vs. 0) : TVMError: CUDALaunch Error: CUDA_ERROR_INVALID_VALUE
grid=(16,1280000,1), block=(64,1,1)
// func_name=tvmgen_default_fused_nn_dense_kernel0
// CUDA Source
// -----------
As you can see, the grid size is invalid. Some posts suggest tuning. But it does not work for this model. The tuning tries a few time and failed because of similar reason.
I was able to fix it by modifying the default schedule of dense_small_batch.gpu
. I just bind the larger number to blockIdx.x
. I don’t know if this is the correct way to fix it. I think I can create a PR regarding this problem if it is a bug.
Model file looks like this:
My script:
model_encoder = "pts_voxel_encoder_centerpoint.onnx"
onnx_ = onnx.load(model_encoder)
x = np.ones((40000,32,9), dtype=np.float32)
input_name = "input_features"
target = tvm.target.create('cuda')
shape_dict = {input_name: x.shape}
mod, params = relay.frontend.from_onnx(onnx_, shape_dict)
with tvm.transform.PassContext(opt_level=3, config={}):
lib = relay.build(mod, target=target, params=params)
dev = tvm.device(str(target), 0)
module = graph_executor.GraphModule(lib["default"](dev))
dtype = "float32"
module.set_input(input_name, x)
module.run()
output_shape = (40000, 1, 32)
out_idx = 0
tvm_output = module.get_output(out_idx, tvm.nd.empty(output_shape)).numpy()
PS: I wonder what is the correct way to get a schduler of a single operation from a network model.