Adding new kernel layout to backend

For the pull request #6137 (Better grouped convolution for CPU targets), I’ve been asked to add an alter_op so that the kernel is reshaped ahead of time.

It took me some time to figure out how to do this, but I’ve now figured out what is required to do it on the TOPI side. However, I’m having an issue with the C++ backend.

The reason for this is that my kernel reshape is not NCHW to “OIHW%di%do” % (ic_bn, oc_bn), but to “GOIHW%di%do” % (ic_bn, oc_bn), where G` is the number of groups. Therefore, when I try to have this layout for alter op, I get the following error:

An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------
  Check failed: false == false: [15:58:23] ../src/tir/ir/data_layout.cc:364:
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------
  Check failed: GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout) == false:

I think this is happening because there is no "GOIHW%di%do" % (ic_bn, oc_bn) layout in the C++ backend, and I guess conv2d_alter_op does the reshape in that backend. The condition that fails in this function is:

if (tir::is_zero(store)) {
      // Not convertible
      return false;
}

Just using "OIHW%di%do" % (ic_bn, oc_bn) in my conv2d_alter_op doesn’t work, because it means that my TOPI function gets a 6D kernel, rather than a 7D kernel.

So I’m wondering how I would add this new kernel layout to this process? Ideally reusing the rules of the reshape in TOPI (see here)

You can see my working branch here, where the issue happens. This test script generates group convolution networks and runs them on x86.