Hi everyone,
I’m working on a custom accelerator backend for TVM and ran into a scheduling / memory issue that I’m not sure how to model correctly.
Hardware background
On our hardware we have:
- A matrix compute unit
- A vector compute unit
- A shared on-chip buffer (local buffer) that both units can access
In a tiled matrix–vector computation, the idea is:
- The matrix unit computes one tile of data and writes it into the local buffer.
- The vector unit immediately consumes that same tile from the same local buffer, without going through global memory.
So in hardware there is no local → global → local roundtrip for that intermediate tile.
What I see in TVM
In my current TVM schedule / lowering, the behavior looks more like this:
- Matrix compute writes the tile to a local buffer.
- TVM then copies that local buffer back to global memory.
- Before the vector compute stage, TVM copies the data from global to another local buffer for the vector unit to use.
In other words, the intermediate result always gets stored to global memory and then reloaded, even though in our hardware a single shared local buffer is enough.
import tvm
from tvm import relax
from tvm import tir
from tvm.relax.frontend import nn
from tvm.relax.transform import LegalizeOps, AnnotateTIROpPattern
import numpy as np
from tvm.relax.expr_functor import PyExprMutator
class NNModule(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(48, 64)
self.relu1 = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
return x
mod, param_spec = NNModule().export_tvm(
spec={"forward": {"x": nn.spec.Tensor((1, 48), "float32")}}
)
pipeline = tvm.transform.Sequential([
LegalizeOps(),
AnnotateTIROpPattern(),
relax.transform.FuseOps(fuse_opt_level=4),
relax.transform.FuseTIR(),
])
mod_lowered = pipeline(mod)
sch = tvm.tir.Schedule(mod_lowered)
mod_lowered.show()
sch.work_on("fused_matmul_add_relu")
block_add = sch.get_block("T_add")
sch.compute_inline(block_add)
# move vector calculation after matrix
matmul_block = sch.get_block("matmul")
compute_block = sch.get_block("compute")
# tile
i0, i1, k = sch.get_loops(matmul_block)
i1_outer, i1_inner = sch.split(i1, factors=[None, 16])
k_outer, k_inner = sch.split(k, factors=[None, 16])
sch.reorder(i0, i1_outer, k_outer, k_inner, i1_inner)
# tile
i0_c, i1_c = sch.get_loops(compute_block)
i1_outer_c, i1_inner_c = sch.split(i1_c, factors=[None, 16])
sch.reverse_compute_at(compute_block, i1_outer)
sch.mod.show()
# add local/shared buffer
x_shared = sch.cache_read(matmul_block, 0, "shared") # x buffer
w_shared = sch.cache_read(matmul_block, 1, "shared") # weight buffer
matmul_local = sch.cache_write(matmul_block, 0, "local")
# 3.
compute_block = sch.get_block("compute")
# sch.cache_read(compute_block, 0, "local", [matmul_local])
bias_local = sch.cache_read(compute_block, 1, "local") # bias
result_local = sch.cache_write(compute_block, 0, "local")
# 4.
i0, i1_0, k_0, k_1, i1_1 = sch.get_loops(matmul_block)
sch.compute_at(x_shared, i1_0)
sch.compute_at(w_shared, i1_0)
sch.compute_at(bias_local, i1_0)
sch.mod.show()
the scheduled tir is as follows, and we can see
- matmul_intermediate[v0, v1] = matmul_intermediate_local[v0, v1]
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(48)), "float32"), permute_dims: T.Buffer((T.int64(48), T.int64(64)), "float32"), fc1_bias: T.Buffer((T.int64(64),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(64)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(64)))
x_shared = T.alloc_buffer((T.int64(1), T.int64(48)), scope="shared")
permute_dims_shared = T.alloc_buffer((T.int64(48), T.int64(64)), scope="shared")
matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(64)), scope="local")
fc1_bias_local = T.alloc_buffer((T.int64(64),), scope="local")
compute_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(64)), scope="local")
for i0, i1_0 in T.grid(T.int64(1), T.int64(4)):
for ax0 in range(T.int64(48)):
with T.block("x_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(48), ax0)
T.reads(x[v0, v1])
T.writes(x_shared[v0, v1])
x_shared[v0, v1] = x[v0, v1]
for ax0, ax1 in T.grid(T.int64(48), T.int64(16)):
with T.block("permute_dims_shared"):
v0 = T.axis.spatial(T.int64(48), ax0)
v1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax1)
T.reads(permute_dims[v0, v1])
T.writes(permute_dims_shared[v0, v1])
permute_dims_shared[v0, v1] = permute_dims[v0, v1]
for k_0, k_1, i1_1 in T.grid(T.int64(3), T.int64(16), T.int64(16)):
with T.block("matmul"):
v_i0 = T.axis.spatial(T.int64(1), i0)
v_i1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + i1_1)
v_k = T.axis.reduce(T.int64(48), k_0 * T.int64(16) + k_1)
T.reads(x_shared[v_i0, v_k], permute_dims_shared[v_k, v_i1])
T.writes(matmul_intermediate_local[v_i0, v_i1])
with T.init():
matmul_intermediate_local[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate_local[v_i0, v_i1] = matmul_intermediate_local[v_i0, v_i1] + x_shared[v_i0, v_k] * permute_dims_shared[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(16)):
with T.block("matmul_intermediate_local"):
v0 = T.axis.spatial(T.int64(1), i0 + ax0)
v1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax1)
T.reads(matmul_intermediate_local[v0, v1])
T.writes(matmul_intermediate[v0, v1])
matmul_intermediate[v0, v1] = matmul_intermediate_local[v0, v1]
for ax0 in range(T.int64(16)):
with T.block("fc1.bias_local"):
v0 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax0)
T.reads(fc1_bias[v0])
T.writes(fc1_bias_local[v0])
fc1_bias_local[v0] = fc1_bias[v0]
for ax0 in range(T.int64(16)):
with T.block("compute"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax0)
T.reads(matmul_intermediate[v_i0, v_i1], fc1_bias_local[v_i1])
T.writes(compute_intermediate_local[v_i0, v_i1])
compute_intermediate_local[v_i0, v_i1] = T.max(matmul_intermediate[v_i0, v_i1] + fc1_bias_local[v_i1], T.float32(0.0))
for ax0, ax1 in T.grid(T.int64(1), T.int64(64)):
with T.block("compute_intermediate_local"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute_intermediate_local[v0, v1])
T.writes(compute_intermediate[v0, v1])
compute_intermediate[v0, v1] = compute_intermediate_local[v0, v1]