import tvm
from tvm.script import relax as R
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def dense_loop(
CSR_VAL: T.handle,
INDICES: T.handle,
IND_PTR: T.handle,
VAL: T.handle,
VEC: T.handle,
OUT: T.handle,
):
csr_val = T.match_buffer(CSR_VAL, (8,), "float64")
indices = T.match_buffer(INDICES, (8,), "int32")
indptr = T.match_buffer(IND_PTR, (12,), "int32")
val = T.match_buffer(VAL, (37,), "float64")
vec = T.match_buffer(VEC, (11,), "float64")
out = T.match_buffer(OUT, (11,), "float64")
for i in T.serial(11):
out[i] = 0.0
for i in T.serial(11):
row_start = indptr[i]
row_end = indptr[i + 1]
for j in T.serial(row_end - row_start):
out[i] += csr_val[row_start + j] * vec[indices[row_start + j]]
for j in T.serial(2):
for i in T.serial(2):
with T.block("db0"):
T.init()
out[i + 0] += val[0 + j * 2 + i] * vec[j + 0]
for j in T.serial(1):
for i in T.serial(2):
with T.block("db1"):
T.init()
out[i + 0] += val[4 + j * 2 + i] * vec[j + 5]
@R.function
def main(data: R.Tensor(("n",), dtype="float64"), indices: R.Tensor(("m",), dtype="int32"), indptr: R.Tensor(("l",), dtype="int32"), vec: R.Tensor(("k",), dtype="float64"), val: R.Tensor(("v",), dtype="float64")):
cls = Module
out = R.call_tir(cls.dense_loop, (data, indices, indptr, val, vec), out_sinfo=R.Tensor((11,), dtype="float64"))
return out
if __name__ == "__main__":
...
ex = tvm.relax.build(Module, target="llvm")
vm = tvm.relax.VirtualMachine(ex, tvm.cpu())
val_arg = tvm.nd.array(val, device=tvm.cpu())
data_arg = tvm.nd.array(csr_val, device=tvm.cpu())
indices_arg = tvm.nd.array(indices, device=tvm.cpu())
indptr_arg = tvm.nd.array(indptr, device=tvm.cpu())
vec_arg = tvm.nd.array(x, device=tvm.cpu())
out = vm["main"](data_arg, indices_arg, indptr_arg, vec_arg, val_arg)
...
In this program, I am creating an output tensor in the line: out = R.call_tir(cls.dense_loop, (data, indices, indptr, val, vec), out_sinfo=R.Tensor((11,), dtype="float64"))
and then initializing it in def dense loop
.
Is there a way to instead allocate and initialize this array in the if __name__ == "__main__"
block and pass it as an argument to the vm
?