When doing autot-tunning the inference from tensorflow, the error occurred on x86

%43 = reshape(%42, newshape=[4, 34, 34, 16])
%44 = transpose(%43, axes=[0, 3, 1, 2])
free_var %MMNet/enc_block0/branch1/depthwise_conv_dilation/depthwise_weights: Tensor[(3, 3, 16, 1), float32]
%45 = transpose(%MMNet/enc_block0/branch1/depthwise_conv_dilation/depthwise_weights, axes=[2, 3, 0, 1])
%46 = nn.conv2d(%44, %45, groups=64, channels=64, kernel_size=[3, 3])an internal invariant was violated while typechecking your program [14:08:40] /home/dolphin/tvm/src/relay/pass/type_solver.cc:119: Check failed: resolved.defined(): Unable to unify parent types: TensorType([64, 0, 3, 3], float32) and TensorType([16, 1, 3, 3], float32)

it might be located in the below:
for i, rate in enumerate(rates):
with tf.variable_scope(f"branch{i}"):
conv = slim.conv2d(inputs, num_outputs=expanded_depth, kernel_size=1, stride=1, scope=“pointwise_conv”)
if stride > 1:
conv = separable_conv(conv, num_outputs=None, kernel_size=3, stride=stride, depth_multiplier=1, scope=“depthwise_conv_stride”)
conv = separable_conv(conv, num_outputs=None, kernel_size=3, stride=1, depth_multiplier=1, rate=rate, scope=“depthwise_conv_dilation”)
convs.append(conv)

What’s the reason the error happened? does tvm can auto-tune the dilated convs from tensorflow on x86? thanks in advance!

Are you running a custom model? This may be a frontend bug during model import.

thanks!
yes, it’s a custom model from the github: https://github.com/hyperconnect/MMNet. But it contains standard conv2d, depthwise_conv2d, depthwise conv2d with different dilation rate and tf.image.resize_bilinear. there are no other special ops. the tensorflow decompose the depthwise dilated conv2d into three steps: 1. SpaceToBatchND 2.DepthWiseConv2d 3. BatchToSpaceND where the error happened as you can see the attched pics below.

by the way doing the .pd load, I copied the code from the tutorial but deleted the last two lines:
with tf.gfile.FastGFile(model_path, ‘rb’) as f:

  1. graph_def = tf.GraphDef()
  2. graph_def.ParseFromString(f.read())
    
  3. graph = tf.import_graph_def(graph_def, name='')
    
  4. # Call the utility to import the graph definition into default graph.
    
  5. graph_def = tf_testing.ProcessGraphDefParam(graph_def)
    
  6. # Add shapes to the graph.
    
  7. #with tf.Session() as sess:
     #graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
    

thanks in advance!

during doing the depthwise convolution with dilation rate =2 and filter = [3,3,16,1] , the output dims of SpaceToBatchND are [4, 16,34,34] with the batch_size=4, channels=16 , when doing the group convolution(can be seen as depthwise convolution), nn.conv2d(%44, %45, groups=64, channels=64, kernel_size=[3, 3]), should be like nn.conv2d(%44, %45, groups=16, channels=16, kernel_size=[3, 3])**???

it seems tvm does not have the depthwise convolution with different dilation rates op on X86 for auto-tuning. what about on Nvidia GPU?

What do you mean by “different” dilation rates? Dilation is handled here in the CUDA depth wise schedule: https://github.com/dmlc/tvm/blob/28e8eca18dc2b2d6e23583c8293764cf6dcfdf32/topi/python/topi/cuda/depthwise_conv2d.py#L76

thank you very much! depth-wise convolution with different dilation rates(which can also be called atrous depthwise convolution ) can compose a structure called atrous spatial pyramid pooling(ASPP), (google’s well-known model deeplabv3 uses depth-wise conv2ds with different dilation rates), which is widely used for Semantic Segmentation tasks. could tvm handle atrous depthwise convolution? who is doing the tensorflow frontend? thanks again!

when I looked into the tensorflow frontend a the tvm/python/tvm/relay/frontend/tensorflow.py, did a little change to the code at the line 361, by adding : attr[‘channels’] = weights_shape[0] showed as the picture below , it’s just a workaround for my case and the tensorflow frontend passed and auto-tunning works too.

the tensorflow decompose the depthwise dilated conv2d into three steps: 1. SpaceToBatchND 2.DepthWiseConv2d 3. BatchToSpaceND:
input_converted = array_ops.space_to_batch_nd(
input=inp, block_shape=dilation_rate, paddings=paddings)
result = self.op(input_converted, filter)
result_converted = array_ops.batch_to_space_nd(
input=result, block_shape=dilation_rate, crops=crops)
after space_to_batch_nd, there is no channel increased but adds a batch dim, so does in the tensorflow frontend. In my opinion, it’s most probably this frontend bug caused the error.

during the auto-tuning, an error occurred: [Task 2/48] Current/Best: 11.99/ 18.48 GFLOPS | Progress: (224/560) | 557.88 s Done.
[Task 3/48] Current/Best: 2.84/ 18.95 GFLOPS | Progress: (352/1792) | 874.00 s Done.
Traceback (most recent call last):
File “autotune_tensorflow_x86_mmnet.py”, line 236, in
tune_kernels(tasks, **tuning_option)
File “autotune_tensorflow_x86_mmnet.py”, line 198, in tune_kernels
target=target, template_key=‘direct’)
File “/home/dolphin/tvm/python/tvm/autotvm/task/task.py”, line 191, in create
sch, _ = func(*args)
File “/home/dolphin/tvm/topi/python/topi/x86/depthwise_conv2d.py”, line 217, in _topi_nn_depthwise_conv2d_NCHWc
data_layout, out_layout, dtype)
File “/home/dolphin/tvm/topi/python/topi/x86/depthwise_conv2d.py”, line 78, in _depthwise_conv2d_NCHWc_cpu
= get_const_tuple(kernel.shape)
ValueError: not enough values to unpack (expected 6, got 4)
Does it related to the frontend?
Thanks in advance!

This looks like the data layout of the kernel is not what is expected (e.g., NCHW instead of NCHWc). Can you check what the 4th task is?

thank you very much @eqy ! in the file of tvm/topi/python/topi/x86/depthwise_conv2d.py
line 209: new_kernel_shape = (out_channel // oc_bn, kh, kw, oc_bn) should have 6 values, but it has 4 values which causes this error.
because the function of the line 217: C = _depthwise_conv2d_NCHWc_cpu(cfg, new_data, new_kernel, strides, padding, dilation, data_layout, out_layout, dtype) where the new_kernel should have 6 values. so mismatch happened.
could you help looked into the issue? than you very much!

after I changed this line into new_kernel_shape = (out_channel // oc_bn,0, kh, kw,0, oc_bn) the autotunning passed, just for my case.
but there is another error occurred :
Traceback (most recent call last):
File “autotune_tensorflow_x86_mmnet_input.py”, line 252, in
net, target=target, params=params)
File “/home/dolphin/tvm/python/tvm/relay/build_module.py”, line 305, in build
graph_json, lowered_funcs, params = graph_gen.codegen(func)
File “/home/dolphin/tvm/python/tvm/relay/backend/graph_runtime_codegen.py”, line 90, in codegen
self._codegen(func)
File “/home/dolphin/tvm/python/tvm/_ffi/_ctypes/function.py”, line 209, in call
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) /home/dolphin/tvm/build/libtvm.so(+0xcf12c8) [0x7f3621bdf2c8]
[bt] (7) /home/dolphin/tvm/build/libtvm.so(+0xcebca6) [0x7f3621bd9ca6]
[bt] (6) /home/dolphin/tvm/build/libtvm.so(+0xcf12c8) [0x7f3621bdf2c8]
[bt] (5) /home/dolphin/tvm/build/libtvm.so(+0xcebca6) [0x7f3621bd9ca6]
[bt] (4) /home/dolphin/tvm/build/libtvm.so(+0xcf0e7a) [0x7f3621bdee7a]
[bt] (3) /home/dolphin/tvm/build/libtvm.so(+0xca6e7f) [0x7f3621b94e7f]
[bt] (2) /home/dolphin/tvm/build/libtvm.so(+0xcadf4c) [0x7f3621b9bf4c]
[bt] (1) /home/dolphin/tvm/build/libtvm.so(+0xcad424) [0x7f3621b9b424]
[bt] (0) /home/dolphin/tvm/build/libtvm.so(+0xf4db0b) [0x7f3621e3bb0b]
File “/home/dolphin/tvm/python/tvm/_ffi/_ctypes/function.py”, line 71, in cfun
rv = local_pyfunc(*pyargs)
File “/home/dolphin/tvm/python/tvm/relay/op/nn/_nn.py”, line 165, in schedule_conv2d
return topi.generic.schedule_depthwise_conv2d_nchw(outs)
File “</home/dolphin/.local/lib/python3.5/site-packages/decorator.py:decorator-gen-57>”, line 2, in schedule_depthwise_conv2d_nchw
File “/home/dolphin/tvm/python/tvm/target.py”, line 372, in dispatch_func
return dispatch_dict[k](*args, **kwargs)
File “</home/dolphin/.local/lib/python3.5/site-packages/decorator.py:decorator-gen-151>”, line 2, in config_dispatcher
File “/home/dolphin/tvm/python/tvm/autotvm/task/dispatcher.py”, line 220, in dispatch_func
return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
File “/home/dolphin/tvm/python/tvm/autotvm/task/topi_integration.py”, line 437, in template_call
return f(cfg, outs, *args, **kwargs)
File “/home/dolphin/tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py”, line 151, in schedule_depthwise_conv2d_nchw_arm
traverse_inline(s, outs[0].op, _callback)
File “/home/dolphin/tvm/topi/python/topi/util.py”, line 51, in traverse_inline
_traverse(final_op)
File “/home/dolphin/tvm/topi/python/topi/util.py”, line 49, in _traverse
callback(op)
File “/home/dolphin/tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py”, line 135, in _callback
_schedule(cfg, s, data, data_pad, kernel, output)
File “/home/dolphin/tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py”, line 75, in _schedule
c, vc = cfg[‘tile_c’].apply(s, A0, c)
File “/home/dolphin/tvm/python/tvm/autotvm/task/space.py”, line 773, in getitem
return self._entity_map[name]
KeyError: ‘tile_c’

File “/home/dolphin/tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py”, line 151, in schedule_depthwise_conv2d_nchw_arm

why does it turn to arm_cpu? I did the auto-tunning on x86 cpu.

I have the same problem! can someone help?? @kevinthesun @tqchen

hi,eqy could you give me some advise on this issue? thank you very much!

Can you share your script for this issue? What is your target set to?

thank you! @eqy
target is set to target = ‘llvm’, ctx = tvm.cpu()
the script is at https://github.com/dolphintear/tvm_x86_autotune: autotune_tensorflow_x86_mmnet_input.py and the pb is also at this address.
I made two changes to two tvm files as discussed above:
tvm/python/tvm/relay/frontend/tensorflow.py, did a little change to the code at the line 361, by adding : attr[‘channels’] = weights_shape[0]
in the file of tvm/topi/python/topi/x86/depthwise_conv2d.py:
ine 209: new_kernel_shape = (out_channel // oc_bn, kh, kw, oc_bn) , I changed this line into new_kernel_shape = (out_channel // oc_bn,0, kh, kw,0, oc_bn)

hi, @eqy

This issue has been really bothering and confusing me for a couple of days. today I looked into the file of topi/python/topi/arm_cpu/depthwise_conv2d.py and it’s odd. I found there are not only arm_cpu but also cpu as its ctx show as below:
autotvm.register_topi_compute(depthwise_conv2d_nchw, [‘arm_cpu’, ‘cpu’], ‘direct’,…
@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, [‘arm_cpu’, ‘cpu’],
@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, [‘arm_cpu’, ‘cpu’],…
why does it has ‘cpu’ in the list of [‘arm_cpu’, ‘cpu’]?
I added two prints at the line of 53 and 56 in the file of topi/python/topi/arm_cpu/depthwise_conv2d.py and ran the autuning again, found that during tasks = autotvm.task.extract_from_program(…), the tvm extracted arm_cpu depthwise conv2d tasks but not the x86 depthwise conv2d
logs has these messages:
wangyz debug: [task.py] <function TaskExtractEnv._register_topi_task.._topi_nn_depthwise_conv2d_nchw at 0x7fa27cf7c268>, func_name topi_nn_depthwise_conv2d_nchw
wangyz debug: [topi_integration.py] outs=[Tensor(shape=[16, 40, 16, 16], op.name=DepthwiseConv2d)], args=(), kwargs={}
wangyz debug: [arm_cpu/depthwise_conv2d.py] def schedule_depthwise_conv2d_nchw_arm
wangyz debug: [arm_cpu/depthwise_conv2d.py] def schedule_depthwise_conv2d_nchw_arm/def _schedule

thank you !

Hello,

This PR fixed the ARM error for me: https://github.com/dmlc/tvm/pull/3264

Good luck!

thank you very my much! @kevinthesun This PR did fix the second error and the third one for my tensorflow model. When will you merge the PR? Could I know who is the contributor for the relay/frontend/tensorflow.py, for my first issue, it’s related to tensorflow frontend problem and I just did the workaround for my tensorflow model, after that tvm can pass the frontend. thank you @Equay and @kevinthesun

This PR is merged. For tf converter I think @srkreddy1238 and @yongwww contribute a lot.