Calling a PrimFunc from TVM IR

I am trying to write an IR kernel that invokes a PrimFunc as part of an exercise for my own understanding of TVM.

The implemented code should get a 2D matrix and clear with zeros one row.

Here is the code:

import tvm
from tvm import relay
from tvm import te
import numpy as np

# TVM setup
target = tvm.target.Target(target="llvm", host="llvm")
device = tvm.device(target.kind.name, 0)

A = te.placeholder((3,6), name="A")
ROW = te.placeholder((1,), dtype="int32", name="ROW")
B = te.compute(
     shape=[3,6],
     fcompute=lambda i,j: te.if_then_else(i == ROW[0], 0, A[i,j]),
     name="B")
    
schedule = te.create_schedule(B.op)
clear_row = tvm.build(schedule, [A, ROW, B], target, name="clear_row")

prim_func = tvm.te.create_prim_func([A,ROW,B])
 
mod = tvm.ir.IRModule()

opFnVar = relay.GlobalVar("opFn")
x0 = relay.var('x', shape=(n, m), dtype="float32")
mod.update_func(opFnVar, prim_func)

mainFnVar = relay.GlobalVar("main")
x = relay.var('a', shape=(n, m))
row = relay.var('row', shape=[1])
sb = relay.ScopeBuilder()
iters = relay.const([niters], dtype="int32")
b = sb.let('var', relay.const([[0,0,0,0,0,0],[0,0,0,0,0,0],[0,0,0,0,0,0],], dtype="int32"))
tmp = sb.let("result", relay.Call(opFnVar, [x, row, b]))
sb.ret(tmp)
mainFn = relay.Function([x, row], sb.get(), tvm.ir.TensorType([n,m]))
mod.update_func(mainFnVar, mainFn)

print(mod)

model = tvm.relay.create_executor(
    kind="vm", mod=mod, target=target, device=device
).evaluate()

result = model([np.array([[1,1,1,1,1,1],[1,1,1,1,1,1],[1,1,1,1,1,1]]), np.array([0])], dtype="float32")

print(result)

It is a simple TVM IR program that should invoke a Tensor IR kernel defined as a tensor expression. The resulting printed IR follows:

def @main(%a: Tensor[(3, 6), float32], %row: Tensor[(1), float32]) -> Tensor[(3, 6), float32] {
  let %var = meta[relay.Constant][0];
  let %result = @opFn(%a, %row, %var);
  %result
}

@opFn = primfn(var_A: handle, var_ROW: handle, var_B: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_1: Pointer(global float32), float32, [3, 6], []),
             A: Buffer(A_1: Pointer(global float32), float32, [3, 6], []),
             ROW: Buffer(ROW_1: Pointer(global int32), int32, [1], [])}
  buffer_map = {var_A: A, var_ROW: ROW, var_B: B} {
  block([], "root") {
    tir.reads([])
    tir.writes([])
    for (i0: int32, 0, 3) {
      for (i1: int32, 0, 6) {
        block([3, 6], "B") as [i, j] {
          bind(i, i0)
          bind(j, i1)
          tir.reads([ROW[0], A[i, j]])
          tir.writes([B[i, j]])
          B[i, j] = @tir.if_then_else((i == ROW[0]), 0f32, A[i, j], dtype=float32)
      }
    }
}

My python code fails on the executor evaluate invoke with this error.

Check failed: (can_dispatch(n)) is false: NodeFunctor calls un-registered function on type tir.PrimFunc

Is what I am trying supported?

Hi @abel-bernabeu!

I’m not 100% sure if this will work, but you can try using the call_lowered Relay op, which I introduced in https://github.com/apache/tvm/pull/9312/.

This is our internal calling convention for lowered functions. In theory, you should be able to just stick the global var that represents the PrimFunc in as the function argument.

You’ll want to do something like this:

mainFnVar = relay.GlobalVar("main")
x = relay.var('a', shape=(n, m))
row = relay.var('row', shape=[1])
sb = relay.ScopeBuilder()
iters = relay.const([niters], dtype="int32")
b = sb.let('var', relay.const([[0,0,0,0,0,0],[0,0,0,0,0,0],[0,0,0,0,0,0],], dtype="int32"))
call_lowered = relay.op.get("call_lowered")
args = relay.Tuple([x, row, b])
tmp = sb.let("result", relay.Call(call_lowered, [opFnVar, args], attrs, span))
sb.ret(tmp)
mainFn = relay.Function([x, row], sb.get(), tvm.ir.TensorType([n,m]))
mod.update_func(mainFnVar, mainFn)

call_lowered hasn’t been set up to work with python (as you can see, I didn’t define attrs or span) so I’d suggest trying in c++. This example might be useful: https://github.com/apache/tvm/blob/main/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc#L119

Thanks for asking! Mutual function call in TIR is not supported for now. Calling TIR from Relay, as @electriclilies mentioned, is supported via call_lowered :slight_smile:

I tried using call_lowered from python, but it failed because the attrs parameter was no passed (exactly as anticipated by @electriclilies).

Moving to C++ could work for me, but I also need to re-evaluate my options for meeting the requirements while staying within budget :slight_smile:

Thanks @electriclilies and @junrushao.

Hi~ for my curiosity, are there ongoing efforts to support “Mutual function call in TIR” now? :drooling_face:

@wrongtest thanks for asking! We haven’t invested effort in enabling this particular feature yet, but mechanically you may put a GlobalVar inside a Call in terms of representation. Do you have any product need for this feature?

Thanks for reply~ We have the issue when playing TVMScript with relay together.

Sometimes we have to compose primfunc scripts similar to how we chain TE expressions. Take a simple example, if we have registered a manually written conv2d script to nn.conv2d and a biasadd script to nn.bias_add, and relay tell us they should be fused together.

Then the lowering process should be a “root” primfunc call each “sub” primfunc in order, and perform inlining to create a flattened function. But these kind of things are not officially supported yet. Also I find it hard to prototype the fuse behavior just from scratch, since I can not write a script which invoke existing primfuncs (currently it is a undefined symbol error if the call op is a GlobalVar).

The same requirement from our side. another advantage with “Mutual function call in TIR” is we can make codes composable in TIR/Tscript, which should be a strong point for an eDSL, you don’t need to have similar codes. I disscussed this requirement with @Hzfengsy offline.

Generally, we have two types of nodes/functions in IRModule: GraphOp (relay op) and TensorOp (TIR op). So, there are four interactive calls:

  • G->T (GraphOp calls TensorOp): call_lowered and Relax should natively support it.
  • G->G: Relay does a good job
  • T->T: Not implemented
  • T->G: No real-world cases

For the requirement from @wrongtest and @xqdan, I guess two paths can do it:

  • Mutate Module runtime (hard but natively)
  • Inline when compiling (easy but hacky)

I don’t have enough bandwidth to finish it, but I am happy to give a hand.

2 Likes

This is what Taichi supports for basic code composition via ti.func. See for example https://github.com/taichi-dev/taichi/blob/master/python/taichi/examples/simulation/pbf2d.py#L97

I think this is a good first step to improve the usability of TVMScript. Otherwise we have to repeat the same code for different axes etc.

I don’t think mutual function call is needed, just one way is enough (one “main” kernel calling other functions, similar to “global” vs “device” in cuda. Of course a “device” function should also be able to call other device functions). Such a simple improvement would go a long way. I believe @wrongtest’s ask can be met by this.

1 Like

to add on tir => tir call. Because TIR functions are packed Func, TIR => TIR call can be implemented via call packed. Note that we implemented a limited version of it in host to device function calls.

Note this is the solution for host to host function calls and host to device function calls. For device function calls we need to use the native function calling convention.

In our case, we should separate kind of functions(TIR vs relay/relax) from the calling convention. In general, in the tvm stack we support roughly two kind of calling conventions:

  • PackedFunc convention: used in most cases
  • Native convention(C or device native): used in low level function interactions.

From the IR(data structure pov) right now we already support TIR=>TIR function calls via calling a GlobalVar, assuming we know the calling convention of particular GlobalVar(assuming it is host side calls). Then the only delta is to have a symbol relocation pass that settle the global symbol of possibly private TIR functions, then translate these calls to call packed to these global symbols. This approach support mutual function calls

There is no need to update the module runtime or IR, so it is mainly a lowering/codegen enhancement. For device side function, we will need to generate calls that depends on the device side calling convention. Function inlining can be viewed as an optional optimization pass here.

4 Likes

I think to make this work in python, you just need to register a make function for the op and also a python constructor. Though that might be more futzing than you’re willing to do