Working with fused axes

Hello,

I’m trying to fuse axes which have “the same semantics”. However, if I try to bind the fused axis to blockIdx.y I’m getting

TVMError: Operate on iter var iter_var(nn.yy.fused.xx.fused, )that is not part of the schedule

Is there a way to achieve this? Ideally I’d fuse the axis and later split it again (into different parts), but I’m stuck even earlier.

The background is that I would like to implement 1x1 convolutions which parallelize the reduction (I think the current GPU conv templates don’t do that). Here batch and spacial axes have the same role (and there is a single reduction axis, in_channels).

Best regards

Thomas

Could you share your code that can reproduce the error?

Hello @haichen,

thank you for your reply!

I think the error went away when I moved the line defining k as the reduction axis after dealing with the non-reducing axes. (It’s a bit surprising that the order matters as much, but I imagine something is going on.)

So it’s basically a medley of the conv2d compute and an oversimplified schedule for dense.

If you are finding this in the discussion forum archive: I don’t recommend to copy this code, it doesn’t work well!

def conv2d_1x1_nchw(inp, w, stride, out_dtype=None):
    if out_dtype is None:
        out_dtype = inp.dtype
    assert isinstance(stride, int) or len(stride) == 2
    if isinstance(stride, int):
        stride_h = stride_w = stride
    else:
        stride_h, stride_w = stride

    batch, in_channel, in_height, in_width = inp.shape
    num_filter, channel, kernel_h, kernel_w = w.shape
    assert int(kernel_h) == 1 and int(kernel_w) == 1
    out_channel = num_filter
    out_height = simplify((in_height - 1) // stride_h + 1)
    out_width = simplify((in_width - 1) // stride_w + 1)

    rc = te.reduce_axis((0, in_channel), name='rc')
    
    return te.compute(
        (batch, out_channel, out_height, out_width),
        lambda b, oc, x, y: te.sum(
            inp[b, rc, x * stride_h,y * stride_w].astype(out_dtype) *
            w[oc, rc, 0, 0].astype(out_dtype),
            axis=[rc]), tag="conv2d_1x1_nchw")

@autotvm.template("tutorial/conv2d_1x1_gemm_no_batching")
def conv2d_1x1_gemm_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
    assert N == 1, "Only consider batch_size = 1 in this template"
    assert KH == 1, KW == 1

    data = te.placeholder((N, CI, H, W), name='data')
    kernel = te.placeholder((CO, CI, KH, KW), name='kernel')
    
    conv = conv2d_1x1_nchw(data, kernel, stride, out_dtype='float32')
    s = te.create_schedule([conv.op])

    batch = N * H * W
    in_dim = CI
    out_dim = CO

    C = s[conv]
    
    k, = s[conv].op.reduce_axis    # I had this at the other axes before
    num_thread = 128
    ko, kf = s[conv].split(k, factor=num_thread)
    partial_red = s.rfactor(conv, kf)

    ax_b, ax_out_ch, ax_x, ax_y = s[conv].op.axis
    s[conv].reorder(ax_b, ax_x, ax_y, ax_out_ch)
    ax_bxy = s[conv].fuse(ax_b, ax_x, ax_y)
        
    s[conv].bind(ax_out_ch, te.thread_axis("blockIdx.x"))
    s[conv].bind(ax_bxy, te.thread_axis("blockIdx.y"))

    tx = s[conv].op.reduce_axis[0]
    thread_x = te.thread_axis("threadIdx.x")
    s[conv].bind(tx, thread_x)
    s[partial_red].compute_at(s[conv], tx)
    s[conv].set_store_predicate(thread_x.var.equal(0))
    
    return s, [data, kernel, conv]

I don’t recommend to copy this code, it doesn’t work well!

So I guess if that is now loosely OK, I need to set up the memory caching.

Best regards

Thomas

Could you try to bind the outer most axis ax_bxy before ax_out_ch?