Relay Graph of Quantized Conv1D Operation

Hello,

I am trying to figure out how the Relay Graph for a quantized Conv1D operation is constructed. It looks like this:

fn (%serving_default_conv1d_input:0: Tensor[(1, 1, 16, 16), uint8] /* span=serving_default_conv1d_input:0:0:0 */, %v_param_2: Tensor[(1, 3, 16, 16), int8] /* span=sequential/conv1d/Conv1D/Conv2D1:0:0 */, %v_param_3: Tensor[(16), int32] /* span=sequential/conv1d/Conv1D/Conv2D:0:0 */, %v_param_5: Tensor[(16), int8] /* span=sequential/conv1d/squeeze_batch_dims/BiasAdd/ReadVariableOp:0:0 */, output_tensor_names=["StatefulPartitionedCall_0"]) {
  %0 = qnn.requantize(%serving_default_conv1d_input:0, 0.0352941f /* span=tfl.quantize:0:0 */, 0 /* span=tfl.quantize:0:0 */, 0.0352941f /* span=tfl.quantize:0:0 */, -128 /* span=tfl.quantize:0:0 */, out_dtype="int8") /* span=tfl.quantize:0:0 */;
  %1 = expand_dims(%0, axis=-3) /* span=sequential/conv1d/Conv1D/ExpandDims:0:0 */;
  %2 = reshape(%1, newshape=[-1, 1, 16, 16]) /* span=sequential/conv1d/Conv1D/Reshape:0:0 */;
  %3 = qnn.conv2d(%2, %v_param_2, -128 /* span=sequential/conv1d/Conv1D/Conv2D2:0:0 */, 0 /* span=sequential/conv1d/Conv1D/Conv2D2:0:0 */, 0.0352941f /* span=sequential/conv1d/Conv1D/Conv2D2:0:0 */, 0.00196695f /* span=sequential/conv1d/Conv1D/Conv2D2:0:0 */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* span=sequential/conv1d/Conv1D/Conv2D2:0:0 */;
  %4 = nn.bias_add(%3, %v_param_3, axis=3) /* span=sequential/conv1d/Conv1D/Conv2D2:0:0 */;
  %5 = qnn.requantize(%4, 6.94219e-05f /* span=sequential/conv1d/Conv1D/Conv2D2:0:0 */, 0 /* span=sequential/conv1d/Conv1D/Conv2D2:0:0 */, 0.142336f /* span=sequential/conv1d/Conv1D/Conv2D2:0:0 */, -1 /* span=sequential/conv1d/Conv1D/Conv2D2:0:0 */, axis=3, out_dtype="int8") /* span=sequential/conv1d/Conv1D/Conv2D2:0:0 */;
  %6 = reshape(%5, newshape=[1, 1, 1, 14, 16]) /* span=sequential/conv1d/Conv1D/Reshape_1:0:0 */;
  %7 = squeeze(%6, axis=[-3]) /* span=sequential/conv1d/Conv1D/Squeeze:0:0 */;
  %8 = reshape(%7, newshape=[-1, 14, 16]) /* span=sequential/conv1d/squeeze_batch_dims/Reshape:0:0 */;
  %9 = qnn.add(%8, %v_param_5, 0.142336f /* span=sequential/conv1d/squeeze_batch_dims/BiasAdd:0:0 */, -1 /* span=sequential/conv1d/squeeze_batch_dims/BiasAdd:0:0 */, 0.00392157f /* span=sequential/conv1d/squeeze_batch_dims/BiasAdd:0:0 */, -128 /* span=sequential/conv1d/squeeze_batch_dims/BiasAdd:0:0 */, 0.142336f /* span=sequential/conv1d/squeeze_batch_dims/BiasAdd:0:0 */, -8 /* span=sequential/conv1d/squeeze_batch_dims/BiasAdd:0:0 */) /* span=sequential/conv1d/squeeze_batch_dims/BiasAdd:0:0 */;
  reshape(%9, newshape=[1, 1, 14, 16]) /* span=StatefulPartitionedCall:0:0:0 */
}

I can understand everything up to the calculation of %5, after that it gets weird. Why is the reshape-squeeze-reshape step needed? %5 already seems to be of shape [1, 1, 14, 16] so the rest seems unnecessary. I know it is related to batch_size, since if I fix that to 1 I don’t get them, but I would like to understand why this is needed when the batch size is not fixed. Thanks in advance!