Take this code for example:
import numpy as np
import tvm
from tvm.autotvm.tuner import XGBTuner
from tvm import relay, autotvm
import pytest
def test_dense_autotvm():
target = tvm.target.cuda()
batch, in_dim, out_dim = 16384, 768, 768
data_shape = (batch, in_dim)
weight_shape = (out_dim, in_dim)
data = relay.var("data", shape=data_shape, dtype="float16")
weight = relay.var("weight", shape=weight_shape, dtype="float16")
dense_val = relay.nn.dense(data, weight, out_dtype="float32")
func = relay.Function(relay.analysis.free_vars(dense_val), dense_val)
mod = tvm.IRModule()
mod['main'] = func
log_filename = "dense_autotvm.log"
tmp_logfile = "dense_autotvm.log" + ".tmp"
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(timeout=10, n_parallel=1),
runner=autotvm.LocalRunner(
number=1, repeat=2, timeout=10, min_repeat_ms=100),
)
tasks = autotvm.task.extract_from_program(
func, target=target, params=None, ops=None)
tsk = tasks[2]
tuner_obj = XGBTuner(tsk, loss_type="rank")
tuner_obj.tune(n_trial=10, early_stopping=0, measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(10, ),
autotvm.callback.log_to_file(tmp_logfile),
])
when run this program, pytest -s test_my_dense.py
, the erorr may be seen like:
[17:52:54] .../verify_gpu_code.cc:298: VerifyGPUCode err: Extent of threadIdx.y (1) does not match the bound 16
[17:52:54] .../verify_gpu_code.cc:298: VerifyGPUCode err: Extent of threadIdx.x (16) does not match the bound 1
[17:52:54] .../verify_gpu_code.cc:298: VerifyGPUCode err: Used shared memory per block (2146304) is greater than the allowed maximum (49152)
test device should be in T4.
print the llvm ir and you will see the log like below, to make the ir more concise, i comment the unroll and double buffer.
[17:52:54] /home/qqqqq/source_code/tvm/src/tir/analysis/verify_gpu_code.cc:298: VerifyGPUCode err: Used shared memory per block (1609728) is greater than the allowed maximum (49152)
Current/Best: 0.00/ 0.00 GFLOPS | Progress: (9/10) | 2.76 s2 @main = primfn(placeholder_2: handle, placeholder_3: handle, T_matmul_NT_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {T_matmul_NT: Buffer(T_matmul_NT_2: Pointer(float32), float32, [12582912], []),
placeholder_1: Buffer(placeholder_4: Pointer(float16), float16, [589824], []),
placeholder: Buffer(placeholder_5: Pointer(float16), float16, [12582912], [])}
buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1, T_matmul_NT_1: T_matmul_NT} {
attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 4;
allocate(T_matmul_NT.local: Pointer(local float32), float32, [98304]), storage_scope = local;
allocate(placeholder.shared: Pointer(shared float16), float16, [1048576]), storage_scope = shared;
allocate(placeholder.d.shared: Pointer(shared float16), float16, [24576]), storage_scope = shared;
allocate(placeholder.shared.local: Pointer(local float16), float16, [131072]), storage_scope = local;
allocate(placeholder.d.shared.local: Pointer(local float16), float16, [192]), storage_scope = local;
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 2;
attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 16;
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1 {
for (i.c.init: int32, 0, 8) {
for (j.c.init: int32, 0, 3) {
for (vthread.s: int32, 0, 1024) {
let cse_var_1: int32 = (((vthread.s*24) + (i.c.init*3)) + j.c.init)
{
T_matmul_NT.local_1: Buffer(T_matmul_NT.local, float32, [14155776], [], scope="local", align=64)[cse_var_1] = 0f32
T_matmul_NT.local_1[(cse_var_1 + 24576)] = 0f32
T_matmul_NT.local_1[(cse_var_1 + 49152)] = 0f32
T_matmul_NT.local_1[(cse_var_1 + 73728)] = 0f32
}
}
}
}
for (k.outer: int32, 0, 6) {
attr [IterVar(threadIdx.y_1: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1;
for (ax0.inner: int32, 0, 8192) {
for (ax1.outer: int32, 0, 32) {
attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1;
for (ax1.inner.inner: int32, 0, 4) {
let cse_var_2: int32 = (ax1.outer*4)
placeholder.shared_1: Buffer(placeholder.shared, float16, [1048576], [], scope="shared")[(((ax0.inner*128) + cse_var_2) + ax1.inner.inner)] = placeholder[(((((blockIdx.x*6291456) + (ax0.inner*768)) + (k.outer*128)) + cse_var_2) + ax1.inner.inner)]
}
}
}
attr [IterVar(threadIdx.y_2: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 16;
for (ax0.inner_1: int32, 0, 12) {
for (ax1.outer_1: int32, 0, 2) {
attr [IterVar(threadIdx.x_2: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 16;
for (ax1.inner.inner_1: int32, 0, 4) {
let cse_var_3: int32 = (ax1.outer_1*64)
placeholder.d.shared_1: Buffer(placeholder.d.shared, float16, [24576], [], scope="shared")[(((((threadIdx.y_2*1536) + (ax0.inner_1*128)) + cse_var_3) + (threadIdx.x_2*4)) + ax1.inner.inner_1)] = placeholder_1[(((((((blockIdx.y*147456) + (threadIdx.y_2*9216)) + (ax0.inner_1*768)) + (k.outer*128)) + cse_var_3) + (threadIdx.x_2*4)) + ax1.inner.inner_1)]
}
}
}
for (k.inner.outer: int32, 0, 8) {
for (ax0: int32, 0, 8) {
for (ax1: int32, 0, 16) {
for (vthread.s_1: int32, 0, 1024) {
placeholder.shared.local_1: Buffer(placeholder.shared.local, float16, [16384], [], scope="local")[(((vthread.s_1*128) + (ax0*16)) + ax1)] = placeholder.shared_1[((((vthread.s_1*1024) + (ax0*128)) + (k.inner.outer*16)) + ax1)]
}
}
}
for (ax0_1: int32, 0, 3) {
for (ax1_1: int32, 0, 16) {
let cse_var_4: int32 = ((ax0_1*16) + ax1_1)
{
placeholder.d.shared.local_1: Buffer(placeholder.d.shared.local, float16, [2304], [], scope="local", align=64)[cse_var_4] = placeholder.d.shared_1[((((threadIdx.y*384) + (ax0_1*128)) + (k.inner.outer*16)) + ax1_1)]
placeholder.d.shared.local_1[(cse_var_4 + 48)] = placeholder.d.shared_1[(((((threadIdx.y*384) + (ax0_1*128)) + (k.inner.outer*16)) + ax1_1) + 6144)]
placeholder.d.shared.local_1[(cse_var_4 + 96)] = placeholder.d.shared_1[(((((threadIdx.y*384) + (ax0_1*128)) + (k.inner.outer*16)) + ax1_1) + 12288)]
placeholder.d.shared.local_1[(cse_var_4 + 144)] = placeholder.d.shared_1[(((((threadIdx.y*384) + (ax0_1*128)) + (k.inner.outer*16)) + ax1_1) + 18432)]
}
}
}
for (k.inner.inner: int32, 0, 16) {
for (i.c: int32, 0, 8) {
for (j.c: int32, 0, 3) {
for (vthread.s_2: int32, 0, 1024) {
let cse_var_10: int32 = ((j.c*16) + k.inner.inner)
let cse_var_9: int32 = (((vthread.s_2*24) + (i.c*3)) + j.c)
let cse_var_8: int32 = (((vthread.s_2*128) + (i.c*16)) + k.inner.inner)
let cse_var_7: int32 = (cse_var_9 + 24576)
let cse_var_6: int32 = (cse_var_9 + 49152)
let cse_var_5: int32 = (cse_var_9 + 73728)
{
T_matmul_NT.local_1[cse_var_9] = (T_matmul_NT.local_1[cse_var_9] + (cast(float32, placeholder.shared.local_1[cse_var_8])*cast(float32, placeholder.d.shared.local_1[cse_var_10])))
T_matmul_NT.local_1[cse_var_7] = (T_matmul_NT.local_1[cse_var_7] + (cast(float32, placeholder.shared.local_1[cse_var_8])*cast(float32, placeholder.d.shared.local_1[(cse_var_10 + 48)])))
T_matmul_NT.local_1[cse_var_6] = (T_matmul_NT.local_1[cse_var_6] + (cast(float32, placeholder.shared.local_1[cse_var_8])*cast(float32, placeholder.d.shared.local_1[(cse_var_10 + 96)])))
T_matmul_NT.local_1[cse_var_5] = (T_matmul_NT.local_1[cse_var_5] + (cast(float32, placeholder.shared.local_1[cse_var_8])*cast(float32, placeholder.d.shared.local_1[(cse_var_10 + 144)])))
}
}
}
}
}
}
}
for (j.inner.inner.inner: int32, 0, 3) {
for (i.inner.inner.inner: int32, 0, 8) {
for (vthread.s_3: int32, 0, 1024) {
let cse_var_11: int32 = (((vthread.s_3*24) + (i.inner.inner.inner*3)) + j.inner.inner.inner)
{
T_matmul_NT[((((((blockIdx.x*6291456) + (vthread.s_3*6144)) + (i.inner.inner.inner*768)) + (blockIdx.y*192)) + (threadIdx.y*3)) + j.inner.inner.inner)] = T_matmul_NT.local_1[cse_var_11]
T_matmul_NT[(((((((blockIdx.x*6291456) + (vthread.s_3*6144)) + (i.inner.inner.inner*768)) + (blockIdx.y*192)) + (threadIdx.y*3)) + j.inner.inner.inner) + 48)] = T_matmul_NT.local_1[(cse_var_11 + 24576)]
T_matmul_NT[(((((((blockIdx.x*6291456) + (vthread.s_3*6144)) + (i.inner.inner.inner*768)) + (blockIdx.y*192)) + (threadIdx.y*3)) + j.inner.inner.inner) + 96)] = T_matmul_NT.local_1[(cse_var_11 + 49152)]
T_matmul_NT[(((((((blockIdx.x*6291456) + (vthread.s_3*6144)) + (i.inner.inner.inner*768)) + (blockIdx.y*192)) + (threadIdx.y*3)) + j.inner.inner.inner) + 144)] = T_matmul_NT.local_1[(cse_var_11 + 73728)]
}
}
}
}
}
}
so move data and weight from global memory to shared memory, the strange tx, ty (1, 1) and (16, 16) might be some strange