Hi, I’m writing a schedule about NHWC_Conv2d with tensor core using wmma
intrinsic. I got a problem here.
I need to split wmma_m
and wmma_n
axis from n, h, w and c axises. wmma_n
can be splitted from c
axis. To support batch=1, I want to split wmma_m
from w
axis. wmma_m
can be 8, 16 or 32, but w
is usually not a multiple of wmma_m
. So I tried to pad w
to padded_w
.
Here is my compute and schedule code, where I pad input
to input_pad_w
, get conv_padded
with GEMM and unpad conv_padded
to conv
. The problem is that shapes of conv
and conv_padded
are different and conv_padded
is marked inline
. So in the lower code, only part of the conv_padded
is calculated, which lead to the failure of the tensorize
.
def nhwc_tensorcore_int8_w(cfg, Input, Filter, strides, padding, dilation=1):
assert isinstance(strides, int) or len(strides) == 2
assert isinstance(dilation, int) or len(dilation) == 2
stride_h = stride_w = strides
dilation_h = dilation_w = dilation
batch, in_height, in_width, in_channel = get_const_tuple(Input.shape)
kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape)
assert (in_channel % 16 == 0 and num_filter % 8 == 0)
# compute output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
# pad at the side
if pad_top or pad_left:
input_pad_side = nn.pad(Input, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="input_pad_side")
else:
input_pad_side = Input
# compute gemm shape
K = kernel_h * kernel_w * in_channel
input_shape = (batch, out_height, out_width, K)
# im2col
if kernel_h == 1 and kernel_w == 1:
input_im2col = te.compute(input_shape, lambda b, h, w, y:
input_pad_side[b, stride_h * h, stride_w * w, y],
name="input_im2col")
else:
input_im2col = te.compute(input_shape, lambda b, h, w, y:
input_pad_side[b,
stride_h * h + (y // in_channel) // kernel_w,
stride_w * w + (y // in_channel) % kernel_w,
y % in_channel],
name="input_im2col")
# pad w
cfg.define_knob("MTile", [8, 16, 32, 64])
padding_factor = cfg["MTile"].val
pad_w = 0
if out_width % padding_factor != 0:
pad_w = padding_factor - (out_width % padding_factor)
if pad_w != 0:
input_pad_w = nn.pad(input_im2col, [0, 0, 0, 0], [0, 0, pad_w, 0], name="input_pad_w")
else:
input_pad_w = input_im2col
# GEMM
padded_w = out_width + pad_w
rx = te.reduce_axis((0, kernel_h), name='rx')
ry = te.reduce_axis((0, kernel_w), name='ry')
rc = te.reduce_axis((0, in_channel), name='rc')
conv_padded = te.compute((batch, out_height, padded_w, num_filter),
lambda b, h, w, y: te.sum(
input_pad_w[b, h, w, (rx * kernel_w + ry) * in_channel + rc].astype("int32") * \
Filter[rx, ry, rc, y].astype("int32"), axis=[rx, ry, rc]),
name="conv_padded")
# unpad
conv = te.compute((batch, out_height, out_width, num_filter),
lambda b, h, w, c: conv_padded[b, h, w, c],
name="conv", tag="conv2d_nhwc_tensorcore_int8_w")
return conv
def schedule_nhwc_tensorcore_int8_w(cfg, s, conv):
# todo: support unpad
batch, out_height, out_width, num_filter = get_const_tuple(conv.shape)
conv_padded = s[conv].op.input_tensors[0]
rx, ry, rc = s[conv_padded].op.reduce_axis
input_pad_w, filter = s[conv_padded].op.input_tensors
if input_pad_w.op.name == "input_pad_w":
input_im2col = s[input_pad_w].op.input_tensors[0]
else:
input_im2col = input_pad_w
input_pad_side = s[input_im2col].op.input_tensors[0]
# compute inline
s[conv_padded].compute_inline()
if input_pad_w.op.name == "input_pad_w":
s[input_pad_w].compute_inline()
s[input_im2col].compute_inline()
if input_pad_side.op.name == "input_pad_side":
s[input_pad_side].compute_inline()
# Designate the memory hierarchy
AS = s.cache_read(input_pad_w, "shared", [conv_padded])
WS = s.cache_read(filter, "shared", [conv_padded])
AF = s.cache_read(AS, "wmma.matrix_a", [conv_padded])
WF = s.cache_read(WS, "wmma.matrix_b", [conv_padded])
CF_padded = s.cache_write(conv_padded, "wmma.accumulator")
CS_padded = s.cache_read(CF_padded, "shared", [conv_padded])
if conv.op in s.outputs:
output = conv
CS = s.cache_write(conv, "shared")
OL = CS
else:
output = s.outputs[0].output(0)
s[conv].set_scope("shared")
OL = conv
# Schedule for autotvm
# cfg.define_knob("MTile", [8, 16, 32, 64])
cfg.define_knob("NTile", [32, 64, 128, 256])
cfg.define_knob("KTile", [32, 64, 128, 256])
cfg.define_knob("block_row_warps", [1, 2, 4])
cfg.define_knob("block_col_warps", [1, 2, 4])
cfg.define_knob("chunk", [1, 2, 4, 8])
cfg.define_knob("vector_width", [1, 4, 8, 16])
cfg.define_knob("offset", [0, 16])
cfg.define_knob("wmma_m", [8, 16, 32])
cfg.define_knob("vthread", [1, 2])
MTile = cfg["MTile"].val
NTile = cfg["NTile"].val
KTile = cfg["KTile"].val
block_row_warps = cfg["block_row_warps"].val
block_col_warps = cfg["block_col_warps"].val
chunk = cfg["chunk"].val
vector_width = cfg["vector_width"].val
offset = cfg["offset"].val
vthread = cfg["vthread"].val
wmma_m = cfg["wmma_m"].val
# wmma_m = 8
wmma_k = 16
wmma_n = 16
if wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8
wmma_shape = (wmma_m, wmma_n, wmma_k)
warp_size = 32
block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
block_z = te.thread_axis("blockIdx.z")
thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")
thread_z = te.thread_axis("threadIdx.z")
thread_vx = te.thread_axis((0, vthread), "vthread", name="vx")
thread_vy = te.thread_axis((0, vthread), "vthread", name="vy")
# Define the intrin strides
def get_strides(extents):
return [np.prod(extents[i:]).tolist() for i in range(len(extents))]
AS_align = KTile + offset
WS_align = NTile + offset
CS_align = NTile + offset
NFrag = (NTile + block_col_warps - 1) // block_col_warps
KFrag = chunk * wmma_k
AS_shape = [wmma_m, wmma_k]
AF_shape = [wmma_m, wmma_k]
WS_shape = [wmma_k, wmma_n]
WF_shape = [wmma_k, wmma_n]
CS_shape = [wmma_m, wmma_n]
CF_shape = [wmma_m, wmma_n]
AS_strides = get_strides([AS_align, 1])
AF_strides = get_strides([KFrag, 1])
WS_strides = get_strides([WS_align, 1])
WF_strides = get_strides([NFrag, 1])
CF_strides = get_strides([NFrag, 1])
CS_strides = get_strides([NTile, 1])
# Schedule for output
b, h, w, n = output.op.axis
block_k = s[output].fuse(b, h)
block_j, m = s[output].split(w, factor=MTile)
block_i, n = s[output].split(n, factor=NTile)
s[output].reorder(block_k, block_j, block_i, m, n)
t = s[output].fuse(m, n)
t, ti = s[output].split(t, factor=vector_width)
t, tx = s[output].split(t, factor=warp_size)
t, ty = s[output].split(t, factor=block_col_warps)
t, tz = s[output].split(t, factor=block_row_warps)
s[output].bind(block_k, block_z)
s[output].bind(block_j, block_y)
s[output].bind(block_i, block_x)
s[output].bind(tz, thread_z)
s[output].bind(ty, thread_y)
s[output].bind(tx, thread_x)
# Schedule for conv
s[OL].compute_at(s[output], block_i)
b, h, w, n = OL.op.axis
s[OL].storage_align(h, CS_align - 1, CS_align)
tz, m = s[OL].split(w, nparts=block_row_warps)
ty, n = s[OL].split(n, nparts=block_col_warps)
s[OL].reorder(tz, ty, m, n)
t = s[OL].fuse(m, n)
t, ti = s[OL].split(t, factor=vector_width)
tx, _ = s[OL].split(t, nparts=warp_size)
s[OL].bind(tz, thread_z)
s[OL].bind(ty, thread_y)
s[OL].bind(tx, thread_x)
# Schedule for wmma store
s[CS_padded].compute_at(s[OL], h)
b, h, w, n = CS_padded.op.axis
s[CS_padded].storage_align(w, CS_align - 1, CS_align)
# _, w = s[CS_padded].split(w, factor=MTile)
tz, m = s[CS_padded].split(w, nparts=block_row_warps)
ty, n = s[CS_padded].split(n, nparts=block_col_warps)
mo_cs, mi_cs = s[CS_padded].split(m, factor=wmma_m)
no_cs, ni_cs = s[CS_padded].split(n, factor=wmma_n)
s[CS_padded].reorder(tz, ty, mo_cs, no_cs, mi_cs, ni_cs)
s[CS_padded].bind(tz, thread_z)
s[CS_padded].bind(ty, thread_y)
# Schedule for wmma compute
s[CF_padded].compute_at(s[CS_padded], ty)
b, h, w, n = CF_padded.op.axis
mo_cf, mi_cf = s[CF_padded].split(w, factor=wmma_m)
no_cf, ni_cf = s[CF_padded].split(n, factor=wmma_n)
kfo, kfi = s[CF_padded].split(rc, factor=KTile)
kfi, _kfi = s[CF_padded].split(kfi, factor=wmma_k)
_kfo, kfi = s[CF_padded].split(kfi, factor=chunk)
s[CF_padded].reorder(rx, ry, kfo, _kfo, kfi, mo_cf, no_cf, mi_cf, ni_cf, _kfi)
s[AS].compute_at(s[CF_padded], kfo)
s[WS].compute_at(s[CF_padded], kfo)
s[AF].compute_at(s[CF_padded], _kfo)
s[WF].compute_at(s[CF_padded], _kfo)
# Schedule for input's shared memory
b, h, w, k = AS.op.axis
s[AS].storage_align(w, AS_align - 1, AS_align)
_, w = s[AS].split(w, factor=MTile)
tz, m = s[AS].split(w, nparts=block_row_warps)
ty, k = s[AS].split(k, nparts=block_col_warps)
s[AS].reorder(tz, ty, m, k)
t = s[AS].fuse(m, k)
t, ti = s[AS].split(t, factor=vector_width)
tx, _ = s[AS].split(t, nparts=warp_size)
s[AS].bind(tz, thread_z)
s[AS].bind(ty, thread_y)
s[AS].bind(tx, thread_x)
s[AS].vectorize(ti)
# Schedule for weight's shared memory
h, w, ic, oc = WS.op.axis
s[WS].storage_align(ic, WS_align - 1, WS_align)
tvx, ic = s[WS].split(ic, nparts=vthread)
tvy, oc = s[WS].split(oc, nparts=vthread)
tz, ic = s[WS].split(ic, nparts=block_row_warps)
ty, oc = s[WS].split(oc, nparts=block_col_warps)
s[WS].reorder(tvx, tvy, tz, ty, ic, oc)
t = s[WS].fuse(ic, oc)
t, ti = s[WS].split(t, factor=vector_width)
tx, _ = s[WS].split(t, nparts=warp_size)
s[WS].bind(tz, thread_z)
s[WS].bind(ty, thread_y)
s[WS].bind(tx, thread_x)
s[WS].bind(tvx, thread_vx)
s[WS].bind(tvy, thread_vy)
s[WS].vectorize(ti)
# Schedule for input's local memory
b, h, w, k = AF.op.axis
mo_af, mi_af = s[AF].split(w, factor=wmma_m)
ko_af, ki_af = s[AF].split(k, factor=wmma_k)
s[AF].reorder(mo_af, ko_af, mi_af, ki_af)
# Schedule for weight's local memory
h, w, ic, oc = WF.op.axis
ico_wf, ici_wf = s[WF].split(ic, factor=wmma_k)
oco_wf, oci_wf = s[WF].split(oc, factor=wmma_n)
s[WF].reorder(h, w, ico_wf, oco_wf, ici_wf, oci_wf)
# tensorize the wmma process
AF_gemm = te.placeholder(AF_shape, name='A', dtype='int8')
WF_gemm = te.placeholder(WF_shape, name='B', dtype='int8')
k_gemm = te.reduce_axis((0, wmma_k), name='k')
CF_compute = te.compute(CF_shape,
lambda m, n: te.sum(AF_gemm[m, k_gemm].astype('int32') * WF_gemm[k_gemm, n].astype('int32'),
axis=k_gemm),
name='C')
s[AF].tensorize(mi_af, intrin_wmma_load_matrix_A(AF_strides, AS_strides, wmma_shape, 'row_major',
AS_shape, AF_shape, 'int8', 'shared'))
s[WF].tensorize(ici_wf, intrin_wmma_load_matrix_W(WF_strides, WS_strides, wmma_shape, 'row_major',
WS_shape, WF_shape, 'int8', 'shared'))
s[CS_padded].tensorize(mi_cs, intrin_wmma_store_matrix(CS_strides, CF_strides, wmma_shape, 'row_major',
'int32', CF_shape, CS_shape, 'shared'))
s[CF_padded].tensorize(mi_cf, intrin_wmma_gemm(AF_gemm, WF_gemm, CF_compute, AF_strides, WF_strides,
CF_strides, wmma_shape, 'row_major', 'row_major', 'row_major'))
For example, when batch=1, in_size=28, in_channel=128, num_filter=128, kernel_size=3, MTile=32
, the program has an error: TVMError: Tensorize failed, split condition tir.likely(((ax2.inner.inner + (ax2.inner.outer*8)) < 28)) relies on var defined inside tensorize scope.
Here’s the lower code before tensorize
.
primfn(input_1: handle, weight_1: handle, bias_1: handle, compute_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {bias: Buffer(bias_2: handle, int32, [1, 1, 1, 128], []),
input: Buffer(input_2: handle, int8, [1, 28, 28, 128], []),
compute: Buffer(compute_2: handle, int8, [1, 28, 28, 128], []),
weight: Buffer(weight_2: handle, int8, [3, 3, 128, 128], [])}
buffer_map = {input_1: input, weight_1: weight, bias_1: bias, compute_1: compute} {
attr [IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")] "thread_extent" = 28;
attr [conv_padded.wmma.accumulator: handle] "storage_scope" = "wmma.accumulator";
allocate(conv_padded.wmma.accumulator, int32, [896]);
attr [input_pad_w.shared: handle] "storage_scope" = "shared";
allocate(input_pad_w.shared, int8, [896]);
attr [weight.shared: handle] "storage_scope" = "shared";
allocate(weight.shared, int8, [1024]);
attr [input_pad_w.shared.wmma.matrix_a: handle] "storage_scope" = "wmma.matrix_a";
allocate(input_pad_w.shared.wmma.matrix_a, int8, [448]);
attr [weight.shared.wmma.matrix_b: handle] "storage_scope" = "wmma.matrix_b";
allocate(weight.shared.wmma.matrix_b, int8, [512]);
attr [conv_padded.wmma.accumulator.shared: handle] "storage_scope" = "shared";
allocate(conv_padded.wmma.accumulator.shared, int32, [896]);
attr [conv: handle] "storage_scope" = "shared";
allocate(conv, int32, [900]);
attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 1;
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 4 {
attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1 {
for (w.c.outer.init: int32, 0, 4) {
for (w.c.inner.init: int32, 0, 8) {
for (y.c.inner.init: int32, 0, 32) {
if @tir.likely((((w.c.outer.init*8) + w.c.inner.init) < 28), dtype=bool) {
conv_padded.wmma.accumulator[(((w.c.outer.init*256) + (w.c.inner.init*32)) + y.c.inner.init)] = 0
}
}
}
}
for (rx: int32, 0, 3) {
for (ry: int32, 0, 3) {
for (rc.outer: int32, 0, 4) {
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
for (ax2.inner.inner.ax3.inner.fused.outer.inner: int32, 0, 32) {
if @tir.likely((threadIdx.x < 28), dtype=bool) {
input_pad_w.shared[((threadIdx.x*32) + ax2.inner.inner.ax3.inner.fused.outer.inner)] = @tir.if_then_else(((((1 <= (blockIdx.z
+ rx)) && ((blockIdx.z + rx) < 29)) && (1 <= (threadIdx.x + ry))) && ((threadIdx.x + ry) < 29)), (int8*)input_2[(((((((blockIdx.z*3584) + (r
x*3584)) + (threadIdx.x*128)) + (ry*128)) + (rc.outer*32)) + ax2.inner.inner.ax3.inner.fused.outer.inner) - 3712)]), 0i8, dtype=int8)
}
}
attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
for (ax2.inner.inner.ax3.inner.inner.fused.outer.inner: int32, 0, 32) {
weight.shared[((threadIdx.x*32) + ax2.inner.inner.ax3.inner.inner.fused.outer.inner)] = (int8*)weight_2[((((((rx*49152) + (ry*1
6384)) + (rc.outer*4096)) + (threadIdx.x*128)) + (blockIdx.x*32)) + ax2.inner.inner.ax3.inner.inner.fused.outer.inner)])
}
for (rc.inner.outer.outer: int32, 0, 2) {
for (ax2.outer: int32, 0, 4) {
for (ax2.inner: int32, 0, 8) {
for (ax3.inner: int32, 0, 16) {
if @tir.likely((((ax2.outer*8) + ax2.inner) < 28), dtype=bool) {
input_pad_w.shared.wmma.matrix_a[(((ax2.outer*128) + (ax2.inner*16)) + ax3.inner)] = (int8*)input_pad_w.shared[((((ax2.
outer*256) + (ax2.inner*32)) + (rc.inner.outer.outer*16)) + ax3.inner)])
}
}
}
}
for (ax2.inner_1: int32, 0, 16) {
for (ax3.inner_1: int32, 0, 32) {
weight.shared.wmma.matrix_b[((ax2.inner_1*32) + ax3.inner_1)] = (int8*)weight.shared[(((rc.inner.outer.outer*512) + (ax2.in
ner_1*32)) + ax3.inner_1)])
}
}
for (w.c.outer: int32, 0, 4) {
for (w.c.inner: int32, 0, 8) {
for (y.c.inner: int32, 0, 32) {
for (rc.inner.inner: int32, 0, 16) {
if @tir.likely((((w.c.outer*8) + w.c.inner) < 28), dtype=bool) {
conv_padded.wmma.accumulator[(((w.c.outer*256) + (w.c.inner*32)) + y.c.inner)] = ((int32*)conv_padded.wmma.accumulato
r[(((w.c.outer*256) + (w.c.inner*32)) + y.c.inner)]) + (cast(int32, (int8*)input_pad_w.shared.wmma.matrix_a[(((w.c.outer*128) + (w.c.inner*16
)) + rc.inner.inner)]))*cast(int32, (int8*)weight.shared.wmma.matrix_b[((rc.inner.inner*32) + y.c.inner)]))))
}
}
}
}
}
}
}
}
}
for (ax2.inner.outer: int32, 0, 4) {
for (ax2.inner.inner: int32, 0, 8) {
for (ax3.inner.inner: int32, 0, 32) {
if @tir.likely((((ax2.inner.outer*8) + ax2.inner.inner) < 28), dtype=bool) {
conv_padded.wmma.accumulator.shared[(((ax2.inner.outer*256) + (ax2.inner.inner*32)) + ax3.inner.inner)] = (int32*)conv_padded.w
mma.accumulator[(((ax2.inner.outer*256) + (ax2.inner.inner*32)) + ax3.inner.inner)])
}
}
}
}
}
attr [IterVar(threadIdx.z, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
attr [IterVar(threadIdx.y, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1;
attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
for (w.inner.c.inner.fused.outer.inner: int32, 0, 28) {
conv[((threadIdx.x*28) + w.inner.c.inner.fused.outer.inner)] = (int32*)conv_padded.wmma.accumulator.shared[((threadIdx.x*28) + w.inner.
c.inner.fused.outer.inner)])
}
for (i2.inner.i3.inner.fused.outer.outer.outer.outer: int32, 0, 32) {
attr [IterVar(threadIdx.z, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
attr [IterVar(threadIdx.y, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1;
attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
if @tir.likely((i2.inner.i3.inner.fused.outer.outer.outer.outer < 28), dtype=bool) {
compute_2[((((blockIdx.z*3584) + (i2.inner.i3.inner.fused.outer.outer.outer.outer*128)) + (blockIdx.x*32)) + threadIdx.x)] = cast(int
8, max(min(@tir.round((cast(float32, ((int32*)conv[((i2.inner.i3.inner.fused.outer.outer.outer.outer*32) + threadIdx.x)]) + (int32*)bias_2[((
blockIdx.x*32) + threadIdx.x)])))*0.00254f32), dtype=float32), 127f32), 0f32))
}
}
}
}
I want to remove these likely statements, such as if @tir.likely((((ax2.outer*8) + ax2.inner) < 28), dtype=bool) and if @tir.likely((((w.c.outer*8) + w.c.inner) < 28), dtype=bool).
Does anyone know how to deal with the schedule of different shapes in in contiguous stages?