Error in writting tune graph for nn.conv3d

Hi, @ comaniac the error [autoTVM] Crash when tuning "dense " on x86 cpu - #17 by aiblackman we discussed last week is still confusing, I updated some codes which can be found there GitHub - aiblackmaner/tvm. My modified codes can be found in the python/tvm/topi/x86/conv3d.py which marked with # modify , and I put my tested code into the tests/test_3d folder. There are mainly four files in tests/test_3d , network.py includes the main network structure,load_test.py includes the test code for loading the saved .so file, from_pytorch_3d.py is same as the tvm official website example from_pytorch.py but it is for my 3d network, and the same with the tune_realy_x86_3d.py.

The test environment of my computer is:Ubuntu 16.04 x86_64 , Torch 1.4.0 , TorchVision 0.5.0 , clang+llvm-9.0.0-x86_64 , and the commit id i pulled from https://github.com/apache/tvm.git is 1d5504b , the version is 0.7.dev1 .

The network used in network.py is the same as below:

class Down(nn.Module):
    def __init__(self, in_channels, kernel_size=2, stride=2):
        super(Down, self).__init__()
        out_channels = in_channels * 2
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, groups=1)
        self.bn = nn.BatchNorm3d(out_channels)
        self.act = nn.ReLU(inplace=True)

    def forward(self, input):
        out = self.act(self.bn(self.conv(input)))
        return out

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
        super(Up, self).__init__()
        self.conv = nn.ConvTranspose3d(in_channels, out_channels // 2, kernel_size=kernel_size, stride=stride, groups=1)
        self.bn = nn.BatchNorm3d(out_channels // 2)
        self.act = nn.ReLU(inplace=True)

    def forward(self, input, skip):
        out = self.act(self.bn(self.conv(input)))
        out = torch.cat((out, skip), 1)
        return out

class Input(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Input, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm3d(out_channels)
        self.act = nn.ReLU(inplace=True)

    def forward(self, input):
        out = self.act(self.bn(self.conv(input)))
        return out

class Output(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Output, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.act1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input):
        out = self.act1(self.bn1(self.conv1(input)))
        out = self.conv2(out)
        out = self.softmax(out)
        return out

class SegNet(nn.Module):
    # network1
    # def __init__(self, in_channels, out_channels):
    #     super(SegNet, self).__init__()
    #     self.in_block = Input(in_channels, 8)
    #     self.out_block = Output(8, out_channels)
    #
    # def forward(self, input):
    #     out_in = self.in_block(input)
    #     out = self.out_block(out_in)
    #     return out

    # network2
    def __init__(self, in_channels, out_channels):
        super(SegNet, self).__init__()
        self.in_block = Input(in_channels, 8)
        self.conv = Down(8)
        self.up = Up(16, 16)
        self.out_block = Output(16, out_channels)

    def forward(self, input):
        out_in = self.in_block(input)
        out = self.conv(out_in)
        up = self.up(out, out_in)
        out = self.out_block(up)
        return out

When I used the network2 in network.py got the state_params.pth, which the input size is(1, 1, 32, 32, 32) and output size is (1, 3, 32, 32, 32), the tune_realy_x86_3d.py in tests/test_3d folder can work normally as follows:

Extract tasks...
[Task  1/ 4]  Current/Best:    5.37/  10.45 GFLOPS | Progress: (32/32) | 12.53 s Done.
[Task  2/ 4]  Current/Best:   13.12/  18.13 GFLOPS | Progress: (80/80) | 52.90 s Done.
[Task  3/ 4]  Current/Best:   19.74/  31.14 GFLOPS | Progress: (160/160) | 65.09 s Done.
[Task  4/ 4]  Current/Best:   20.76/  21.74 GFLOPS | Progress: (32/32) | 13.37 s Done.
Cannot find config for target=llvm -keys=tracing,cpu -device=tracing, workload=('conv3d_NCDHWc.x86', ('TENSOR', (1, 1, 32, 32, 32), 'float32'), ('TENSOR', (8, 1, 3, 3, 3), 'float32'), (1, 1, 1), (1, 1, 1, 1, 1, 1), (1, 1, 1), 'NCDHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -keys=tracing,cpu -device=tracing, workload=('conv3d_NCDHWc.x86', ('TENSOR', (1, 8, 32, 32, 32), 'float32'), ('TENSOR', (16, 8, 2, 2, 2), 'float32'), (2, 2, 2), (0, 0, 0, 0, 0, 0), (1, 1, 1), 'NCDHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -keys=tracing,cpu -device=tracing, workload=('conv3d_NCDHWc.x86', ('TENSOR', (1, 16, 32, 32, 32), 'float32'), ('TENSOR', (3, 16, 3, 3, 3), 'float32'), (1, 1, 1), (1, 1, 1, 1, 1, 1), (1, 1, 1), 'NCDHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -keys=tracing,cpu -device=tracing, workload=('conv3d_NCDHWc.x86', ('TENSOR', (1, 3, 32, 32, 32), 'float32'), ('TENSOR', (3, 3, 1, 1, 1), 'float32'), (1, 1, 1), (0, 0, 0, 0, 0, 0), (1, 1, 1), 'NCDHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
2021-02-01 17:52:09,640 INFO Start to benchmark layout transformation...
2021-02-01 17:52:59,951 INFO Benchmarking layout transformation successful.
2021-02-01 17:52:59,953 INFO Start to run dynamic programming algorithm...
2021-02-01 17:52:59,954 INFO Start forward pass...
2021-02-01 17:52:59,956 INFO Finished forward pass.
2021-02-01 17:52:59,956 INFO Start backward pass...
2021-02-01 17:52:59,958 INFO Finished backward pass...
2021-02-01 17:52:59,958 INFO Finished DPExecutor run.
2021-02-01 17:52:59,959 INFO Writing optimal schedules to segnet_graph_opt.log successfully.
Compile...
Cannot find config for target=llvm -keys=cpu -mcpu=core-avx2, workload=('conv3d_NCDHWc.x86', ('TENSOR', (1, 16, 33, 33, 33), 'float32'), ('TENSOR', (8, 16, 2, 2, 2), 'float32'), (1, 1, 1), (0, 0, 0), (1, 1, 1), 'NCDHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
/home/project/test_3d/tune_relay_x86_3d.py:187: DeprecationWarning: legacy graph runtime behaviour of producing json / lib / params will be  removed in the next release. Please see documents of tvm.contrib.graph_runtime.GraphModule for the  new recommended usage.
  graph, lib, params = relay.build(mod, target=target, params=params)
Evaluate inference time cost...
Mean inference time (std dev): 17.76 ms (0.05 ms)
Process finished with exit code 0

But when I set the input size to (1, 1, 48, 48, 48) and output size to (1, 3, 48, 48, 48), the tune_realy_x86_3d.py in tests/test_3d folder can not work normally with the following error:

Extract tasks...
[Task  1/ 4]  Current/Best:    5.28/   7.80 GFLOPS | Progress: (48/48) | 19.98 s Done.
[Task  2/ 4]  Current/Best:   16.06/  20.51 GFLOPS | Progress: (120/120) | 148.96 s Done.
[Task  3/ 4]  Current/Best:   26.39/  38.90 GFLOPS | Progress: (240/240) | 114.28 s Done.
[Task  4/ 4]  Current/Best:   36.73/  36.73 GFLOPS | Progress: (48/48) | 22.51 s Done.
Cannot find config for target=llvm -keys=tracing,cpu -device=tracing, workload=('conv3d_NCDHWc.x86', ('TENSOR', (1, 1, 48, 48, 48), 'float32'), ('TENSOR', (8, 1, 3, 3, 3), 'float32'), (1, 1, 1), (1, 1, 1, 1, 1, 1), (1, 1, 1), 'NCDHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -keys=tracing,cpu -device=tracing, workload=('conv3d_NCDHWc.x86', ('TENSOR', (1, 8, 48, 48, 48), 'float32'), ('TENSOR', (16, 8, 2, 2, 2), 'float32'), (2, 2, 2), (0, 0, 0, 0, 0, 0), (1, 1, 1), 'NCDHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -keys=tracing,cpu -device=tracing, workload=('conv3d_NCDHWc.x86', ('TENSOR', (1, 16, 48, 48, 48), 'float32'), ('TENSOR', (3, 16, 3, 3, 3), 'float32'), (1, 1, 1), (1, 1, 1, 1, 1, 1), (1, 1, 1), 'NCDHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -keys=tracing,cpu -device=tracing, workload=('conv3d_NCDHWc.x86', ('TENSOR', (1, 3, 48, 48, 48), 'float32'), ('TENSOR', (3, 3, 1, 1, 1), 'float32'), (1, 1, 1), (0, 0, 0, 0, 0, 0), (1, 1, 1), 'NCDHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
2021-02-01 17:44:08,252 INFO Start to benchmark layout transformation...
2021-02-01 17:45:37,742 INFO Benchmarking layout transformation successful.
2021-02-01 17:45:37,744 INFO Start to run dynamic programming algorithm...
2021-02-01 17:45:37,744 INFO Start forward pass...
2021-02-01 17:45:37,748 INFO Finished forward pass.
2021-02-01 17:45:37,748 INFO Start backward pass...
2021-02-01 17:45:37,750 INFO Finished backward pass...
2021-02-01 17:45:37,750 INFO Finished DPExecutor run.
2021-02-01 17:45:37,750 INFO Writing optimal schedules to segnet_graph_opt.log successfully.
Compile...
Cannot find config for target=llvm -keys=cpu -mcpu=core-avx2, workload=('conv3d_NCDHWc.x86', ('TENSOR', (1, 16, 49, 49, 49), 'float32'), ('TENSOR', (8, 16, 2, 2, 2), 'float32'), (1, 1, 1), (0, 0, 0), (1, 1, 1), 'NCDHW', 'float32'). A fallback configuration is used, which may bring great performance regression.
/home/project/test_3d/tune_relay_x86_3d.py:187: DeprecationWarning: legacy graph runtime behaviour of producing json / lib / params will be  removed in the next release. Please see documents of tvm.contrib.graph_runtime.GraphModule for the  new recommended usage.
  graph, lib, params = relay.build(mod, target=target, params=params)
Evaluate inference time cost...

Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)

The modified codes in the python/tvm/topi/x86/conv3d.py are in the form of python/tvm/topi/x86/conv2d.py, when the shape is (32, 32, 32) it can work normally but when the shape is (48, 48, 48) or bigger than (48, 48, 48), it can not work normally, this problem bothered me for a long time. Could you take a look at the codes , or do you have any good suggestions for this error? Thank you very much!