In certain TensorIR transformation passes, a block’s attr is used as an input to enable some optimization or transformation. For example, the inject_ptx_async_copy
pass requires the block to have the attribute async_scope
set to 1. like the unittest example provided as follows:
pythonCopy code
@T.prim_func
def ptx_global_to_shared_dyn_copy_fp16x8(
A: T.Buffer[(32, 128), "float16"],
B: T.Buffer[(32, 128), "float16"],
C: T.Buffer[(32, 128), "float16"],
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
bx = T.env_thread("blockIdx.x")
tx = T.env_thread("threadIdx.x")
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
T.reads(A[0:32, 0:128], B[0:32, 0:128])
T.writes(C[0:32, 0:128])
T.attr("default", "async_scope", 1)
for i in T.serial(16):
for j in T.vectorized(8):
A_shared[tx, i * 8 + j] = A[tx, i * 8 + j]
B_shared[tx, i * 8 + j] = B[tx, i * 8 + j]
T.evaluate(T.ptx_commit_group(dtype=""))
T.evaluate(T.ptx_wait_group(0, dtype=""))
for i in range(128):
C[tx, i] = A_shared[tx, i] + B_shared[tx, i]
However, in some cases, it may not be desirable to manually set the T.attr
attribute, as in the pure compute script example below:
pythonCopy code
@T.prim_func
def func(A: T.Buffer[(256, 256), "int8"], B: T.Buffer[(256, 256), "int8"], C: T.Buffer[(256, 256), "int32"]):
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
for i, j, k in T.grid(256, 256, 256):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = 0
C[vi, vj] = C[vi, vj] + T.Cast("int32", A[vi, vk]) * T.Cast("int32", B[vk, vj])
To optimize memory utilization and maintain several global to shared stages, cache read and cache write primitives can be utilized as shown in the example below:
for ax0, ax1 in T.grid(2048, 2048):
with T.block("A_global_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A_global[v0, v1])
T.writes(A_global_shared[v0, v1])
A_global_shared[v0, v1] = A_global[v0, v1]
I didn’t find any methods or primitives to set the blocks’ attr, however I find apply sch.annotate pritimive into the block I got the following new block:
for ax0, ax1 in T.grid(2048, 2048):
with T.block("A_global_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A_global[v0, v1])
T.writes(A_global_shared[v0, v1])
T.block_attr({"attr", "value"})
A_global_shared[v0, v1] = A_global[v0, v1]
it doesn’t work because the pass won’t consider the block_attr
as attr
, and in simpler terms, the current method I’m using to handle this stage does not require the specific type of transformation you’re asking about. Instead, I leveraged tensorize primitives to cover this stage.
def get_aync_copy_intrin(dtype):
if dtype == "float32":
elems = 4
elif dtype == "float16":
elems = 8
elif dtype == "int8":
elems = 16
else:
raise ValueError("Unsupported dtype: {}".format(dtype))
@T.prim_func
def async_copy_desc(global_handle: T.handle, shared_handle: T.handle) -> None:
globalVar = T.match_buffer(
global_handle,
(elems),
dtype,
align=64,
offset_factor=elems,
scope="global",
)
sharedVar = T.match_buffer(
shared_handle, (elems), dtype, align=64, offset_factor=16, scope="shared"
)
with T.block("root"):
T.reads(globalVar[0:elems])
T.writes(sharedVar[0:elems])
for ax0 in T.vectorized(elems):
with T.block("shared_warp"):
v0 = T.axis.remap("S", [ax0])
T.reads(globalVar[v0])
T.writes(sharedVar[v0])
sharedVar[v0] = globalVar[v0]
@T.prim_func
def async_copy_imlp(global_handle: T.handle, shared_handle: T.handle) -> None:
globalVar = T.match_buffer(
global_handle,
(elems),
dtype,
align=64,
offset_factor=elems,
scope="global",
)
sharedVar = T.match_buffer(
shared_handle, (elems), dtype, align=64, offset_factor=elems, scope="shared"
)
with T.block("root"):
T.reads(globalVar[0:elems])
T.writes(sharedVar[0:elems])
T.attr(0, "async_scope", 1)
for ax0 in T.vectorized(elems):
with T.block("shared_warp"):
v0 = T.axis.remap("S", [ax0])
T.reads(globalVar[v0])
T.writes(sharedVar[v0])
sharedVar[v0] = globalVar[v0]
return async_copy_desc, async_copy_imlp
but I think there must be some better choice, any suggestions please?