Primitives to set tir block attr?

In certain TensorIR transformation passes, a block’s attr is used as an input to enable some optimization or transformation. For example, the inject_ptx_async_copy pass requires the block to have the attribute async_scope set to 1. like the unittest example provided as follows:

pythonCopy code

@T.prim_func
def ptx_global_to_shared_dyn_copy_fp16x8(
    A: T.Buffer[(32, 128), "float16"],
    B: T.Buffer[(32, 128), "float16"],
    C: T.Buffer[(32, 128), "float16"],
) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    bx = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(bx, 1)
    T.launch_thread(tx, 32)
    with T.block():
        A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
        B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
        T.reads(A[0:32, 0:128], B[0:32, 0:128])
        T.writes(C[0:32, 0:128])

        T.attr("default", "async_scope", 1)
        for i in T.serial(16):
            for j in T.vectorized(8):
                A_shared[tx, i * 8 + j] = A[tx, i * 8 + j]
                B_shared[tx, i * 8 + j] = B[tx, i * 8 + j]

        T.evaluate(T.ptx_commit_group(dtype=""))
        T.evaluate(T.ptx_wait_group(0, dtype=""))

        for i in range(128):
            C[tx, i] = A_shared[tx, i] + B_shared[tx, i]

However, in some cases, it may not be desirable to manually set the T.attr attribute, as in the pure compute script example below:

pythonCopy code

@T.prim_func
def func(A: T.Buffer[(256, 256), "int8"], B: T.Buffer[(256, 256), "int8"], C: T.Buffer[(256, 256), "int32"]):
    T.func_attr({"tir.noalias": True, "global_symbol": "main"})
    for i, j, k in T.grid(256, 256, 256):
        with T.block("B"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            T.reads(A[vi, vk], B[vk, vj])
            T.writes(C[vi, vj])
            with T.init():
                C[vi, vj] = 0
            C[vi, vj] = C[vi, vj] + T.Cast("int32", A[vi, vk]) * T.Cast("int32", B[vk, vj])

To optimize memory utilization and maintain several global to shared stages, cache read and cache write primitives can be utilized as shown in the example below:

    for ax0, ax1 in T.grid(2048, 2048):
        with T.block("A_global_shared"):
            v0, v1 = T.axis.remap("SS", [ax0, ax1])
            T.reads(A_global[v0, v1])
            T.writes(A_global_shared[v0, v1])
            A_global_shared[v0, v1] = A_global[v0, v1]

I didn’t find any methods or primitives to set the blocks’ attr, however I find apply sch.annotate pritimive into the block I got the following new block:

    for ax0, ax1 in T.grid(2048, 2048):
        with T.block("A_global_shared"):
            v0, v1 = T.axis.remap("SS", [ax0, ax1])
            T.reads(A_global[v0, v1])
            T.writes(A_global_shared[v0, v1])
            T.block_attr({"attr", "value"})
            A_global_shared[v0, v1] = A_global[v0, v1]

it doesn’t work because the pass won’t consider the block_attr as attr, and in simpler terms, the current method I’m using to handle this stage does not require the specific type of transformation you’re asking about. Instead, I leveraged tensorize primitives to cover this stage.

def get_aync_copy_intrin(dtype):   
    if dtype == "float32":
        elems = 4
    elif dtype == "float16":
        elems = 8
    elif dtype == "int8":
        elems = 16 
    else:
        raise ValueError("Unsupported dtype: {}".format(dtype))
    
    @T.prim_func
    def async_copy_desc(global_handle: T.handle, shared_handle: T.handle) -> None:
        globalVar = T.match_buffer(
            global_handle,
            (elems),
            dtype,
            align=64,
            offset_factor=elems,
            scope="global",
        )
        sharedVar = T.match_buffer(
            shared_handle, (elems), dtype, align=64, offset_factor=16, scope="shared"
        )

        with T.block("root"):
            T.reads(globalVar[0:elems])
            T.writes(sharedVar[0:elems])

            for ax0 in T.vectorized(elems):
                with T.block("shared_warp"):
                    v0 = T.axis.remap("S", [ax0])
                    T.reads(globalVar[v0])
                    T.writes(sharedVar[v0])
                    sharedVar[v0] = globalVar[v0]
    
    @T.prim_func
    def async_copy_imlp(global_handle: T.handle, shared_handle: T.handle) -> None:
        globalVar = T.match_buffer(
            global_handle,
            (elems),
            dtype,
            align=64,
            offset_factor=elems,
            scope="global",
        )
        sharedVar = T.match_buffer(
            shared_handle, (elems), dtype, align=64, offset_factor=elems, scope="shared"
        )

        with T.block("root"):
            T.reads(globalVar[0:elems])
            T.writes(sharedVar[0:elems])
            T.attr(0, "async_scope", 1)
            for ax0 in T.vectorized(elems):
                with T.block("shared_warp"):
                    v0 = T.axis.remap("S", [ax0])
                    T.reads(globalVar[v0])
                    T.writes(sharedVar[v0])
                    sharedVar[v0] = globalVar[v0]


    return async_copy_desc, async_copy_imlp

but I think there must be some better choice, any suggestions please?

There are some testcases to describe how annotations are lowered: https://github.com/apache/tvm/blob/main/tests/python/unittest/test_tir_transform_lower_opaque_block.py#L324. As a summary:

  • only annotation keys start with pragma_ are converted to T.attr
  • loop annotations are preserved as it is if not start with pragma_

The community generally has a trend to discard T.attr in TensorIR lowering flow (refer to Can we lift tir.AttrStmt value type to ObjectRef? - #5 by tqchen). I am not sure why inject_ptx_async_copy require current fashion, maybe we could adjust it to make consistent with TensorIR annotation system? cc authors @junrushao @masahi for help:)

2 Likes

I don’t remember the details, but for PTX async copy, this attribute is not supposed to be set manually, rather it is added by the software pipeline pass https://github.com/apache/tvm/blob/61c9742ea79d0057290502379a81a5487c77790d/src/tir/transforms/inject_software_pipeline.cc#L889

We need AttrStmt, because we need to look at this attribute during InjectPTXAsyncCopy, at which point all TensorIR stuff have already been lowered to normal TIR.

And InjectPTXAsyncCopy works on normal TIR, not TensorIR, because it requires StoregeRewrite to have happened before it.

Does this answer your question? @wrongtest

cc @vinx13

InjectPTXAsyncCopy operates on normal TIR, it doesn’t get to see block or block annotations.

This make a lot of sense to me~

Make sense to me, thanks alot! @wrongtest @masahi

It works now, but I observe that this annotation will convert all of the global to shared assignment into a async copy stage, like we have two stages, one is read A to A_shared, and one is B to B_shared, all of the two stage will be cast into an async one, but I only wanna to convert just one of them and the other still keep the naive copy, how to realize?

sch.annotate(ko, "software_pipeline_async_stages", [0, 0])

seems like the ann_val of this annotation only change the location of wait_group and commit_group

You can try tir.merge_async_commit_queue_scope option. This might be exactly what you are looking for.

1 Like

@masahi

given the schedule I applied of tvm’s software pipeline:

sch.annotate(ki, ann_key="software_pipeline_stage", ann_val=[0, 0, 0])
sch.annotate(ki, ann_key="software_pipeline_order", ann_val=[0, 1, 2])
sch.annotate(ko, ann_key="software_pipeline_stage",
             ann_val=[0, 0, 0, stage - 1, stage - 1])
sch.annotate(ko, ann_key="software_pipeline_order",
             ann_val=[0, 1, 3, 2, 4])
sch.annotate(ko, ann_key="software_pipeline_async_stages", ann_val=[0])

the cuda core gen is :

extern "C" __global__ void __launch_bounds__(128) main_kernel0(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C) {
  extern __shared__ uchar buf_dyn_shmem[];
  half C_warp[256];
  half A_shared_dyn_warp[16];
  half B_shared_dyn_warp[128];
  
  // init c
  for (int ax0_ax1_ax2_ax3_0_fused_2 = 0; ax0_ax1_ax2_ax3_0_fused_2 < 4; ++ax0_ax1_ax2_ax3_0_fused_2)
      // async copy A from global to shared
  for (int ax0_ax1_ax2_ax3_0_fused_2_1 = 0; ax0_ax1_ax2_ax3_0_fused_2_1 < 8; ++ax0_ax1_ax2_ax3_0_fused_2_1)
      // async copy B from global to shared
    __asm__ __volatile__("cp.async.commit_group;");
  for (int kk_0 = 0; kk_0 < 511; ++kk_0) {
    __syncthreads();
    for (int ax0_ax1_ax2_ax3_0_fused_2_2 = 0; ax0_ax1_ax2_ax3_0_fused_2_2 < 4; ++ax0_ax1_ax2_ax3_0_fused_2_2)
      // copy A_2
    for (int ax0_ax1_ax2_ax3_0_fused_2_3 = 0; ax0_ax1_ax2_ax3_0_fused_2_3 < 8; ++ax0_ax1_ax2_ax3_0_fused_2_3)
      // copy B_2
    __asm__ __volatile__("cp.async.commit_group;");

	__asm__ __volatile__("cp.async.wait_group 1;");
    __syncthreads();
    for (int kk_1 = 0; kk_1 < 2; ++kk_1) {
    	// ldmatrix and compute
    }
  }
  __asm__ __volatile__("cp.async.wait_group 0;");
  // epoligue computation
  for (int ax0_4 = 0; ax0_4 < 2; ++ax0_4) {
    for (int ax1 = 0; ax1 < 16; ++ax1) {
      for (int local_id = 0; local_id < 8; local_id+=2) {
     	// write_back c     
    }
  }
}

I must agreed that the logic of software pipleline and the usage of cp.async ptx is definitely correct. However, sometimes it will not be pattern with the best performance, for example, we can just make a single commit group and let wait_group_all wait all the dma operation has done, the performance may be better (maybe sometimes nvcc cover this optimizaion. the code will be

extern "C" __global__ void __launch_bounds__(128) main_kernel0(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C) {
  extern __shared__ uchar buf_dyn_shmem[];
  half C_warp[256];
  half A_shared_dyn_warp[16];
  half B_shared_dyn_warp[128];
  
  // init c
  for (int ax0_ax1_ax2_ax3_0_fused_2 = 0; ax0_ax1_ax2_ax3_0_fused_2 < 4; ++ax0_ax1_ax2_ax3_0_fused_2)
      // async copy A from global to shared
  for (int ax0_ax1_ax2_ax3_0_fused_2_1 = 0; ax0_ax1_ax2_ax3_0_fused_2_1 < 8; ++ax0_ax1_ax2_ax3_0_fused_2_1)
      // async copy B from global to shared
   // __asm__ __volatile__("cp.async.commit_group;");
  for (int kk_0 = 0; kk_0 < 511; ++kk_0) {
    __syncthreads();
    for (int ax0_ax1_ax2_ax3_0_fused_2_2 = 0; ax0_ax1_ax2_ax3_0_fused_2_2 < 4; ++ax0_ax1_ax2_ax3_0_fused_2_2)
      // copy A_2
    for (int ax0_ax1_ax2_ax3_0_fused_2_3 = 0; ax0_ax1_ax2_ax3_0_fused_2_3 < 8; ++ax0_ax1_ax2_ax3_0_fused_2_3)
      // copy B_2
    __asm__ __volatile__("cp.async.commit_group;");

	__asm__ __volatile__("cp.async.wait_group 0;");
    __syncthreads();
    for (int kk_1 = 0; kk_1 < 2; ++kk_1) {
    	// ldmatrix and compute
    }
  }
  // __asm__ __volatile__("cp.async.wait_group 0;");
  // epoligue computation
  for (int ax0_4 = 0; ax0_4 < 2; ++ax0_4) {
    for (int ax1 = 0; ax1 < 16; ++ax1) {
      for (int local_id = 0; local_id < 8; local_id+=2) {
     	// write_back c     
    }
  }
}

But I don’t know how to leverage software_pipeline_async_stages to implement it, currently I just adhoc custom the cuda kernel to achieve my purpose, can you please give me some helps?

I resolved it by leverage tvm_callback_cuda_postproc, thanks. but any comment is welcome.

Sorry for the late reply. Have you seen the test case and its annotations in https://github.com/apache/tvm/blob/424c749a3dac0ba42e89d3cbd04b024658d7d104/tests/python/unittest/test_tir_transform_inject_software_pipeline.py#L1516-L1527?

I think it should generate the ideal kernel you described.

It works now, thanks. :wink: