Background and Motivation
Currently, TVM uses the tir.Simplify
pass to remove some redundant expression like nested equivalent if-condition. For example, given a simple softmax operation like
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [2, 10, 257, 1025], []),
placeholder: Buffer(placeholder_2: Pointer(float32), float32, [2, 10, 257, 1025], [])}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [2, 10, 257]), storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 6;
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024 {
if @tir.likely((floordiv(floordiv((threadIdx.x + (blockIdx.x*1024)), 257), 10) < 2), dtype=bool) {
if @tir.likely((floordiv((threadIdx.x + (blockIdx.x*1024)), 257) < 20), dtype=bool) {
if @tir.likely(((threadIdx.x + (blockIdx.x*1024)) < 5140), dtype=bool) {
T_softmax_maxelem[((blockIdx.x*1024) + threadIdx.x)] = -3.40282e+38f32
}
}
}
// ...
tir.Simplify
will simplify this to
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [2, 10, 257, 1025], []),
placeholder: Buffer(placeholder_2: Pointer(float32), float32, [2, 10, 257, 1025], [])}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [2, 10, 257]), storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 6;
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024 {
if @tir.likely((((blockIdx.x*1024) + threadIdx.x) < 5140), dtype=bool) {
T_softmax_maxelem[((blockIdx.x*1024) + threadIdx.x)] = -3.40282e+38f32
}
// ...
where three equivalent condition will be simplified to one.
However, things will be different when the given input has a dynamic shape. Current tir.Simplify
will fail given an input with dynamic shape, this is because the analyzer (actually the RewriteSimplifier
and the CanonicalSimplifier
) used by this pass lacks corresponding rules for this “non-const” situation. In the next part of this post we will continuous to use this simple softmax example to discuss this problem. We try to fix this problem by adding more rules in both RewriteSimplifier
and CanonicalSimplifier
. Currently this is still an experimental idea, if you find something wrong or improper, feel free to correct us in this post directly.
Proposal
We will show our proposal by solving this problem in our simple softmax example here. As shown in this post, we can eliminate some redundant expressions by introducing sign information into tensor shapes. But there are still some other redundant expressions that are not covered by this solution. These redundant expressions can be eliminated by the tir.Simplify
pass when the input’s shape is static as shown before. For the dynamic situation, we list the reasons that prevent tir.Simplify
from simplifying these redundancy as follows:
-
RewriteSimplifier
only has rules forIntImm
to simplifyfloordiv(x, c1) < c2
tox < c1 * c2
. - After the simplification from
floordiv(x, c1) < c2
tox < c1 * c2
, we can directly get a new constantc3 = c1 * c2
providingc1
andc2
areIntImm
. But if we are given variables (or even worse, expressions), we cannot distinguish betweenv1 * v2
andv2 * v1
.
For clarity, we use SizeVar
d0
, d1
, d2
, and d3
for the shape in our simple softmax example. The output of current tir.Simplify
is
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 {
if @tir.likely((floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), d2), d1) < d0), dtype=bool) {
if @tir.likely((floordiv(((blockIdx.x*512) + threadIdx.x), d2) < (d0*d1)), dtype=bool) {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
}
}
}
To simplify floordiv(((blockIdx.x*512) + threadIdx.x), d2) < (d0*d1)
and floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), d2), d1) < d0
to (blockIdx.x*512) + threadIdx.x < ((d0*d1)*d2)
, we add a new rule in PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op);
:
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
// ...
PVar<PrimExpr> x, y, z, s1, s2;
// ...
TVM_TRY_REWRITE_IF(floordiv(x, s1) < s2, x < s1 * s2,
analyzer_->const_int_bound(s1.Eval())->min_value >= 0);
// ...
}
Here comes our first worry. The corresponding IntImm
version of this rule is
TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), c1.Eval()->value > 0);
where the if-condition is >0
instead of >=0
. For PrimExpr
version we can only get the non-negative information from the ConstIntBoundAnalyzer
(this bound information comes from simple facts like SizeVar + SizeVar >= 0
or SizeVar * SizeVar >= 0
). Although =0
is an invalid case, this transformation is not equivalent and may hide some run-time errors. This post shows some possible solutions for this issue but there may be some simpler solutions (if you have any idea, please share with us).
After adding this rule for rewrite simplify, we get:
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < (d2*(d1*d0))), dtype=bool) {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < (d2*(d0*d1))), dtype=bool) {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
}
}
}
Now the question becomes how to distinguish between (d2*(d1*d0))
, (d2*(d0*d1))
, and ((d0*d1)*d2)
. We add a rule in PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op);
to get a canonical form of multiplication:
PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) {
// ...
// normalize
PrimExpr a = this->CanonicalMutate(op->a);
PrimExpr b = this->CanonicalMutate(op->b);
// ...
// var * expr => expr * var
if (a.as<VarNode>() && !b.as<VarNode>()) {
std::swap(a, b);
}
// if given var * var or expr * expr, use their
// structural hash value to sort
if (a.as<VarNode>() || !b.as<VarNode>()) {
auto ah = StructuralHash()(a);
auto bh = StructuralHash()(b);
if (ah > bh) {
std::swap(a, b);
}
}
// ...
}
I think the method that uses the structural hash value for sorting is a bit ugly, but I have no other better idea currently. After add this rule for canonical simplify, we get:
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
}
}
}
Next we need to find a way to remove these literally equivalent expressions. Actually in RewriteSimplify
there is such a mechanism:
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<CallNode>();
// ...
ExprDeepEqual expr_equal;
if (op->op.same_as(tir::builtin::likely())) {
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (expr_equal(constraint, op->args[0])) {
return make_const(op->dtype, true);
}
}
}
return ret;
}
However, all constraint
s in literal_constraints_
have been processed by the CanonicalSimplify
when enter this constraint:
Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
PrimExpr condition = this->VisitExpr(op->condition); // HERE
PrimExpr real_condition = condition;
static auto op_likely = Op::Get("tir.likely");
if (auto call = condition.as<CallNode>()) {
if (call->op.same_as(op_likely)) {
real_condition = call->args[0];
}
}
Stmt then_case, else_case;
{
With<ConstraintContext> ctx(analyzer_, real_condition);
then_case = this->VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
With<ConstraintContext> ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition)));
else_case = this->VisitStmt(op->else_case);
}
// ...
}
while op->args[0]
is not since we are in the RewriteSimplify
and the CanonicalSimplify
is behind this process:
PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
if (tir::is_const_int(expr)) return expr;
PrimExpr res = expr;
for (int i = 0; i < steps; ++i) {
res = this->rewrite_simplify(res); // RewriteSimplify
if (tir::is_const_int(res) || ++i == steps) return res; // is ++i proper here?
res = this->canonical_simplify(res); // CanonicalSimplify
if (tir::is_const_int(res)) return res;
}
return res;
}
This will make op->args[0]
looks something like ((threadIdx.x: int32 + (blockIdx.x: int32*512)) < (((d1: int32*d0: int32)*d2: int32)*d3: int32))
while the constraint looks like (((blockIdx.x: int32*512) + threadIdx.x: int32) < (((d1: int32*d0: int32)*d2: int32)*d3: int32))
. To solve this problem, we can perform a canonical simplify before the comparison
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
// ...
ExprDeepEqual expr_equal;
if (op->op.same_as(tir::builtin::likely())) {
auto condition = analyzer_->canonical_simplify(op->args[0]);
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (expr_equal(constraint, condition)) {
return make_const(op->dtype, true);
}
}
}
return ret;
}
After that we get:
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax", "tir.noalias": True}
buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0: int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32, stride_2: int32, stride_3: int32], type="auto"),
T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32), float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32, stride_7: int32], type="auto")}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]), storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 512 {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)), dtype=bool) {
T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
}
Again, this is only an experimental idea and there are still some issues to be solved. If you have any better ideas, please feel free to suggest below.