[tvm0.20.0-Relax] How to codegen llvm-ir with Relax?

Hello everyone, I’ve worked a lot with LLVM but new to TVM. I would like to implement my own tvm backend to codegen custom llvm-IR.

However, it seems that the infrastructure API has vastly changed in the latest release version(0.20.0). I’m looking for an equivalent api of the get_source method in relay to obtain the generated llvm-ir in text form, namely:

with tvm.transform.PassContext(opt_level=3):
    lib=tvm.build(mod, target=target)
print(lib.get_source())

I’ve been struggling to find the correct api. I’ve managed to generate VirtualMachine code, and run the Module with the VirtualMachine, but still have no idea to obtain the llvm-ir. Here is my code

from tvm.script import ir as I
from tvm.script import relax as R
import tvm.relax as relax
import tvm
import numpy as np

@I.ir_module
class InputModule:
    @R.function
    def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")):
        z = R.add(x, y)
        return z

mod2= InputModule
mod2.show()
target=tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu()
ex=relax.build(mod2, target=target)
vm=relax.VirtualMachine(ex, dev)

data1=tvm.nd.array(np.random.randint(1,2,size=(3,4)).astype(np.float32))
data2=tvm.nd.array(np.random.randint(5,6,size=(3,4)).astype(np.float32))
nd_res=vm["foo"](data1, data2)
print(nd_res)

Outputs:[[6. 6. 6. 6.], [6. 6. 6. 6.],[6. 6. .6 .6]]

Hello,
Once I’ve managed to get an MLF archive with C layers implementation and some kind of “flow” pseudocode with their calls using code like this:

import onnx
import tvm
from tvm import relax
from tvm.relax.frontend import detach_params
from tvm.relax.frontend.onnx import from_onnx

onnx_model_path = "models/model.onnx"
onnx_model = onnx.load(onnx_model_path)

input_name = "input4"
input_shape = (1, 12, 40, 1)  
shape_dict = {input_name: input_shape}

mod_from_onnx = from_onnx(onnx_model, shape_dict)
mod, params = detach_params(mod_from_onnx)

target = tvm.target.Target("c")
ex = tvm.compile(mod, target)

ex.export_library("mlf.tar")
flow = ex._as_text()
print(flow)

The “flow” pseudocode looked like this:

@main:
  call  vm.builtin.check_tensor_info in: %0, i4, c[0], c[1] dst: %void
  call  vm.builtin.match_shape in: %0, %void, i4, i0, i1, i0, i12, i0, i40, i0, i1, c[1] dst: %void
  call  vm.builtin.reshape in: %0, c[2]     dst: %1
  call  vm.builtin.alloc_storage in: %vm, c[3], i0, c[4], c[5] dst: %2
  call  vm.builtin.alloc_tensor in: %2, i0, c[6], c[7] dst: %3
  call  conv2d           in: %1, c[8], %3 dst: %void
  call  vm.builtin.null_value in:              dst: %1
  call  vm.builtin.reshape in: c[9], c[10]  dst: %4
  call  vm.builtin.alloc_storage in: %vm, c[3], i0, c[11], c[5] dst: %5
  call  vm.builtin.alloc_tensor in: %5, i0, c[6], c[12] dst: %6
  call  add              in: %3, %4, %6   dst: %void
  call  vm.builtin.null_value in:              dst: %3
  call  vm.builtin.null_value in:              dst: %4
  call  vm.builtin.alloc_tensor in: %2, i0, c[6], c[13] dst: %7
  call  relu             in: %6, %7       dst: %void
  call  vm.builtin.null_value in:              dst: %6
  call  vm.builtin.alloc_tensor in: %5, i0, c[14], c[15] dst: %8
  call  conv2d1          in: %7, c[16], %8 dst: %void
  call  vm.builtin.null_value in:              dst: %7
  call  vm.builtin.reshape in: c[17], c[18] dst: %9
  call  vm.builtin.alloc_tensor in: %2, i0, c[14], c[19] dst: %10
  call  add1             in: %8, %9, %10  dst: %void
  call  vm.builtin.null_value in:              dst: %8
  call  vm.builtin.null_value in:              dst: %9
  call  vm.builtin.alloc_tensor in: %5, i0, c[14], c[20] dst: %11
  call  relu1            in: %10, %11     dst: %void
...
@vm.builtin.check_tensor_info packed_func;
@vm.builtin.match_shape packed_func;
@vm.builtin.reshape packed_func;
@vm.builtin.alloc_storage packed_func;
@vm.builtin.alloc_tensor packed_func;
@conv2d packed_func;
@vm.builtin.null_value packed_func;
@add packed_func;
@relu packed_func;
@conv2d1 packed_func;
...

And the layers implementation in lib.c and devc.c in the tar archive looked similar to the ones that we were getting with microTVM tvmc tool. However, we have not built and tested it in reality, as we are still using tvmc installed from the v0.18.0 branch on Github.

Also I’ve seen mod.script() function that gives this style of output:

# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(conv_1_input: R.Tensor((1, 12, 40, 1), dtype="float32")) -> R.Tensor((1, 1, 40, 1), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv: R.Tensor((1, 1, 12, 40), dtype="float32") = R.reshape(conv_1_input, R.shape([1, 1, 12, 40]))
            lv1: R.Tensor((1, 6, 1, 40), dtype="float32") = R.nn.conv2d(lv, metadata["relax.expr.Constant"][0], strides=[12, 1], padding=[0, 2, 0, 3], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][1], R.shape([1, 6, 1, 1]))
            lv3: R.Tensor((1, 6, 1, 40), dtype="float32") = R.add(lv1, lv2)
            lv4: R.Tensor((1, 6, 1, 40), dtype="float32") = R.nn.relu(lv3)
            lv5: R.Tensor((1, 9, 1, 40), dtype="float32") = R.nn.conv2d(lv4, metadata["relax.expr.Constant"][2], strides=[12, 1], padding=[0, 1, 0, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv6: R.Tensor((1, 9, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][3], R.shape([1, 9, 1, 1]))
            lv7: R.Tensor((1, 9, 1, 40), dtype="float32") = R.add(lv5, lv6)
            lv8: R.Tensor((1, 9, 1, 40), dtype="float32") = R.nn.relu(lv7)
            lv9: R.Tensor((1, 3, 1, 40), dtype="float32") = R.nn.conv2d(lv8, metadata["relax.expr.Constant"][4], strides=[12, 1], padding=[0, 2, 0, 3], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            ...
            lv26: R.Tensor((1, 1, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][13], R.shape([1, 1, 1, 1]))
            lv27: R.Tensor((1, 1, 1, 40), dtype="float32") = R.add(lv25, lv26)
            gv: R.Tensor((1, 1, 40, 1), dtype="float32") = R.reshape(lv27, R.shape([1, 1, 40, 1]))
            R.output(gv)
        return gv

Hope that helps in your case.

Is this what you’re looking for ?

@I.ir_module
class InputModule:
    @R.function
    def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")):
        z = R.add(x, y)
        return z

lib = tvm.relax.build(InputModule, target="llvm")

print(lib.mod.imported_modules[0].get_source())

You can also inspect C code by lowering to C

lib = tvm.relax.build(InputModule, target="c")

print(lib.mod.imported_modules[0].get_source())

Here, you have to lower Relax IR before generating LLVM, through tvm.relax.build that gives a VMExecutable of which you retrieve the module.
It’s usually simpler to get llvm code from TIR (that you can get from lowering Relax to TIR, for example with tvm.relax.transform.LegalizeOps)

@I.ir_module
class TirModule:
    @T.prim_func
    def add(x: T.Buffer((T.int64(3), T.int64(4)), "float32"), y: T.Buffer((T.int64(3), T.int64(4)), "float32"), T_add: T.Buffer((T.int64(3), T.int64(4)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(3), T.int64(4)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(x[v_ax0, v_ax1], y[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + y[v_ax0, v_ax1]

print(tvm.build(TirModule, target="c").get_source()) # or target="llvm" of course

I’d suggest looking at tir_to_runtime. This function calls codegen_build which in turn calls the appropriate target build function (e.g., target.build.llvm for the llvm target).

Yes, this solves my problem. Thank you. To me, the most confusing part is the python object hierarchy: so you write a Relax module (tvm.ir.module.IRModule), get a VMExecutable (tvm.relax.vm_build.VMExecutable) after relax.build() the module, the returned object has a method as_text() to show some kindof VM-instr representation and can be emulated on the host machine. Yet, there is a “lowered module” (tvm.runtime.module.Module) that can be either retrieved by VMExecutable.mod or directly from hand-written TE/TIR script, and only the lower Module can be built by tvm.build().

P.S. Would appreciate if these info is updated in the official tutorial sample code and well-documented.