Inspecting generated cuda host code

Hi,

I’m trying to generate different memory allocation plannings for Nvidia GPUs, and would like to analyze the CUDA memory allocations of TVM IRModules depending on various optimizations. For this, I need to see the cudaMemcpy and CudaMallocs that TVM would generate. I found how to dump CUDA device code (with the computations), but not the CUDA host code (with the allocations). The only host code I could extract was llvm code.
I also found that we can split host and device tir modules with tvm.tir.build.split_host_device_mods, but couldn’t see how to generate CUDA host code from this.
Is it possible to extract CUDA host code ? If yes, how should I proceed ?

Related question, but a bit old : Any way to dump the cuda host code? - #2 by x-huan . They mention microTVM which seems phased out.

[Python API] Getting cuda host code (example)

import tvm
from tvm.script import ir as I
from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def conv2d7(lv34: T.Buffer((T.int64(1), T.int64(1024), T.int64(13), T.int64(13)), "float32"), B: T.Buffer((T.int64(1024), T.int64(1024), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(1), T.int64(1024), T.int64(13), T.int64(13)), "float32")):
        T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((T.int64(1), T.int64(1024), T.int64(15), T.int64(15)))
        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1024), T.int64(15), T.int64(15)):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(lv34[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(14) and T.int64(1) <= v_i3 and v_i3 < T.int64(14), lv34[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float32(0.0))
        for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(1), T.int64(1024), T.int64(13), T.int64(13), T.int64(1024), T.int64(3), T.int64(3)):
            with T.block("conv2d_nchw"):
                v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
                T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], B[v_ff, v_rc, v_ry, v_rx])
                T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
                with T.init():
                    conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0.0)
                conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * B[v_ff, v_rc, v_ry, v_rx]

import pycuda.driver as cuda

def get_cuda_target_config():
    import pycuda.autoinit  # error : pycuda._driver.LogicError: cuDeviceGet failed: initialization error
    device = cuda.Device(0)
    cc = device.compute_capability()
    sm_arch = f"sm_{cc[0]}{cc[1]}"
    print(f"Detected CUDA compute capability: {cc} → arch={sm_arch}")

    attrs = {
        "kind": "cuda",
        "arch": sm_arch,
        "max_threads_per_block": device.get_attribute(cuda.device_attribute.MAX_THREADS_PER_BLOCK),
        "max_shared_memory_per_block": device.get_attribute(cuda.device_attribute.MAX_SHARED_MEMORY_PER_BLOCK),
        "thread_warp_size": device.get_attribute(cuda.device_attribute.WARP_SIZE),
    }

    return tvm.target.Target(attrs)

# Define cuda target
target = get_cuda_target_config()

# Simple bind for thread and blocks
bound_mod = tvm.tir.transform.DefaultGPUSchedule()(
    tvm.tir.transform.BindTarget(tvm.target.Target(target))(Module)
)
print(bound_mod.show())

lib = tvm.build(bound_mod, target="cuda")
print(f"--- Cuda generated device code ---")
print(lib.imported_modules[0].get_source())

Split host and device TIRs

From the tvm.tir.build.build function

from typing import Union, Optional
from tvm import IRModule
from tvm.tir import PrimFunc
from tvm.target import Target
from tvm.runtime import ndarray
from tvm.tir.build import split_host_device_mods

def get_device_host_code(
    mod: Union[PrimFunc, IRModule],
    target: Optional[Union[str, Target]] = None,
    pipeline: Union[None, str, tvm.transform.Pass] = "default",
):
    """Build a function with a signature, generating code for devices
    coupled with target information.

    Parameters
    ----------
    mod : Union[PrimFunc, IRModule]
        The input to be built.
    target : Optional[Union[str, Target]]
        The target for compilation.
    pipeline : Union[None, str, tvm.transform.Pass]
        The pipeline to use for compilation.

    Returns
    -------
    tvm.runtime.Module
        A module combining both host and device code.
    """
    # Convert PrimFunc to IRModule
    if isinstance(mod, PrimFunc):
        mod = tvm.IRModule.from_expr(mod)
    else:
        assert isinstance(mod, tvm.IRModule)

    # Step 0: Determine the target in environment
    # It's used to bind the PrimFunc without target attr to serve as a default target
    target_to_bind = Target.current() if target is None else target
    if target_to_bind is None:
        target_to_bind = "llvm"
    assert target_to_bind is not None
    target_to_bind = Target.canon_target(target_to_bind)

    # Step 1: Determine the target to search for tir pipeline
    target = Target.current() if target is None else target
    if target is None:
        for func in mod.functions.values():
            f_target = func.attrs.get("target", None)
            if f_target is not None:
                target = f_target
                break
    if target is not None:
        target = Target.canon_target(target)

    # Step 2: Determine the host target
    target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
    if target is not None:
        if target.host is not None:
            target_host = target.host
        elif ndarray.device(target.kind.name, 0).device_type == ndarray.cpu(0).device_type:
            target_host = target
    target_host = Target.canon_target(target_host)
    target_to_bind = target_to_bind.with_host(target_host)

    # Step 3: Bind the target to the input module
    mod = tvm.tir.transform.BindTarget(target_to_bind)(mod)

    # Step 4: Apply the tir  pipeline
    if pipeline is not None:
        # custom pipeline
        if isinstance(pipeline, str):
            pipeline = tvm.tir.get_tir_pipeline(pipeline)
    else:
        # default pipeline depends on the target
        pipeline = tvm.tir.get_default_tir_pipeline(target)
    mod = pipeline(mod)

    # Step 5: Get host and device modules
    host_mod, device_mod_dict = split_host_device_mods(mod)

    # Step 6: Apply finalization passes
    host_mod = tvm.tir.pipeline.finalize_host_passes()(host_mod)
    device_mod_dict = {
        target: tvm.tir.pipeline.finalize_device_passes()(device_mod)
        for target, device_mod in device_mod_dict.items()
    }

    return host_mod, device_mod_dict

device_mod, host_mod = get_device_host_code(bound_mod, target)
print(host_mode.show())

I’m not familiar with the C++ part of TVM, is it possible to generate CUDA host code from there ?

Thanks !

Memory Allocation & Data Copy in TVM for CUDA

TVM’s same Python code works across CPUs/GPUs/other backends, so TVM uses implicit memory management, which means TVM abstracts low-level CUDA calls (like cudaMalloc/cudaMemcpy) behind high-level APIs (tvm.nd.array()).

Use the following simple vector addition IRModule as an example:

@tvm.script.ir_module
class VecAdd:
    @T.prim_func
    def main(A: T.Buffer((1024,), "float32"),
             B: T.Buffer((1024,), "float32"),
             C: T.Buffer((1024,), "float32")) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i in T.grid(1024):
            with T.block("C"):
                vi = T.axis.remap("S", [i])
                C[vi] = A[vi] + B[vi]

sch = tvm.tir.Schedule(VecAdd)
block_C = sch.get_block("C")
i, = sch.get_loops(block=block_C)
i0, i1 = sch.split(i, [None, 128])
sch.bind(i0, "blockIdx.x")
sch.bind(i1, "threadIdx.x")

# Build
rt_mod = tvm.build(sch.mod, target="cuda")

# Initial inputs/outputs
A_np = np.random.uniform(size=(1024,)).astype("float32")
B_np = np.random.uniform(size=(1024,)).astype("float32")
A_nd = tvm.nd.array(A_np, tvm.cuda(0))
B_nd = tvm.nd.array(B_np, tvm.cuda(0))
C_nd = tvm.nd.array(np.zeros((1024,), dtype="float32"), tvm.cuda(0))

# Execute
rt_mod["main"](A_nd, B_nd, C_nd)


When we initialize input/output tensors using tvm.nd.array(), here’s what happens under the hood:

1. Memory Allocation (empty() => cudaMalloc)

  • The __init__ magic function of tvm.nd.array returns:
    return empty(arr.shape, arr.dtype, device, mem_scope).copyfrom(arr)
    
  • empty() triggers the following C++ backend chain:
    NDArray::Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional<String> mem_scope);
    
    • NDArray::Empty() invokes the virtual method:
    void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint);
    
  • The AllocDataSpace function has different implementation for different backends. For CUDA, the implementation resolves to cudaMalloc (or cuMemAlloc in CUDA Driver API) to allocate device memory (see tvm/src/runtime/cuda/cuda_device_api.cc at 8e478f57b9c75c8778ccb2b32f92cc89209e6aed · apache/tvm · GitHub).

2. Data Copy (copyfrom() => cudaMemcpyAsync)

  • The __init__ magic function of tvm.nd.array returns:
    return empty(arr.shape, arr.dtype, device, mem_scope).copyfrom(arr)
    
  • The .copyfrom(arr) operation calls the C++ backend function:
    void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset,
                        size_t size, Device dev_from, Device dev_to, DLDataType type_hint,
                        TVMStreamHandle stream);
    
  • When copying from host (CPU) to device (GPU): this ultimately calls cudaMemcpyAsync to perform the transfer asynchronously (see tvm/src/runtime/cuda/cuda_device_api.cc at 8e478f57b9c75c8778ccb2b32f92cc89209e6aed · apache/tvm · GitHub).

I have abbreviated this process a lot, but you can find the relevant content based on this call path. Hope it helps you.

2 Likes

Hi,

Thanks for your answer, it helps me a lot. I will dig in to find the relevant CudaMalloc/CudaMemcpy calls.