Hi all,
When EliminateCommonSubexpr
pass runs on an IR where there are 2 call_tir nodes with same arguments, it replaces the tuple with a var.
Before Common Subexpression Elimination:
a = R.call_tir(cls.strided_slice, (x,), out_sinfo=R.Tensor((512,)))
b = R.call_tir(cls.strided_slice1, (x,), out_sinfo=R.Tensor((512,)))
After Common Subexpression Elimination:
lv: R.Tuple(R.Tensor((1024,))) = (x,)
a = R.call_tir(cls.strided_slice, lv, out_sinfo=R.Tensor((512,)))
b = R.call_tir(cls.strided_slice1, lv, out_sinfo=R.Tensor((512,)))
As you see, the args to both call_tir (x,)
is getting extracted out and assigned to a variable.
When we run FoldConstant
pass after this, at some point, the pass errors out with a complaint that the args for call_tir is expected to be a tuple. The definition of call_tir says that args has to be a tuple, but what the common subexpression elimination pass does also seems valid.
So my question is whether this change by EliminateCommonSubexpr
is valid or not. If it is valid, then I can modify the fold constant pass to fetch the tuple from the variable (lv
in this case), and if it’s not valid, I might have to update the subexr pass accordingly