Where does the layout transform of each op happen during alter_op_layout pass?

Hello! I’m trying to figure out the problem as said in the title. For example, when a module is built like:

with autotvm.apply_graph_best(graph_opt_sch_file):
  with tvm.transform.PassContext(opt_level=3):
    graph_factory = relay.build_module.build(mod, target=target, params=params)

the function CallWithNewLayouts in alter_op_layout.cc will be called, and it calls a series of functions all the way until

def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):

supposing the target is an x86 cpu. However, I only see this function changing the layout info in attrs, yet to see any change of the actual layout of tensors in the graph. If I debug this process and print the IR right after the AlterOpLayout pass I can see the shapes of tensors changed accordingly from 4D to 5D/6D, and the layout_transform nodes are inserted. So my question here is when does this happen? Can anyone give me a pointer to the code?

Many thanks!

When you see the tensors changed from 4D to 5D, the corresponding conv2d op has already been changed from NCHW to NCHWc; otherwise the type won’t match. This is called “alter op layout”. Specifically, the function you pointed returns the altered NCHWc op:

Accordingly, your graph changed from 4D -> conv2d_NCHW to 4D -> layout_tranform -> 5D -> conv2d_NCHWc.

I see, so are you saying the inputs in line 114 are already 5D? Or they’re somehow converted to 5D?

Btw, here are you saying NCHW’s inputs can only be 4D and NCHWc’s 5D/6D? I’m actually experimenting a customer op. How do I let it accept both 4D and 5D/6D inputs?

An op can only accept a static type of inputs, so you cannot let an op accept both 4D and 5D inputs. That’s why we need to “alter op”.

I see. That’s one important info I didn’t catch before. Thank you for letting me know!

But now I’m still not sure when the 4D to 5D/6D conversion of tensors happen, as well as all expand_dims and layout_transform. Does it happen somewhere before the alter_op_layout pass?

OK, I think I know what I’m looking for and where they are. It’s in transform_layout.h where there’s a LayoutRewriter function for this purpose. Specifically, memoizer's Transform function (defined in the same file) does the job:

  Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) {
    if (src_layout.Equals(dst_layout)) {
      return raw;

    std::tuple<const Object*, std::string, std::string> key =
        std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name());
    auto& memo = operator->()->memo;

    auto iter = memo.find(key);
    if (iter != memo.end()) {
      return iter->second;
    } else {
      Expr transform = TransformHelper(raw, src_layout, dst_layout);
      memo[key] = transform;
      return transform;