While deploying the bert model with relax frontend,I found a bug in Squeeze OpConvert,we could directly fetch the axis info if there are in attr argument. the PR is here: https://github.com/apache/tvm/pull/16059
and there is another strange problem while leaglize R.squeeze(),my model is bert-base-uncased-squad-v1.onnx, part of model is:
@I.ir_module
class Module:
@R.function
def main(input_ids: R.Tensor(("batch", "sequence"), dtype="int64"), attention_mask: R.Tensor(("batch", "sequence"), dtype="int64"), token_type_ids: R.Tensor(("batch", "sequence"), dtype="int64")) -> R.Tuple(R.Tensor(("batch", "sequence"), dtype="float32"), R.Tensor(("batch", "sequence"), dtype="float32")):
batch = T.int64()
sequence = T.int64()
R.func_attr({"num_input": 3})
cls = Module
with R.dataflow():
...
lv623: R.Tensor((batch, sequence, 1), dtype="float32") = lv622[0]
lv624: R.Tensor((batch, sequence, 1), dtype="float32") = lv622[1]
lv625: R.Tensor((batch, sequence), dtype="float32") = R.squeeze(lv623, axis=None)
lv626: R.Tensor((batch, sequence), dtype="float32") = R.squeeze(lv624, axis=None)
gv: R.Tuple(R.Tensor((batch, sequence), dtype="float32"), R.Tensor((batch, sequence), dtype="float32")) = lv625, lv626
R.output(gv)
return gv
when we put axis=None to R.squeeze(), although the input tensor have symbol shape, the LegalizeOps() still dose not work, that means we could’t get correspondingly squeeze prim_func.
I have found some UT about R.squeeze() in test_transform_legalize_ops_manipulate.py, it works. so I doubt there may some thing wrong with LegalizeOps() if our give a quite complex IRModule.
def test_squeeze_no_axis():
# fmt: off
@tvm.script.ir_module
class Squeeze:
@R.function
def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) :
gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) :
gv = R.call_tir(Expected.squeeze, (x,), R.Tensor((2, 3, 4), dtype="float32"))
return gv
@T.prim_func(private=True)
def squeeze(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3), T.int64(1), T.int64(1), T.int64(4)), "float32"), T_squeeze: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")):
T.func_attr({"tir.noalias": True})
for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)):
with T.block("T_squeeze"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[ax0, T.int64(0), ax1, T.int64(0), T.int64(0), ax2])
T.writes(T_squeeze[ax0, ax1, ax2])
T_squeeze[ax0, ax1, ax2] = rxplaceholder[ax0, T.int64(0), ax1, T.int64(0), T.int64(0), ax2]
# fmt: on
mod = LegalizeOps()(Squeeze)
tvm.ir.assert_structural_equal(mod, Expected)