Would you mind helping me to do one experiment ? Just apply this patch to verify whether it works for you under the condition of multi batch:
diff --git a/python/tvm/topi/x86/injective.py b/python/tvm/topi/x86/injective.py
index 6492b78d6..0162780c1 100644
--- a/python/tvm/topi/x86/injective.py
+++ b/python/tvm/topi/x86/injective.py
@@ -37,7 +37,7 @@ def schedule_injective_from_existing(sch, out):
The updated schedule.
"""
if len(sch[out].op.axis) >= 5:
- fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2])
+ fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1])
sch[out].parallel(fused)
elif len(sch[out].op.axis) >= 3:
fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1])