[pytorch/onnx]model nn.conv2, in particular dimension 1 conflicts 4096 does not match 512;

[pytorch] save model

 %105 = add(%104, 1f);
  %106 = reshape(%105, newshape=[8, 1, 512, 1, 1]);
  %107 = multiply(%3, %106);
  %108 = full(1, shape=[], dtype="float32");
  %109 = power(%107, 2f);
  %110 = sum(%109);
  %111 = add(%110, 1e-08f);
  %112 = power(%111, 0.5f);
  %113 = divide(%108, %112);
  %114 = multiply(%113, 1f);
  %115 = reshape(%114, newshape=[8, 512, 1, 1, 1]);
  %116 = multiply(%107, %115);
  %117 = reshape(%116, newshape=[4096, 512, 3, 3]);
  %118 = nn.conv2d(%1, %117, padding=[1, 1, 1, 1], channels=4096, kernel_size=[3, 3]) in particular dimension 1 conflicts 4096 does not match 512; unable to unify: `Tensor[(4096, 4096, 3, 3), float32]` and `Tensor[(4096, 512, 3, 3), float32]`; ;
  reshape(%118, newshape=[8, 512, 4, 4])
}

[python.onnx ] save model

%122 = take(%121, %v635, axis=0);
%123 = expand_dims(%122, axis=0);
%124 = shape_of(%blocks.1.0.weight, dtype=“int64”);
%125 = take(%124, %v647, axis=0);
%126 = expand_dims(%125, axis=0);
%127 = reshape(%120, newshape=[1, 4096, 4, 4]);
%128 = reshape(%blocks.1.0.weight, newshape=[1, 512, 512, 3, 3]);
%129 = multiply(%128, %v678);
%130 = take(%93, %v364, axis=1);
%131 = multiply(%blocks.1.0.fc.weight, %v655);
%132 = transpose(%131, axes=[1, 0]);
%133 = transpose(%132, axes=[1, 0]);
%134 = nn.dense(%130, %133, units=None);
%135 = reshape(%blocks.1.0.bias.bias, newshape=[1, 512]);
%136 = add(%134, %135);
%137 = add(%136, %v668);
%138 = reshape(%137, newshape=[8, 512, 1, 1, 1]);
%139 = multiply(%129, %138);
%140 = power(%139, %v691);
%141 = sum(%140, axis=[1, 3, 4]);
%142 = add(%141, %v694);
%143 = power(%142, %v696);
%144 = divide(%v698, %143);
%145 = reshape(%144, newshape=[8, 1, 512, 1, 1]);
%146 = multiply(%139, %145); %147 = reshape(%146, newshape=[4096, 512, 3, 3]);
%148 = nn.conv2d_transpose(%127, %147, channels=512, kernel_size=[3, 3],
strides=[2, 2], padding=[0, 0, 0, 0], groups=8) in particular dimension 1 conflicts 64 does not match 512; unable to unify: Tensor[(4096, 64, 3, 3), float32] and
Tensor[(4096, 512, 3, 3), float32]; ;
%149 = shape_of(%148, dtype=“int64”); %150 = take(%149, %v728, axis=0);
%151 = expand_dims(%150, axis=0); %152 = shape_of(%148, dtype=“int64”);
%153 = take(%152, %v731, axis=0);
%154 = expand_dims(%153, axis=0);
%155 = (%123, %126, %151, %154);
concatenate(%155)

I have the same problem. How did you solve it? 我遇到一样的问题,您是怎么解决的?