[Relay] Help with fusion bug

I’m trying to get operation fusing to work with a new bitserial_conv2d relay op that I’ve added but am running into a bug that I can’t figure out.

When I register the bitserial_conv2d oppattern to be opaque, everything works exactly as expected. However, switching the pattern to out_elemwise_fusible causes even the simplest set of operations (in this case a bitserial_conv followed by an add) to spit out this error:

Traceback (most recent call last):
  File "rpi_simple_test.py", line 48, in <module>
    graph, lib, params = relay.build_module.build(y_func, target=target, params=params)
  File "/home/jwfromm/tvm/python/tvm/relay/build_module.py", line 196, in build
    params)
  File "/home/jwfromm/tvm/python/tvm/relay/build_module.py", line 107, in build
    self._build(func, target, target_host)
  File "tvm/_ffi/_cython/./function.pxi", line 310, in tvm._ffi._cy3.core.FunctionBase.__call__
  File "tvm/_ffi/_cython/./function.pxi", line 245, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./function.pxi", line 234, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 170, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home/jwfromm/tvm/build/libtvm.so(+0x4bcd8d) [0x7fa825b47d8d]
  [bt] (7) /home/jwfromm/tvm/build/libtvm.so(+0x4bbfa8) [0x7fa825b46fa8]
  [bt] (6) /home/jwfromm/tvm/build/libtvm.so(+0x4e73f6) [0x7fa825b723f6]
  [bt] (5) /home/jwfromm/tvm/build/libtvm.so(+0x4e66f6) [0x7fa825b716f6]
  [bt] (4) /home/jwfromm/tvm/build/libtvm.so(+0x4de213) [0x7fa825b69213]
  [bt] (3) /home/jwfromm/tvm/build/libtvm.so(+0x4e184a) [0x7fa825b6c84a]
  [bt] (2) /home/jwfromm/tvm/build/libtvm.so(+0x4bda4f) [0x7fa825b48a4f]
  [bt] (1) /home/jwfromm/tvm/build/libtvm.so(+0x4c5890) [0x7fa825b50890]
  [bt] (0) /home/jwfromm/tvm/build/libtvm.so(+0x964c8b) [0x7fa825fefc8b]
  File "/home/jwfromm/tvm/python/tvm/relay/backend/_backend.py", line 52, in lower
    f = _build.lower(sch, inputs, name=func_name)
  File "/home/jwfromm/tvm/python/tvm/build_module.py", line 376, in lower
    stmt = form_body(sch)
  File "/home/jwfromm/tvm/python/tvm/build_module.py", line 325, in form_body
    bounds = schedule.InferBound(sch)
  File "tvm/_ffi/_cython/./function.pxi", line 310, in tvm._ffi._cy3.core.FunctionBase.__call__
  File "tvm/_ffi/_cython/./function.pxi", line 245, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./function.pxi", line 234, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 170, in tvm._ffi._cy3.core.CALL
  [bt] (4) /home/jwfromm/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7fa825ff51f1]
  [bt] (3) /home/jwfromm/tvm/build/libtvm.so(+0x19b792) [0x7fa825826792]
  [bt] (2) /home/jwfromm/tvm/build/libtvm.so(+0x4753d5) [0x7fa825b003d5]
  [bt] (1) /home/jwfromm/tvm/build/libtvm.so(+0x481faf) [0x7fa825b0cfaf]
  [bt] (0) /home/jwfromm/tvm/build/libtvm.so(+0x155b13) [0x7fa8257e0b13]
  File "/home/jwfromm/tvm/src/schedule/message_passing.cc", line 71
  File "tvm/_ffi/_cython/./function.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/home/jwfromm/tvm/python/tvm/relay/backend/_backend.py", line 60, in lower
    raise RuntimeError(msg)
  File "/home/jwfromm/tvm/python/tvm/relay/backend/_backend.py", line 52, in lower
    f = _build.lower(sch, inputs, name=func_name)
  File "/home/jwfromm/tvm/python/tvm/build_module.py", line 376, in lower
    stmt = form_body(sch)
  File "/home/jwfromm/tvm/python/tvm/build_module.py", line 325, in form_body
    bounds = schedule.InferBound(sch)
  File "tvm/_ffi/_cython/./function.pxi", line 310, in tvm._ffi._cy3.core.FunctionBase.__call__
  File "tvm/_ffi/_cython/./function.pxi", line 245, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./function.pxi", line 234, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 170, in tvm._ffi._cy3.core.CALL
  [bt] (4) /home/jwfromm/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7fa825ff51f1]
  [bt] (3) /home/jwfromm/tvm/build/libtvm.so(+0x19b792) [0x7fa825826792]
  [bt] (2) /home/jwfromm/tvm/build/libtvm.so(+0x4753d5) [0x7fa825b003d5]
  [bt] (1) /home/jwfromm/tvm/build/libtvm.so(+0x481faf) [0x7fa825b0cfaf]
  [bt] (0) /home/jwfromm/tvm/build/libtvm.so(+0x155b13) [0x7fa8257e0b13]
  File "/home/jwfromm/tvm/src/schedule/message_passing.cc", line 71
TVMError: Check failed: allow_missing:
During handling of the above exception, another exception occurred:

TVMError: Check failed: allow_missing:
Error during compile function
-----------------------------
v0.0.1
fn (%p0: Tensor[(1, 64, 64, 64), int16], %p1: Tensor[(3, 3, 64, 64), int16], %p2: int16, __dict__=meta[StrMap][0]) -> Tensor[(1, 62, 62, 64), int16] {
  %0 = nn.bitserial_conv2d(%p0, %p1, channels=64, data_layout="NHWC", kernel_layout="", pack_dtype="uint8", out_dtype="int16") /* ty=Tensor[(1, 62, 62, 64), int16] */
  add(%0, %p2) /* ty=Tensor[(1, 62, 62, 64), int16] */
}
/* meta data */
{
  "root": 1,
  "nodes": [
    {
      "type_key": ""
    },
    {
      "type_key": "StrMap",
      "keys": [
        "StrMap"
      ],
      "data": [2]
    },
    {
      "type_key": "Array",
      "data": [3]
    },
    {
      "type_key": "StrMap",
      "keys": [
        "Primitive"
      ],
      "data": [4]
    },
    {
      "type_key": "IntImm",
      "attrs": {
        "dtype": "int32",
        "value": "1"
      }
    }
  ],
  "b64ndarrays": [],
  "attrs": {"tvm_version": "0.6.dev"}
}

I’ve looked at the topi implementation and as far as I can tell the add is being scheduled correctly. I feel like I probably missed some sort of fused operation definition somewhere but I have no idea where. Does anyone have any tips that might help?

can you show me your bitserial_conv2d schedule? For fusion to work, each parent schedule needs to have additional schedules at the end for handling to-be-fused ops (bias add, relu etc).

Sure, here are all the relevant functions.

def traverse_inline(s, final_op, callback):
    """Traverse computation graph and do auto inline
    
    Parameters
    ----------  
    s: schedule
        The schedule
    final_op: Operation
        The final output operator.
    callback: callable
        The callback function on each op
    """ 
    visited = set()

    def _traverse(op):
        if op in visited:
            return
        visited.add(op)
        if tag.is_injective(op.tag):
            if op not in s.outputs:
                s[op].compute_inline()
            for tensor in op.input_tensors:
                if tensor.op.input_tensors:
                    _traverse(tensor.op)
        callback(op)

    _traverse(final_op)

# ARM specific schedule that using custom microkernel
def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,                                                                               
                                  conv_out, output, last, unipolar):                                                                                    
    _, _, _, _, _, IB, CI = data_vec.shape
    _, KH, KW, KB, _, _ = kernel_vec.shape
    KB = get_const_int(KB)                                                                                                                              
    IB = get_const_int(IB)                                                                                                                              

    VC = cfg["tile_co"].size[-1]
    VH = cfg["tile_oh"].size[-1]
    VW = cfg["tile_ow"].size[-1]           

    ##### Schedule data padding and  packing                                                                                                            
    if data_pad is not None:
        s[data_pad].compute_inline()

    _, h, _, _, _, _, _ = s[data_vec].op.axis
    cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)                                                                
    oh, ih = cfg["tile_ah"].apply(s, data_vec, h)                                                                                                       
    s[data_vec].parallel(oh)                                                                                                                            

    #### Schedule kernel packing
    co, _, _, _, _, _ = s[kernel_vec].op.axis
    cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)                                                              
    oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)                                                                                                 
    s[kernel_vec].parallel(oco)

    ##### Schedule Convolution
    n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis
    kh, kw, kb, ib, ci = s[conv_out].op.reduce_axis

    ci_o, ci_i = cfg['tile_ci'].apply(s, conv_out, ci)
    re_axes = cfg["reorder_0"].apply(s, conv_out,
                                     [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i])                                                           

    # Use microkernel                                                                                                                                   
    kfactor = cfg['tile_ci'].size[1]                                                                                                                    
    if kfactor % 8 == 0:
        pc = _intrin_popcount(VC, kfactor, KB, IB, unipolar)                                                                                            
        s[conv_out].tensorize(kb, pc)

    n, h, w, co = s[last].op.axis                                                                                                                       
    co, vc = cfg['tile_co'].apply(s, last, co)
    oh, vh = cfg['tile_oh'].apply(s, last, h)                                                                                                           
    ow, vw = cfg['tile_ow'].apply(s, last, w)                                                                                                           
    s[last].reorder(n, oh, ow, co, vh, vw, vc)                                                                                                          
    s[last].vectorize(vc)
    if last != output:
        s[last].compute_inline()

    s[conv_out].compute_at(s[last], co)
    s[last].parallel(oh)                                                                                                                                
    return s   

@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, 'arm_cpu', 'direct')                                                         
def schedule_bitserial_conv2d_nhwc(cfg, outs):                                                                                                          
    """Arm cpu schedule for bitserial conv2d"""                                                                                                         
    s = tvm.create_schedule([x.op for x in outs])                                                                                                       

    def _callback(op):                                                                                                                                  
        """Traverse operators from computation graph"""

        if 'spatial_bitserial_conv_nhwc' in op.tag:
            output = op.output(0)
            conv_out = op.input_tensors[0]
            kernel_vec = conv_out.op.input_tensors[0]
            kernel_q = kernel_vec.op.input_tensors[0]
            data_vec = conv_out.op.input_tensors[1]
            data_q = data_vec.op.input_tensors[0]
            data = data_q.op.input_tensors[0]
            data_pad = None
            if isinstance(data_q.op, tvm.tensor.ComputeOp) and "pad" in data_q.op.tag:
                data_pad = data_q
                data_q = data
                data = data.op.input_tensors[0]
            unipolar = "unipolar" in conv_out.op.tag
            _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
                                          conv_out, output, outs[0], unipolar)
    #traverse(outs[0].op)
    traverse_inline(s, outs[0].op, _callback)
    return s

My understanding is that traverse_inline should be properly catching any fusible ops. Is there an extra check missing from the actual conv schedule too?

This looks correct. For fusion setting the op pattern to out_elemwise_fusible should be all that is needed. Sorry I have no idea what is broken.

this was unfortunately a limitation of the current schedule language. Since we are mainly doing schedule from input to output, we could inline input into the output, but not the other way around.

If you take a look at other conv schedules, you will find that these schedules schedule the outer loop of the conv using final stage(add), and then use compute at to set the inner stage of the conv2d at the final stage point, then inline the intermediate ones

1 Like

Would be useful if we could just use topi to compose the conv2d and add and see what will happen

Thanks Tianqi! It was exactly as you described - we were trying to inline the output into input

s[last].reorder(n, oh, ow, co, vh, vw, vc)
s[last].vectorize(vc) if last != output: s[last].compute_inline()

Here scheduled compute_inline and other scheduler primitives such as vectorize to stage s[last]. Is this legal?