How to leverage tvm.tir.transform in TVM 0.7

Hi everyone, I’ve been working for a couple of months on the TVM stack and I really like it :slight_smile:

I have a question related to the use of TVM 0.7 APIs, in particular tvm.tir.transform.Simplify.

In TVM 0.6 I could simply call tvm.tir.ir_pass.Simplify(stmt) in any of my custom IR passes to leverage the benefits of simplified arithmetics. I would then be able to apply further optimizations on top of the simplified IR. I read the documentation (here) and I’ve tried several ways to obtain the same result in TVM 0.7 but I have not found a way to leverage tvm.tir.transform.Simplify. I think I might be missing something.

What is the correct way to use this API? Could you show me an example in which you pass a tir.PrimFunc or a stmt object to the new API?

Thank you!

We have moved to having all passes use the pass infrastructure now, you should be able to just invoke Simplify on your entire IRModule before running your pass. i.e Simplify()(mod).

If you want to just simplify a single expression/stmt you can use the internal API to do so.

Here is how simplify pass is defined:

Pass Simplify() {
  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
    auto* n = f.CopyOnWrite();
    arith::Analyzer analyzer;
    n->body = arith::StmtSimplifier(&analyzer).Simplify(std::move(n->body));
    return f;
  return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
 arith::Analyzer analyzer;

Thank you @jroesch, that definitely brings in more clarity.

The way I was using this pass in its previous version was to call it on stmt within my own pass as a step after some IR transformations, and it worked flawlessly.

Looking at your answer, it seems that what I am looking for is the Python API equivalent of

 arith::Analyzer analyzer;

Basically what I was looking for is being able to run the Simplify pass in my own pass at a function or statement granularity rather than the whole module.

I think someone just needs to expose this in Python you could probably do so with

TVM_REGISTER_GLOBAL("SimplifyStmt").set_body_typed([](tir::Stmt stmt) {
    arith::Analyzer analyzer;
    return arith::StmtSimplifier(&analyzer).Simplify(stmt);

and in Python


Thank you @jroesch, this is very helpful!