Hello everyone! I will keep the story short. I am working on the conv2d_int8 template with tensor cores support. Vthread’s usage leads to the broken kernels (wrong results) after tuning. I was able to narrow done the problem to the simplest use-case:
import tvm
# Simple algorithm [A[i] @ B for i in range(len(A))]
VIRTUAL_THREAD = 2
A_shape = (VIRTUAL_THREAD, 16, 16)
B_shape = (16, 16)
A = tvm.placeholder(A_shape, name='A', dtype='float16')
B = tvm.placeholder(B_shape, name='B', dtype='float16')
r = tvm.reduce_axis((0, 16), name='r')
C = tvm.compute((VIRTUAL_THREAD, 16, 16),
lambda vth, row, col: tvm.sum(
A[vth, row, r].astype("float32") * B[r, col].astype("float32"),
axis=[r]),
name="C")
s = tvm.create_schedule(C.op)
# print(tvm.lower(s, [A, B, C], simple_mode=True))
# Memory hierarchy
AF = s.cache_read(A, 'wmma.matrix_a', [C])
BF = s.cache_read(B, 'wmma.matrix_b', [C])
CF = s.cache_write(C, 'wmma.accumulator')
def intrin_wmma_load_matrix(scope):
n = 16
A = tvm.placeholder((n, n), name='A', dtype='float16')
BA = tvm.decl_buffer(A.shape, A.dtype, scope='global', data_alignment=32, offset_factor=256)
C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256)
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
BA = ins[0]
BC = outs[0]
ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync',
BC.data, n, n, n, BC.elem_offset // 256,
BA.access_ptr('r'), n, 'row_major'))
return ib.get()
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_gemm():
n = 16
A = tvm.placeholder((n, n), name='A', dtype='float16')
B = tvm.placeholder((n, n), name='B', dtype='float16')
k = tvm.reduce_axis((0, n), name="k")
C = tvm.compute((n, n),
lambda ii, jj:
tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k),
name='C')
BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256)
BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256)
BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256)
def intrin_func(ins, outs):
BA, BB = ins
BC, = outs
def init():
ib = tvm.ir_builder.create()
ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0))
return ib.get()
def update():
ib = tvm.ir_builder.create()
ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync',
BC.data, BC.elem_offset // 256,
BA.data, BA.elem_offset // 256,
BB.data, BB.elem_offset // 256,
BC.data, BC.elem_offset // 256))
return ib.get()
return update(), init(), update()
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
def intrin_wmma_store_matrix():
n = 16
A = tvm.placeholder((n, n), name='A', dtype='float32')
BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256)
C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256)
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
BA = ins[0]
BC = outs[0]
ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync',
BA.data, n, n, n, BA.elem_offset // 256,
BC.access_ptr('w'), n, 'row_major'))
return ib.get()
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
vth, row, col = C.op.axis
s[CF].compute_at(s[C], vth)
s[AF].compute_at(s[CF], CF.op.axis[-3])
s[BF].compute_at(s[CF], CF.op.axis[-3])
s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a'))
s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b'))
s[C].tensorize(row, intrin_wmma_store_matrix())
s[CF].tensorize(CF.op.axis[-2], intrin_wmma_gemm())
print(tvm.lower(s, [A, B, C], simple_mode=True))
s[C].bind(C.op.axis[-3], tvm.thread_axis('vthread'))
print(tvm.lower(s, [A, B, C], simple_mode=True))
So, the problem is - after vthread binding we got multiple load/store expressions for the same memory fragment and only one gemm call:
// attr [A.wmma.matrix_a] storage_scope = "wmma.matrix_a"
allocate A.wmma.matrix_a[float16 * 256]
// attr [B.wmma.matrix_b] storage_scope = "wmma.matrix_b"
allocate B.wmma.matrix_b[float16 * 256]
// attr [C.wmma.accumulator] storage_scope = "wmma.accumulator"
allocate C.wmma.accumulator[float32 * 256]
produce C {
for (vth, 0, 2) {
produce C.wmma.accumulator {
produce A.wmma.matrix_a {
tvm_load_matrix_sync(A.wmma.matrix_a, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), A, (vth*256), 256, 1), 16, "row_major")
}
produce B.wmma.matrix_b {
tvm_load_matrix_sync(B.wmma.matrix_b, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), B, 0, 256, 1), 16, "row_major")
}
tvm_mma_sync(C.wmma.accumulator, 0, A.wmma.matrix_a, 0, B.wmma.matrix_b, 0, C.wmma.accumulator, 0)
}
tvm_store_matrix_sync(C.wmma.accumulator, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), C, (vth*256), 256, 2), 16, "row_major")
}
}
turns into:
// attr [A.wmma.matrix_a] storage_scope = "wmma.matrix_a"
allocate A.wmma.matrix_a[float16 * 256]
// attr [B.wmma.matrix_b] storage_scope = "wmma.matrix_b"
allocate B.wmma.matrix_b[float16 * 256]
// attr [C.wmma.accumulator] storage_scope = "wmma.accumulator"
allocate C.wmma.accumulator[float32 * 256]
produce C {
produce C.wmma.accumulator {
produce A.wmma.matrix_a {
tvm_load_matrix_sync(A.wmma.matrix_a, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), A, 0, 256, 1), 16, "row_major")
tvm_load_matrix_sync(A.wmma.matrix_a, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), A, 256, 256, 1), 16, "row_major")
}
produce B.wmma.matrix_b {
tvm_load_matrix_sync(B.wmma.matrix_b, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), B, 0, 256, 1), 16, "row_major")
}
tvm_mma_sync(C.wmma.accumulator, 0, A.wmma.matrix_a, 0, B.wmma.matrix_b, 0, C.wmma.accumulator, 0)
}
tvm_store_matrix_sync(C.wmma.accumulator, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), C, 0, 256, 2), 16, "row_major")
tvm_store_matrix_sync(C.wmma.accumulator, 16, 16, 16, 0, tvm_access_ptr(type_annotation(), C, 256, 256, 2), 16, "row_major")
}
I am curious, is it a bug, or am I doing something wrong? The pipeline seems to be somewhat close to the one used in the TOPI.
Please, note that for simplicity, I omitted threadIdx bindings, this behavior can be observed even in the simple mode. Intrinsics are taken from the TVM tutorial with a single memory scope change (shared->global).