[Relay] [NN] Does relay.nn.dense supports multi-dimensional input?

Hi folks, an investigation of a different issue led me to this post. Since we cannot run nn.dense on dimension > 2, our frontends add reshape before and after nn.dense. But our op fusion pass doesn’t fuse reshape with dense, so we cannot fuse dense with activation ops that follow it anymore due to the annoying reshape in the middle.

In particular, when we import huggingface transformer models, most of dense ops are not fused with elemwise ops at all, so we end up with something like

  ...
  %987 = fn (%p0420: Tensor[(1024, 1024), float16], %p1305: Tensor[(4096, 1024), float16], Primitive=1, hash="c13735290dc46bbc") -> Tensor[(1024, 4096), float16] {
    nn.dense(%p0420, %p1305, units=None, out_dtype="float16") /* ty=Tensor[(1024, 4096), float16] */
  };
  %988 = %987(%986, meta[relay.Constant][16] /* ty=Tensor[(4096, 1024), float16] */) /* ty=Tensor[(1024, 4096), float16] */;
  %989 = fn (%p0419: Tensor[(1024, 4096), float16], %p1304: Tensor[(4096), float16], %p2142: float16, Primitive=1, hash="ab37ab7bd1a05f99") -> Tensor[(1024, 4096), float16] {
    %887 = reshape(%p0419, newshape=[8, 128, 4096]) /* ty=Tensor[(8, 128, 4096), float16] */;
    %888 = add(%887, %p1304) /* ty=Tensor[(8, 128, 4096), float16] */;
    %889 = multiply(%888, %p2142) /* ty=Tensor[(8, 128, 4096), float16] */;
    %890 = cast(%889, dtype="float32") /* ty=Tensor[(8, 128, 4096), float32] */;
    %891 = erf(%890) /* ty=Tensor[(8, 128, 4096), float32] */;
    %892 = multiply(%891, 0.5f /* ty=float32 */) /* ty=Tensor[(8, 128, 4096), float32] */;
    %893 = cast(%888, dtype="float32") /* ty=Tensor[(8, 128, 4096), float32] */;
    %894 = add(0.5f /* ty=float32 */, %892) /* ty=Tensor[(8, 128, 4096), float32] */;
    %895 = multiply(%893, %894) /* ty=Tensor[(8, 128, 4096), float32] */;
    %896 = reshape(%895, newshape=[-1, 4096]) /* ty=Tensor[(1024, 4096), float32] */;
    cast(%896, dtype="float16") /* ty=Tensor[(1024, 4096), float16] */
  };
 ...

So it is very important that we fix this issue.

1 Like