Generated CUDA Function Input

I have used matrix multiplication code from 4. Matrix Multiplication — Dive into Deep Learning Compiler 0.1 documentation (d2l.ai). However, instead of normal matrix multiplication, I have modified it to be a batched matrix multiplication.

The resulting code runs perfectly, but I want to run the resulting cuda code manually in C++ rather than compiling to pytorch function and use it python. To do this, I have been able to get the cuda code as well, but the resulting cuda kernel has many inputs, and I am not sure what would be the correct assignment.

For example, the resulting cuda code for batched multiplication has signature

default_function_kernel0(float* __restrict__ A, float* __restrict__ B, float* __restrict__ R, int d3, int d1, int stride, int bsz, int stride_1, int stride_2, int d2, int stride_3, int stride_4, int stride_5, int stride_6, int stride_7, int stride_8)

Here it’s pretty simple that A, B and C are matrix with d1, d2, d3 being its shape i.e. (m, n, k) and we have batch size. However, I am not sure what strides are. There are 9 of such strides and I have not used them while writing the TVM code.

Is there a way to know, when this function is executed through python what are the values? Because, once I know it, I can use those values and have the function executed on C++ as well.

The TVM code is as follows.

import pandas as pd
from pandas import DataFrame
import time
import tvm
from tvm.contrib import dlpack
from tvm import te
import torch
import numpy as np

b_blocksz, y_blocksz, x_blocksz, t_b, t_y, t_x, t_k = 1, 8, 16, 1, 8, 4, 16

def _compile_function():

    bsz = te.var('bsz') # Batch size
    d1 = te.var('d1')   # D1 -> # of rows of first matrix
    d2 = te.var('d2')   # D2 -> # of columns of first matrix
    d3 = te.var('d3')   # D3 -> # of Columns of first matrix and # of rows of second matrix

    A = te.placeholder((bsz, d1, d3), name='A', dtype='float32')  # first tensor
    B = te.placeholder((bsz, d3, d2), name='B', dtype='float32')  # second tensor

    k = te.reduce_axis((0, d3), name='k')  # dimension to sum over

    output_shape = (bsz, d1, d2)  # shape of the result tensor

    algorithm = lambda l, i, j: te.sum(A[l, i, k] * B[l, k, j], axis=k)
    R = te.compute(output_shape, algorithm, name='R')
    s = te.create_schedule(R.op)

    A_shared = s.cache_read(A, "shared", [R])
    A_local  = s.cache_read(A_shared, "local", [R])
    B_shared = s.cache_read(B, "shared", [R])
    B_local  = s.cache_read(B_shared, "local", [R])
    C_local  = s.cache_write(R, "local")
    

    batch, y, x = s[R].op.axis

    yo, yi           = s[R].split(y      , t_y) 
    yb, yo           = s[R].split(yo     , y_blocksz)
    xo, xi           = s[R].split(x      , t_x) 
    xb, xo           = s[R].split(xo     , x_blocksz)

    s[R].reorder(batch, yb, xb, yo, xo, yi, xi)

    s[R].bind(xb, te.thread_axis("blockIdx.x"))
    s[R].bind(yb, te.thread_axis("blockIdx.y"))
    s[R].bind(xo, te.thread_axis("threadIdx.x"))
    s[R].bind(yo, te.thread_axis("threadIdx.y"))

    s[C_local].compute_at(s[R], xo)
    batch_i, y_i, x_i = s[C_local].op.axis
    k,              = s[C_local].op.reduce_axis
    ko, ki          = s[C_local].split(k, t_k)

    s[C_local].reorder(batch_i, ko, ki, y_i, x_i)


    def optimize_read_cache(shared, local):
        s[shared].compute_at(s[C_local], ko)
        s[local ].compute_at(s[C_local], ki)
        

        b, y, x = s[shared].op.axis

        bo, bi  = s[shared].split(b, nparts=b_blocksz)
        yo, yi  = s[shared].split(y, nparts=y_blocksz)
        xo, xi  = s[shared].split(x, nparts=x_blocksz)
        

        s[shared].reorder(bo, bi, yo, xo, yi, xi)

        s[shared].bind(xo, te.thread_axis("threadIdx.x"))
        s[shared].bind(yo, te.thread_axis("threadIdx.y"))
        s[shared].bind(bo, te.thread_axis("threadIdx.z"))
        


    optimize_read_cache(A_shared, A_local)
    optimize_read_cache(B_shared, B_local)
    print('\n ===================== \n{}'.format(tvm.lower(s, [A, B, R], simple_mode=True)))
    return tvm.build(s, [A, B, R], target='cuda')

if __name__ == "__main__":
    custom_bmm = _compile_function()
    custom_bmm_pytorch = dlpack.to_pytorch_func(custom_bmm)  # wrap it as a pytorch function
    print(custom_bmm.imported_modules[0].get_source()) # to get cuda code