Getting block and grid dimensions from CUDA graph

I’m trying to manually run a model generated using relay.build. Basically, I’m pretending to be the graph runtime (for an experimental system that uses CUDA kernels but not the local CUDA drivers).

So far I know how to get the raw source, compiled ptx, execution graph, and parameters. What I’m missing are the block and grid dimensions. Here’s a minimum example:

import numpy as np
from tvm import relay
from tvm.relay import testing
import tvm

mod, params = relay.testing.mlp.get_workload(1)

target = tvm.target.cuda()
with tvm.transform.PassContext():
    graphMod = relay.build(mod, target, params=params)

lib = graphMod.get_lib()
cudaLib = lib.imported_modules[0]

executionGraph = graphMod.get_json()
rawParams = { k : p.asnumpy() for k,p in params.items() }
rawSource = cudaLib.get_source()
cudaLib.save("foo.ptx")
# gridDim, blockDim = ????

From cuda_module.cc I see that these are encoded in the TVMArgs for a CUDAWrappedFunction, but I’m not clear on how to access this from Python or otherwise derive them. Ideally I’d do this without having to instrument TVM and run the model once.

One other question about this: The parameters are named in params, but numbered in the execution graph. Are the arrays in params ordered the same way as in the graph?

1 Like

Hey, I’m trying to do the same thing you did, run tvm model manually. And I’m just as confused as you are about how to get girdDim and blockDim. I wonder if you have found any solution after all this time

Block / grid dimensions are encoded as part of host modules that are executed using llvm, which will invoke CUDAWrappedFunction at runtime. You will need to customize the compilation pipeline to extract such information.

Thank you for your reply. I first tried to get these messages by modifying HostDeviceSplitter. But I was really new to tvm, so I meet some problems and gave up(I’m confused about the meaning of the elements contained in m.thread_extent_). Now I modify CUDAWrappedFunc and get these messages by running the kernel once.This may not be the best way(I guess), because I’m not sure whether the lanuch parameter for each Kernel is a constant or context-dependent value.

if the input shape is static, you can assume the kernel launch params are constant

Hi, I meet a new problem, I wonder if you could give me a hint.I want to know the order in which the fused operators generated by relay.build are launched and where the input parameters come from.I first wanted to use graph_json to get these messages, but in my simple demo, graph_json only had 5 compute nodes, but actually launched 6 kernels(tested by nvidia nsight compute), so I don’t know where the extra kernel comes from. I would appreciate it if you could give me a hint. More details are here How to understand graph_json generated by relay.build?

Building with target(“cuda”, host=“c”) will make it easier to see what happened. For example, the following code indicates gridDim=<16, 28, 112>> and blockDim=<4, 4, 1>.

  (((TVMValue*)stack_value)[3].v_int64) = ((int64_t)16);
  ((int32_t*)stack_tcode)[3] = 0;
  (((TVMValue*)stack_value)[4].v_int64) = ((int64_t)28);
  ((int32_t*)stack_tcode)[4] = 0;
  (((TVMValue*)stack_value)[5].v_int64) = ((int64_t)112);
  ((int32_t*)stack_tcode)[5] = 0;
  (((TVMValue*)stack_value)[6].v_int64) = ((int64_t)4);
  ((int32_t*)stack_tcode)[6] = 0;
  (((TVMValue*)stack_value)[7].v_int64) = ((int64_t)4);
  ((int32_t*)stack_tcode)[7] = 0;
  (((TVMValue*)stack_value)[8].v_int64) = ((int64_t)1);
  ((int32_t*)stack_tcode)[8] = 0;