[Dynamic Shape] Better simplify support for dynamic boundary check

Background and Motivation

Currently, TVM uses the tir.Simplify pass to remove some redundant expression like nested equivalent if-condition. For example, given a simple softmax operation like

primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
  buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [2, 10, 257, 1025], []),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [2, 10, 257, 1025], [])}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [2, 10, 257]), storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 6;
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024 {
      if @tir.likely((floordiv(floordiv((threadIdx.x + (blockIdx.x*1024)), 257), 10) < 2), dtype=bool) {
        if @tir.likely((floordiv((threadIdx.x + (blockIdx.x*1024)), 257) < 20), dtype=bool) {
          if @tir.likely(((threadIdx.x + (blockIdx.x*1024)) < 5140), dtype=bool) {
            T_softmax_maxelem[((blockIdx.x*1024) + threadIdx.x)] = -3.40282e+38f32
          }
        }
      }
      // ...

tir.Simplify will simplify this to

primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
  buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [2, 10, 257, 1025], []),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [2, 10, 257, 1025], [])}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [2, 10, 257]), storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 6;
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024 {
      if @tir.likely((((blockIdx.x*1024) + threadIdx.x) < 5140), dtype=bool) {
        T_softmax_maxelem[((blockIdx.x*1024) + threadIdx.x)] = -3.40282e+38f32
      }
      // ...

where three equivalent condition will be simplified to one.

However, things will be different when the given input has a dynamic shape. Current tir.Simplify will fail given an input with dynamic shape, this is because the analyzer (actually the RewriteSimplifier and the CanonicalSimplifier) used by this pass lacks corresponding rules for this “non-const” situation. In the next part of this post we will continuous to use this simple softmax example to discuss this problem. We try to fix this problem by adding more rules in both RewriteSimplifier and CanonicalSimplifier. Currently this is still an experimental idea, if you find something wrong or improper, feel free to correct us in this post directly.

Proposal

We will show our proposal by solving this problem in our simple softmax example here. As shown in this post, we can eliminate some redundant expressions by introducing sign information into tensor shapes. But there are still some other redundant expressions that are not covered by this solution. These redundant expressions can be eliminated by the tir.Simplify pass when the input’s shape is static as shown before. For the dynamic situation, we list the reasons that prevent tir.Simplify from simplifying these redundancy as follows:

  1. RewriteSimplifier only has rules for IntImm to simplify floordiv(x, c1) < c2 to x < c1 * c2.
  2. After the simplification from floordiv(x, c1) < c2 to x < c1 * c2, we can directly get a new constant c3 = c1 * c2 providing c1 and c2 are IntImm. But if we are given variables (or even worse, expressions), we cannot distinguish between v1 * v2 and v2 * v1.

For clarity, we use SizeVar d0, d1, d2, and d3 for the shape in our simple softmax example. The output of current tir.Simplify is

primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
  buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
             T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 {
      if @tir.likely((floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), d2), d1) < d0), dtype=bool) {
        if @tir.likely((floordiv(((blockIdx.x*512) + threadIdx.x), d2) < (d0*d1)), dtype=bool) {
          if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
            T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
          }
        }
      }

To simplify floordiv(((blockIdx.x*512) + threadIdx.x), d2) < (d0*d1) and floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), d2), d1) < d0 to (blockIdx.x*512) + threadIdx.x < ((d0*d1)*d2), we add a new rule in PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op);:

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
    // ...
    PVar<PrimExpr> x, y, z, s1, s2;
    // ...
    TVM_TRY_REWRITE_IF(floordiv(x, s1) < s2, x < s1 * s2,
                       analyzer_->const_int_bound(s1.Eval())->min_value >= 0);
    // ...
}

Here comes our first worry. The corresponding IntImm version of this rule is

TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), c1.Eval()->value > 0);

where the if-condition is >0 instead of >=0. For PrimExpr version we can only get the non-negative information from the ConstIntBoundAnalyzer (this bound information comes from simple facts like SizeVar + SizeVar >= 0 or SizeVar * SizeVar >= 0). Although =0 is an invalid case, this transformation is not equivalent and may hide some run-time errors. This post shows some possible solutions for this issue but there may be some simpler solutions (if you have any idea, please share with us).

After adding this rule for rewrite simplify, we get:

primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
  buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
             T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 {
      if @tir.likely((((blockIdx.x*512) + threadIdx.x) < (d2*(d1*d0))), dtype=bool) {
        if @tir.likely((((blockIdx.x*512) + threadIdx.x) < (d2*(d0*d1))), dtype=bool) {
          if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
            T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
          }
        }
      }

Now the question becomes how to distinguish between (d2*(d1*d0)), (d2*(d0*d1)), and ((d0*d1)*d2). We add a rule in PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op); to get a canonical form of multiplication:

PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) {
  // ...
  // normalize
  PrimExpr a = this->CanonicalMutate(op->a);
  PrimExpr b = this->CanonicalMutate(op->b);

  // ...

  // var * expr => expr * var
  if (a.as<VarNode>() && !b.as<VarNode>()) {
    std::swap(a, b);
  }
  
  // if given var * var or expr * expr, use their
  // structural hash value to sort
  if (a.as<VarNode>() || !b.as<VarNode>()) {
    auto ah = StructuralHash()(a);
    auto bh = StructuralHash()(b);
    if (ah > bh) {
      std::swap(a, b);
    }
  }
  
  // ...
}

I think the method that uses the structural hash value for sorting is a bit ugly, but I have no other better idea currently. After add this rule for canonical simplify, we get:

primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
  buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
             T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 {
      if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
        if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
          if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
            T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
          }
        }
      }

Next we need to find a way to remove these literally equivalent expressions. Actually in RewriteSimplify there is such a mechanism:

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
  // add condition context to if_then_else
  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
  op = ret.as<CallNode>();

  // ...

  ExprDeepEqual expr_equal;
  if (op->op.same_as(tir::builtin::likely())) {
    for (const auto& constraint : literal_constraints_) {
      // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
      if (expr_equal(constraint, op->args[0])) {
        return make_const(op->dtype, true);
      }
    }
  }
  return ret;
}

However, all constraints in literal_constraints_ have been processed by the CanonicalSimplify when enter this constraint:

Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
  PrimExpr condition = this->VisitExpr(op->condition); // HERE
  PrimExpr real_condition = condition;
  static auto op_likely = Op::Get("tir.likely");

  if (auto call = condition.as<CallNode>()) {
    if (call->op.same_as(op_likely)) {
      real_condition = call->args[0];
    }
  }

  Stmt then_case, else_case;
  {
    With<ConstraintContext> ctx(analyzer_, real_condition);
    then_case = this->VisitStmt(op->then_case);
  }
  if (op->else_case.defined()) {
    With<ConstraintContext> ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition)));
    else_case = this->VisitStmt(op->else_case);
  }
  
  // ...
}

while op->args[0] is not since we are in the RewriteSimplify and the CanonicalSimplify is behind this process:

PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
  if (tir::is_const_int(expr)) return expr;
  PrimExpr res = expr;
  for (int i = 0; i < steps; ++i) {
    res = this->rewrite_simplify(res);                       // RewriteSimplify
    if (tir::is_const_int(res) || ++i == steps) return res;  // is ++i proper here?
    res = this->canonical_simplify(res);                     // CanonicalSimplify
    if (tir::is_const_int(res)) return res;
  }
  return res;
}

This will make op->args[0] looks something like ((threadIdx.x: int32 + (blockIdx.x: int32*512)) < (((d1: int32*d0: int32)*d2: int32)*d3: int32)) while the constraint looks like (((blockIdx.x: int32*512) + threadIdx.x: int32) < (((d1: int32*d0: int32)*d2: int32)*d3: int32)). To solve this problem, we can perform a canonical simplify before the comparison

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
  // ...
  ExprDeepEqual expr_equal;
  if (op->op.same_as(tir::builtin::likely())) {
    auto condition = analyzer_->canonical_simplify(op->args[0]);
    for (const auto& constraint : literal_constraints_) {
      // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
      if (expr_equal(constraint, condition)) {
        return make_const(op->dtype, true);
      }
    }
  }
  return ret;
}

After that we get:

primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
  buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
             T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
  buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
  allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global {
    attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 {
      if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
        T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
      }

Again, this is only an experimental idea and there are still some issues to be solved. If you have any better ideas, please feel free to suggest below.

5 Likes

We are working on Dyn-shape model compilation with Relay VM, and currently running into this issue, although this redundant Expr may be eliminated through compilation with clang/gcc O2 option (@lygztq has located the LLVM::EarlyCSE pass may help), but we hope to fix it within TIR scope, avoid mis-conception for us developers.

2 Likes

Thanks for the proposal. I agree that it is a valuable problem for dynamic shape.

Here are two questions from me:

  1. Is it necessary to rewrite (d1*d2)*d0 into d0*d1*d2. Can we prove them equal by Analyzer directly?
  2. Can we embed the new rule into tir.Simplify rather than create a new method RewriteSimplifier?

Thanks again for the great work and proposal!

1 Like

Thanks for the reply.

I think proving two expressions by Analyzer directly is another reasonable solution, but there may be a higher time overhead compared with the “rewrite” method. Currently, the StructuralEqual mechanism has limit ability (only remap-var case) to distinguish these expressions. Overall, this is an idea worth trying and I will try to write a new mechanism similar to the StructuralEqual to solve this problem.

For your second question, actually the RewriteSimplifier is part of the Analyzer. But it is true that adding the new rule directly into tir.Simplify may be a better way and I will try it.

Thanks again for your excellent suggestions!

1 Like