Int32/Int64 issue when codegen into llvm::Function?

I construct a simple IRModule only with Relax function like follow:

    @I.ir_module
    class Module:
        @R.function
        def main(x: R.Tensor(("m",), dtype="float32"), y: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), dtype="float32"):
            z = R.add(x, y)
            return z

and legalize module with the following Pass:

# note! "xxx" is an AMD-Like GPU, which was self-developed.
mod = Module 
target = tvm.target.Target("xxx")
device = tvm.device("xxx", 0)

seq = tvm.transform.Sequential([
    LegalizeOps(),
    DefaultGPUSchedule(),
])
with target:
    mod = seq(mod)    

and I got IRModule with PrimFUnc add() as following:

    @I.ir_module

    class Module:

        @T.prim_func(private=True)

        def add(var_A: T.handle, var_B: T.handle, var_T_add: T.handle):

            T.func_attr({"tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})

            m = T.int64()

            A = T.match_buffer(var_A, (m,))

            B = T.match_buffer(var_B, (m,))

            T_add = T.match_buffer(var_T_add, (m,))

            # with T.block("root"):

            for ax0_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"):

                for ax0_fused_2 in T.thread_binding(T.int64(2048), thread="threadIdx.x"):

                    for ax0_fused_0 in range((m + T.int64(524287)) // T.int64(524288)):

                        with T.block("T_add"):

                            v_ax0 = T.axis.spatial(m, ax0_fused_0 * T.int64(524288) + ax0_fused_1 * T.int64(2048) + ax0_fused_2)

                            T.where((ax0_fused_0 * T.int64(256) + ax0_fused_1) * T.int64(2048) + ax0_fused_2 < m)

                            T.reads(A[v_ax0], B[v_ax0])

                            T.writes(T_add[v_ax0])

                            T_add[v_ax0] = A[v_ax0] + B[v_ax0]

        @R.function

        def main(x: R.Tensor(("m",), dtype="float32"), y: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), dtype="float32"):

            m = T.int64()

            cls = Module

            z = R.call_tir(cls.add, (x, y), out_sinfo=R.Tensor((m,), dtype="float32"))

            return z

while trying to build this Module with relax:

ex = relax.build(mod, target=target)

only got the following LLVM Error-Log:

    1: tvm::codegen::CodeGenLLVM::Finish()
        at /path/tvm/src/target/llvm/codegen_llvm.cc:366
0: tvm::codegen::CodeGenLLVM::Verify() const
        at /path/tvm/src/target/llvm/codegen_llvm.cc:354
File "/path/tvm/src/target/llvm/codegen_llvm.cc", line 354
TVMError: LLVM module verification failed with the following errors: 
Both operands to a binary operator are not of the same type!
%5 = mul nsw i32 %0, i64 2048
Both operands to a binary operator are not of the same type!
%7 = add nsw i64 %6, i32 %5

I dived into this issue, and figured out the root cause: there are T.int64() data in T.thread_binding(), leading to the ForFrame’s iter_var dtype mismatched with int64.

for ax0_fused_2 in T.thread_binding(T.int64(2048), thread="threadIdx.x"):

in /path/tvm/src/script/ir_builder/tir/ir.cc:352:

n->vars’s dtype is bits, why iter_var’s dtype is fixed to DataType::Int(32) ?

I guess this might be a bug, or the reason behind the issue above.

ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
                    Optional<Map<String, ObjectRef>> annotations) {
using namespace tvm::tir;
PrimExpr min = start;
PrimExpr extent = arith::Analyzer().Simplify(stop - start);
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
int bits = std::max(min.dtype().bits(), extent.dtype().bits());
n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))};
n->doms = {Range::FromMinExtent(min, extent)};
n->f_make_for_loop = [annotations, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For {
    ICHECK_EQ(vars.size(), 1);
    ICHECK_EQ(doms.size(), 1);
    IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex,
                    thread);
    return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
            annotations.value_or(Map<String, ObjectRef>()));
};
return ForFrame(n);
}

I indeed had found a Pass called ForceNarrowIndexToInt32() that could forcibly cast it to T.int32, but I think it shouldn’t be patched using a Pass. Instead, some checks or conversion mechanisms should be added at a lower level to ensure that the generated LLVM IR is valid.

Indeed in this case seems we likely need to make sure that the iter var to follow the dtype of of the bound.

also cc @junrushao @Hzfengsy

This line is updated in https://github.com/apache/tvm/pull/15547. I would take a look. Thanks for your report :slight_smile:

It seems that I missed the dtype of iter var when I modified it. Need to keep it same as n->var.

It is tricky, as the user explicitly define T.int64(2048), it SHOULD be a INT64 var instead of narrowing into int32.

So, Could we use ForceNarrowIndexToInt32() to Force IR to int32 in order to avoid “Int32/Int64” compile issue? or may cause unknown side-effect?

P.S: “T.int64(2048)” is more likely auto generated by TVMScript Printer, not user explicitly define, despite we can do it, but we tend to write IRModule with only R.function :grinning_face_with_smiling_eyes:

Fixed in https://github.com/apache/tvm/pull/16041