Implicit requirements on parameter types of AssertStmt?

Hi all, I found an interesting problem about AssertStmt in tir part.

When I try to build this module about AssertStmt, I got such an error:

import tvm
from tvm import tir

v = tir.Var("v", "int32")
assert_stmt = tir.AssertStmt(v, v, tir.Evaluate(tir.const(0)))
prim_func = tir.PrimFunc({v}, assert_stmt)
mod = tvm.lower(prim_func)
tvm.build(mod)

But if I replace the condition variable with a comparison, I can build this module successfully:

It seems that although the parameter type of condition declared in the document is PrimExpr, actually there are some implicit requirements on it? Or we just implement the assert statement wrong? I’m curious the reason since this behavior works well in other languages like Python :joy::

1 Like

And I found actually we have a specific type check about message of AssertStmt, so it would be better if we also check the type of condition to prevent this error?

In python, pretty much every object you can take has some “truth value”, i.e. it can be placed in a context where a true/false condition is expected, such as an if statement or an assert statement. This is simply because python is designed to work that way.

In TIR values are not automatically convertible to true/false, so in conditions, like the one in Assert, you need something that has a boolean type. The error you’re seeing comes from LLVM, and the complaint is that the argument to the br instruction is not of type i1. What most likely happened there, is that v got converted to an llvm::Value, with a type i32, and then used directly in the branch. This is not a legal LLVM IR, and you got a verifier error.

1 Like

Thanks for your reply! @kparzysz

It seems that for LLVM, it expects a boolean value as the argument of assert statement, while it can accept other types as the argument of if statement, here is an example: image

And I found another interesting error that will occur if I also add an else_case:

import tvm
from tvm import tir

v = tir.Var("v", "int32")
assert_stmt = tir.IfThenElse(v, tir.Evaluate(tir.const(0)),tir.Evaluate(tir.const(0)))
prim_func = tir.PrimFunc({v}, assert_stmt)
print(prim_func)
mod = tvm.lower(prim_func)
tvm.build(mod)

I think TVM doesn’t check the type of condition strictly, since the function implemented in src/arith/ir_mutator_with_analyzer.cc calls Not(real_condition) directly, which means it expects the type of condition to be bool. While if we don’t pass a then_case, there will be no such implicit requirements:

Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
  // ...
  Stmt then_case, else_case;
  {
    With<ConstraintContext> ctx(analyzer_, real_condition);
    then_case = this->VisitStmt(op->then_case);
  }
  if (op->else_case.defined()) {
    With<ConstraintContext> ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition)));
    else_case = this->VisitStmt(op->else_case);
  }
  // ...
}

You’re probably right. I guess TVM could be more consistent in checking TIR for errors.