Simpler example (cant edit the original post):
import tvm
from tvm.script import relax as R
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def dense_loop(CSR_VAL: T.handle, OUT: T.handle):
csr_val = T.match_buffer(CSR_VAL, (8,), "float64")
out = T.match_buffer(OUT, (8,), "float64")
for i in T.serial(8):
out[i] = csr_val[i] * 2.0
@R.function
def main(data: R.Tensor(("n",), dtype="float64")):
cls = Module
out = R.call_tir(cls.dense_loop, (data,), out_sinfo=R.Tensor((8,), dtype="float64"))
return out
if __name__ == "__main__":
...
ex = tvm.relax.build(Module, target="llvm")
vm = tvm.relax.VirtualMachine(ex, tvm.cpu())
data_arg = tvm.nd.array(csr_val, device=tvm.cpu())
out = vm["main"](data_arg)
...