Hello @haichen,
thank you for your reply!
I think the error went away when I moved the line defining k as the reduction axis after dealing with the non-reducing axes. (It’s a bit surprising that the order matters as much, but I imagine something is going on.)
So it’s basically a medley of the conv2d compute and an oversimplified schedule for dense.
If you are finding this in the discussion forum archive: I don’t recommend to copy this code, it doesn’t work well!
def conv2d_1x1_nchw(inp, w, stride, out_dtype=None):
if out_dtype is None:
out_dtype = inp.dtype
assert isinstance(stride, int) or len(stride) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
batch, in_channel, in_height, in_width = inp.shape
num_filter, channel, kernel_h, kernel_w = w.shape
assert int(kernel_h) == 1 and int(kernel_w) == 1
out_channel = num_filter
out_height = simplify((in_height - 1) // stride_h + 1)
out_width = simplify((in_width - 1) // stride_w + 1)
rc = te.reduce_axis((0, in_channel), name='rc')
return te.compute(
(batch, out_channel, out_height, out_width),
lambda b, oc, x, y: te.sum(
inp[b, rc, x * stride_h,y * stride_w].astype(out_dtype) *
w[oc, rc, 0, 0].astype(out_dtype),
axis=[rc]), tag="conv2d_1x1_nchw")
@autotvm.template("tutorial/conv2d_1x1_gemm_no_batching")
def conv2d_1x1_gemm_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
assert N == 1, "Only consider batch_size = 1 in this template"
assert KH == 1, KW == 1
data = te.placeholder((N, CI, H, W), name='data')
kernel = te.placeholder((CO, CI, KH, KW), name='kernel')
conv = conv2d_1x1_nchw(data, kernel, stride, out_dtype='float32')
s = te.create_schedule([conv.op])
batch = N * H * W
in_dim = CI
out_dim = CO
C = s[conv]
k, = s[conv].op.reduce_axis # I had this at the other axes before
num_thread = 128
ko, kf = s[conv].split(k, factor=num_thread)
partial_red = s.rfactor(conv, kf)
ax_b, ax_out_ch, ax_x, ax_y = s[conv].op.axis
s[conv].reorder(ax_b, ax_x, ax_y, ax_out_ch)
ax_bxy = s[conv].fuse(ax_b, ax_x, ax_y)
s[conv].bind(ax_out_ch, te.thread_axis("blockIdx.x"))
s[conv].bind(ax_bxy, te.thread_axis("blockIdx.y"))
tx = s[conv].op.reduce_axis[0]
thread_x = te.thread_axis("threadIdx.x")
s[conv].bind(tx, thread_x)
s[partial_red].compute_at(s[conv], tx)
s[conv].set_store_predicate(thread_x.var.equal(0))
return s, [data, kernel, conv]
I don’t recommend to copy this code, it doesn’t work well!
So I guess if that is now loosely OK, I need to set up the memory caching.
Best regards
Thomas