Thanks for the answers.
I found where the attribute for exp()
op is registered at priority 10. It’s in python/tvm/relay/op/_tensor.py
register_shape_func("exp", False, elemwise_shape_func)
def register_shape_func(op_name, data_dependent, shape_func=None, level=10):
"""Register operator shape function for an op.
Now the progress is moving on.
I made a mistake on the tensor shape, with setting the correct tensor shape, the TVM exp() op’s performance is not as good as TF or numpy’s.
(1) In the fused op in questions, the tensor shape is [100, 120, 120], float32. Running 400 times, it costs 6.02s.
(2) TVM tir.exp() with the same tensor shape without subtract, it costs 5322ms. (which is bad)
(3) numpy exp() with the same tensor shape and the same amount of computation, it costs 1933ms.
So I think the problem is no longer a fused op one now. The tir.exp() op itself is slower than TF and numpy. I tested fast_exp(), the performance is similiar with tir.exp().
The Relay IR and TIR for the subtract exp fuse op is as following:
Relay IR:
%450 = subtract(%448, %449) /* StatefulPartitionedCall/functional_1/activation_2/sub */ /* ty=Tensor[(100, 120, 120), float32] */;
%451 = exp(%450) /* StatefulPartitionedCall/functional_1/activation_2/Exp */ /* ty=Tensor[(100, 120, 120), float32] */;
tir:
, GlobalVar(tvmgen_default_fused_subtract_exp): PrimFunc([placeholder, placeholder, T_exp]) attrs={"from_legacy_te_schedule": (bool)1, "global_symbol": "tvmgen_default_fused_subtract_exp", "tir.noalias": (bool)1, "target": llvm -keys=cpu -libs=dnnl -link-params=0 -mattr=avx,avx2,sse3,sse4.2,fma,avx512er,avx512f -mcpu=x86-64 -opt-level=3} {
parallel (ax0.ax1.fused, 0, 12000) {
for (ax2.outer, 0, 8) {
for (ax2.inner.s, 0, 16) {
if ((((ax2.outer*16) + ax2.inner.s) < 120)) {
T_exp[(((ax0.ax1.fused*120) + (ax2.outer*16)) + ax2.inner.s)] = tir.exp((placeholder[(((ax0.ax1.fused*120) + (ax2.outer*16)) + ax2.inner.s)] - placeholder[ax0.ax1.fused]))
}
}
}
}
}
I still have a question why in the fused tir, the inner loop and outer loop are 16 and 8, instead of 100, 120, 120.