Incomplete type in BroadcastToRel while doing MergeComposite

Hi, community,

I am seeing segfault while running this code snippet

In this code snippet, I am running a small model as below:

def @main(%data: Tensor[(1, 3, 16, 16), float32]) -> Tensor[(1, 6, 8, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(6, 3, 4, 4), float32] */, strides=[2, 2], padding=[1, 1, 1, 1], channels=6, kernel_size=[4, 4]) /* ty=Tensor[(1, 6, 8, 8), float32] */;
  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(6, 1, 1), float32] */) /* ty=Tensor[(1, 6, 8, 8), float32] */;
  %2 = nn.global_avg_pool2d(%1) /* ty=Tensor[(1, 6, 1, 1), float32] */;
  %3 = broadcast_to(%2, meta[relay.attrs.InitOpAttrs][0]) /* ty=Tensor[(1, 6, 8, 8), float32] */;
  multiply(%1, %3) /* ty=Tensor[(1, 6, 8, 8), float32] */
}

When I try to apply a MergeComposite pass to merge weight + nn.conv2d + bias + add. I got a segfault due to an incomplete type in src/relay/op/tensor/transform.cc::BroadCastToRel. The types[0] in below’s code has incomplete type.

Has anyone encountered a similar issue? Is there any suggestion about how to fix it?

Thanks, Joey

Looks like the type function of broadcast doesn’t cover the case of IncompleteType. I fixed it in this PR

The output with this PR:

def @main(%data: Tensor[(1, 3, 16, 16), float32]) -> Tensor[(1, 6, 8, 8), float32] {
  %1 = fn (%FunctionVar_0_0: Tensor[(1, 3, 16, 16), float32], PartitionedFromPattern="nn.conv2d_add_", Composite="conv2d_add") -> Tensor[(1, 6, 8, 8), float32] {
    %0 = nn.conv2d(%FunctionVar_0_0, meta[relay.Constant][0] /* ty=Tensor[(6, 3, 4, 4), float32] */, strides=[2, 2], padding=[1, 1, 1, 1], channels=6, kernel_size=[4, 4]) /* ty=Tensor[(1, 6, 8, 8), float32] */;
    add(%0, meta[relay.Constant][1] /* ty=Tensor[(6, 1, 1), float32] */) /* ty=Tensor[(1, 6, 8, 8), float32] */
  };
  %2 = %1(%data) /* ty=Tensor[(1, 6, 8, 8), float32] */;
  %3 = nn.global_avg_pool2d(%2) /* ty=Tensor[(1, 6, 1, 1), float32] */;
  %4 = broadcast_to(%3, meta[relay.attrs.InitOpAttrs][0]) /* ty=Tensor[(1, 6, 8, 8), float32] */;
  multiply(%2, %4) /* ty=Tensor[(1, 6, 8, 8), float32] */
}
1 Like

Hi @comaniac,

Thanks a lot for your help! It works perfectly now!