Hi,
I am writing a schedule for conv2d (a 1x1 in this case), the compute defined by the default template for nn in topi (topi.nn.conv2d). My schedule is below:
import tvm
import tvm.te as te
import tvm.topi as topi
def schedule(outs,x,w):
outs = [outs] if isinstance(outs, te.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
te.schedule.AutoInlineInjective(s)
C = outs[0].op.input_tensors[0] # conv2d
R = outs[0] # relu
# declare caches
CC = s.cache_write(C, "local")
i0, i1, i2, i3 = R.op.axis
s[C].compute_at(s[R], i3)
s[R].reorder(i0, i2, i1, i3)
# split the workloads
ffor, ffir = s[R].split(i1, factor=16)
xxo, xxir = s[R].split(i3, factor=7)
s[R].reorder(ffor, xxo, ffir, xxir)
nn, ffi, yy, xxi = s[CC].op.axis
rc, ry, rx = s[CC].op.reduce_axis
rco, rci = s[CC].split(rc, factor=4)
s[CC].compute_at(s[R], xxo) #ffi
s[C].compute_at(s[R], xxo)
s[CC].reorder(rco, ffi, rci, ry, rx)
px, x = s[outs[0]].split(outs[0].op.axis[0], nparts=1)
s[outs[0]].bind(px, te.thread_axis("pipeline"))
return s
num_filter = 256
in_channels = 128
in_size = 28
filter_width = 1
filter_height = 1
x = tvm.te.placeholder([1, in_channels, in_size, in_size], name='in', dtype='float32')
w = tvm.te.placeholder([num_filter, in_channels, filter_width, filter_height], name='w', dtype='float32')
C = topi.nn.conv2d(x, w, 1, 0, 1)
out = tvm.topi.nn.relu(C)
s = schedule(out, x, w)
print(tvm.lower(s, [x, w], simple_mode=True))
Which yields the following from TVMv0.7:
primfn(in_1: handle, w_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {w: Buffer(w_2: Pointer(float32), float32, [256, 128, 1, 1], []),
in: Buffer(in_2: Pointer(float32), float32, [1, 128, 28, 28], [])}
buffer_map = {in_1: in, w_1: w} {
attr [compute: Pointer(float32)] "storage_scope" = "global";
allocate(compute, float32, [200704]);
attr [compute.local: Pointer(float32)] "storage_scope" = "local";
allocate(compute.local, float32, [112]);
attr [compute_1: Pointer(float32)] "storage_scope" = "global";
allocate(compute_1, float32, [112]);
attr [IterVar(pipeline: int32, (nullptr), "ThreadIndex", "pipeline")] "pipeline_exec_scope" = 1;
for (i2: int32, 0, 28) {
for (i1.outer: int32, 0, 16) {
for (i3.outer: int32, 0, 4) {
for (xx.c.init: int32, 0, 7) {
for (ff.c.init: int32, 0, 16) {
compute.local[((ff.c.init*7) + xx.c.init)] = 0f32
}
}
for (rc.outer: int32, 0, 32) {
for (xx.c: int32, 0, 7) {
for (ff.c: int32, 0, 16) {
for (rc.inner: int32, 0, 4) {
compute.local[((ff.c*7) + xx.c)] = ((float32*)compute.local[((ff.c*7) + xx.c)] + ((float32*)in_2[(((((rc.outer*3136) + (rc.inner*784)) + (i2*28)) + (i3.outer*7)) + xx.c)]*(float32*)w_2[((((i1.outer*2048) + (ff.c*128)) +(rc.outer*4)) + rc.inner)]))
}
}
}
}
for (ff: int32, 0, 16) {
for (xx: int32, 0, 7) {
compute_1[((ff*7) + xx)] = (float32*)compute.local[((ff*7) + xx)]
}
}
for (i1.inner: int32, 0, 16) {
for (i3.inner: int32, 0, 7) {
compute[(((((i1.outer*12544) + (i1.inner*784)) + (i2*28)) + (i3.outer*7)) + i3.inner)] = max((float32*)compute_1[((i1.inner*7) + i3.inner)], 0f32)
}
}
}
}
}
}
How would I get it so that it looks more like the following?
primfn(in_1: handle, w_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {w: Buffer(w_2: Pointer(float32), float32, [256, 128, 1, 1], []),
in: Buffer(in_2: Pointer(float32), float32, [1, 128, 28, 28], [])}
buffer_map = {in_1: in, w_1: w} {
attr [compute: Pointer(float32)] "storage_scope" = "global";
allocate(compute, float32, [200704]);
attr [compute.local: Pointer(float32)] "storage_scope" = "local";
allocate(compute.local, float32, [112]);
attr [compute_1: Pointer(float32)] "storage_scope" = "global";
allocate(compute_1, float32, [112]);
attr [IterVar(pipeline: int32, (nullptr), "ThreadIndex", "pipeline")] "pipeline_exec_scope" = 1;
for (i2: int32, 0, 28) {
for (i1.outer: int32, 0, 16) {
for (i3.outer: int32, 0, 4) {
for (xx.c.init: int32, 0, 7) {
for (ff.c.init: int32, 0, 16) {
compute_1[((ff.c.init*7) + xx.c.init)] = 0f32
}
}
for (rc.outer: int32, 0, 32) {
for (xx.c: int32, 0, 7) {
for (ff.c: int32, 0, 16) {
allocate(compute.local, float32, [1]);
compute.local[0] = 0f32
for (rc.inner: int32, 0, 4) {
compute.local[0] = ((float32*)compute.local[0] + ((float32*)in_2[(((((rc.outer*3136) + (rc.inner*784)) + (i2*28)) + (i3.outer*7)) + xx.c)]*(float32*)w_2[((((i1.outer*2048) + (ff.c*128)) +(rc.outer*4)) + rc.inner)]))
}
compute_1[((ff.c*7) + xx.c)] = (float32*)compute.local[0]
}
}
}
for (i1.inner: int32, 0, 16) {
for (i3.inner: int32, 0, 7) {
compute[(((((i1.outer*12544) + (i1.inner*784)) + (i2*28)) + (i3.outer*7)) + i3.inner)] = max((float32*)compute_1[((i1.inner*7) + i3.inner)], 0f32)
}
}
}
}
}
}
I know to get the compute.local variable to get allocated in a narrow scope that has to be declared parallel. But I cant do s[C].compute_at(s[CC]…) since that won’t result in a legal graph. I’ve tried using rfactor but I don’t think that’s what I need to be using here. Seems like it wou ld be an easy change but I have been struggling with this for days.
And if I s[CC].compute_at(s[C], s[C].op.axis[N]) then the reordering is lost. I had tried to reorder after doing the compute_at but it doesn’t seem like I can reorder loops that are in two different stages with each other once that is done.
If anyone can shed some light what schedule changes need to be made here, that would be greatly appreciated.