I construct a simple IRModule only with Relax function like follow:
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor(("m",), dtype="float32"), y: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), dtype="float32"):
z = R.add(x, y)
return z
and legalize module with the following Pass:
# note! "xxx" is an AMD-Like GPU, which was self-developed.
mod = Module
target = tvm.target.Target("xxx")
device = tvm.device("xxx", 0)
seq = tvm.transform.Sequential([
LegalizeOps(),
DefaultGPUSchedule(),
])
with target:
mod = seq(mod)
and I got IRModule with PrimFUnc add() as following:
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(var_A: T.handle, var_B: T.handle, var_T_add: T.handle):
T.func_attr({"tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
m = T.int64()
A = T.match_buffer(var_A, (m,))
B = T.match_buffer(var_B, (m,))
T_add = T.match_buffer(var_T_add, (m,))
# with T.block("root"):
for ax0_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
for ax0_fused_2 in T.thread_binding(T.int64(2048), thread="threadIdx.x"):
for ax0_fused_0 in range((m + T.int64(524287)) // T.int64(524288)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(m, ax0_fused_0 * T.int64(524288) + ax0_fused_1 * T.int64(2048) + ax0_fused_2)
T.where((ax0_fused_0 * T.int64(256) + ax0_fused_1) * T.int64(2048) + ax0_fused_2 < m)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[v_ax0]
@R.function
def main(x: R.Tensor(("m",), dtype="float32"), y: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), dtype="float32"):
m = T.int64()
cls = Module
z = R.call_tir(cls.add, (x, y), out_sinfo=R.Tensor((m,), dtype="float32"))
return z
while trying to build this Module with relax:
ex = relax.build(mod, target=target)
only got the following LLVM Error-Log:
1: tvm::codegen::CodeGenLLVM::Finish()
at /path/tvm/src/target/llvm/codegen_llvm.cc:366
0: tvm::codegen::CodeGenLLVM::Verify() const
at /path/tvm/src/target/llvm/codegen_llvm.cc:354
File "/path/tvm/src/target/llvm/codegen_llvm.cc", line 354
TVMError: LLVM module verification failed with the following errors:
Both operands to a binary operator are not of the same type!
%5 = mul nsw i32 %0, i64 2048
Both operands to a binary operator are not of the same type!
%7 = add nsw i64 %6, i32 %5