There are four inputs, data, scale_d, scale_h, scale_w. in op, “dyn.nn.upsampling3d”. However, only three data layouts is set in UpsamplingInferCorrectLayout.
template <typename T>
InferCorrectLayoutOutput UpsamplingInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* attrs_ptr = attrs.as<T>();
ICHECK(attrs_ptr);
ObjectPtr<T> params = make_object<T>(*attrs_ptr);
if (new_in_layouts.defined()) {
ICHECK_GT(new_in_layouts.size(), 0);
Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) &&
(input.IndexOf(LayoutAxis::Get('D')) == -1 ||
(input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
!input.Contains(LayoutAxis::Get('d'))))) {
params->layout = input.name(); // modify self to follow the input layout
}
}
Layout inferred_layout(params->layout);
Layout param_layout("NCHW");
return InferCorrectLayoutOutput({inferred_layout, param_layout, param_layout}, {inferred_layout},
Attrs(params));
}