I am trying to write an IR kernel that invokes a PrimFunc as part of an exercise for my own understanding of TVM.
The implemented code should get a 2D matrix and clear with zeros one row.
Here is the code:
import tvm
from tvm import relay
from tvm import te
import numpy as np
# TVM setup
target = tvm.target.Target(target="llvm", host="llvm")
device = tvm.device(target.kind.name, 0)
A = te.placeholder((3,6), name="A")
ROW = te.placeholder((1,), dtype="int32", name="ROW")
B = te.compute(
shape=[3,6],
fcompute=lambda i,j: te.if_then_else(i == ROW[0], 0, A[i,j]),
name="B")
schedule = te.create_schedule(B.op)
clear_row = tvm.build(schedule, [A, ROW, B], target, name="clear_row")
prim_func = tvm.te.create_prim_func([A,ROW,B])
mod = tvm.ir.IRModule()
opFnVar = relay.GlobalVar("opFn")
x0 = relay.var('x', shape=(n, m), dtype="float32")
mod.update_func(opFnVar, prim_func)
mainFnVar = relay.GlobalVar("main")
x = relay.var('a', shape=(n, m))
row = relay.var('row', shape=[1])
sb = relay.ScopeBuilder()
iters = relay.const([niters], dtype="int32")
b = sb.let('var', relay.const([[0,0,0,0,0,0],[0,0,0,0,0,0],[0,0,0,0,0,0],], dtype="int32"))
tmp = sb.let("result", relay.Call(opFnVar, [x, row, b]))
sb.ret(tmp)
mainFn = relay.Function([x, row], sb.get(), tvm.ir.TensorType([n,m]))
mod.update_func(mainFnVar, mainFn)
print(mod)
model = tvm.relay.create_executor(
kind="vm", mod=mod, target=target, device=device
).evaluate()
result = model([np.array([[1,1,1,1,1,1],[1,1,1,1,1,1],[1,1,1,1,1,1]]), np.array([0])], dtype="float32")
print(result)
It is a simple TVM IR program that should invoke a Tensor IR kernel defined as a tensor expression. The resulting printed IR follows:
def @main(%a: Tensor[(3, 6), float32], %row: Tensor[(1), float32]) -> Tensor[(3, 6), float32] {
let %var = meta[relay.Constant][0];
let %result = @opFn(%a, %row, %var);
%result
}
@opFn = primfn(var_A: handle, var_ROW: handle, var_B: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {B: Buffer(B_1: Pointer(global float32), float32, [3, 6], []),
A: Buffer(A_1: Pointer(global float32), float32, [3, 6], []),
ROW: Buffer(ROW_1: Pointer(global int32), int32, [1], [])}
buffer_map = {var_A: A, var_ROW: ROW, var_B: B} {
block([], "root") {
tir.reads([])
tir.writes([])
for (i0: int32, 0, 3) {
for (i1: int32, 0, 6) {
block([3, 6], "B") as [i, j] {
bind(i, i0)
bind(j, i1)
tir.reads([ROW[0], A[i, j]])
tir.writes([B[i, j]])
B[i, j] = @tir.if_then_else((i == ROW[0]), 0f32, A[i, j], dtype=float32)
}
}
}
My python code fails on the executor evaluate invoke with this error.
Check failed: (can_dispatch(n)) is false: NodeFunctor calls un-registered function on type tir.PrimFunc
Is what I am trying supported?