A question about asynchronous stage in software pipeline

Hi, I’m trying to use the recently merged async pipeline([https://github.com/apache/tvm/pull/12171]). Thanks for the excellent work of @masahi.

I have a question of three_stage_compute in test_tir_transform_inject_software_pipeline.py:

@T.prim_func
def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]):
    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
        for i in T.serial(
            0,
            16,
            annotations={
                "software_pipeline_stage": [0, 1, 2],
                "software_pipeline_order": [0, 1, 2],
            },
        ):
            with T.block("compute"):
                T.reads(A[tx, i])
                T.writes(D[tx, i])
                B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
                C = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
                with T.block():
                    T.reads(A[tx, i])
                    T.writes(B[tx, 0])
                    B[tx, 0] = A[tx, i] * T.float32(2)
                with T.block():
                    T.reads(B[tx, 0])
                    T.writes(C[tx, 0])
                    C[tx, 0] = B[tx, 0] + T.float32(2)
                with T.block():
                    T.reads(C[tx, 0])
                    T.writes(D[tx, i])
                    D[tx, i] = C[tx, 0] + T.float32(1)

If I set the software_pipeline_async_stages = [0, 1, 2] which means 3 stages are all async stage. In this case stage 2 “C[tx, 0] = B[tx, 0] + T.float32(2)” wait for stage 1 “B[tx, 0] = A[tx, i] * T.float32(2)”. Stage 1 is a async producer and stage 2 reads from asynchronously written buffers.

I think there is another wait/commit relations between stage 2 and stage 3 which async stage did not consider. The software pipelined tir is like this:

with T.block("_1"):
    T.reads(A[tx, 2 : 16], B[0 : 2, tx, 0], C[0 : 2, tx, 0])
    T.writes(B[0 : 2, tx, 0], C[0 : 2, tx, 0], D[tx, 0 : 14])
    for i in T.serial(14):
        with T.block("_2"):
            T.where(i + 2 < 16)
            T.reads(A[tx, i + 2])
            T.writes(B[i % 2, tx, 0])
            B[(i + 2) % 2, tx, 0] = A[tx, i + 2] * T.float32(2)
        with T.block("_3"):
            T.where(i + 2 - 1 < 16)
            T.reads(B[(i + 1) % 2, tx, 0])
            T.writes(C[(i + 1) % 2, tx, 0])
            C[(i - 1 + 2) % 2, tx, 0] = B[(i - 1 + 2) % 2, tx, 0] + T.float32(2)
        with T.block("_4"):
            T.where(i + 2 - 2 < 16)
            T.reads(C[i % 2, tx, 0])
            T.writes(D[tx, i])
            D[tx, i - 2 + 2] = C[(i - 2 + 2) % 2, tx, 0] + T.float32(1)

Suppose i = 1, “C[0,tx,0] = B[1,tx,0] + T.float32(2)” should wait for the “D[tx, 1] = C[0, tx, 0] + T.float32(1)” from the i-1 iteration. Because stage 2 can use C[0, tx, 0] only after “D[tx, 1] = C[0, tx, 0] + T.float32(1)” is completed.

I don’t know if this requirement is HW dependent. DMA read / compute / DMA write are the three stages in my case. So the compute stage should wait for both the DMA read and DMA write in previous iteration.

Correct me if I misunderstood asynchronous stage in software pipeline. Thanks.

1 Like

Thanks for pointing this out. You are right that, the current implementation only handles read after write dependencies. I didn’t think about write after read dependencies like yours.

Let me think about how to “fix” this. The PR was already merged, so if you already have a solution, you are more than welcome to send a follow up PR.

So the bottom line is, if a stage is both an async producer and consumer, we need two async_wait, one for waiting for the previous stage in the current iteration (supported by the current impl) and another for waiting for the next stage in the previous iteration (not considered currently)?

I think it’s should be easy to patch the current implementation to support such cases. Just add another async_wait with the correct stage dependency relation and wait count. And the wait count would probably be calculated in the same way, but we need to subtract 1 to account for the “previous” iteration. A similar logic is already in https://github.com/apache/tvm/blob/3c737fbd5baccc60aff355b40105220c148b7d7f/src/tir/transforms/inject_software_pipeline.cc#L667

So the bottom line is, if a stage is both an async producer and consumer, we need two async_wait , one for waiting for the previous stage in the current iteration (supported by the current impl) and another for waiting for the next stage in the previous iteration (not considered currently)?

I think it’s one async_wait waiting for the previous stage in the previous iteration and another for waiting for the next stage in the previous iteration (not considered currently).

Thanks, I will try to send a PR for this issue.

if a stage is both an async producer and consumer, we need two async_wait , one for waiting for the previous stage in the current iteration (supported by the current impl) and another for waiting for the next stage in the previous iteration (not considered currently)

I think this handles most cases, but not all. A synchronous stage might be writing to a buffer read by an asynchronous stage. In this case, there is no need for an explicit guard on data availability, because the producer is synchronous; however, it’s possible that the asynchronous stage hasn’t finished a prior iteration before the synchronous stage overwrites a buffer space.

To put it another way, we have to protect against the cases where the asynchronous stages “run ahead” and do the overwriting, as well as the cases where they “fall behind” and are the ones overwritten.

I think the rules that might cover all cases to prevent buffer overwrites are:

  1. Asynchronous producers must have an explicit guard
  2. Synchronous producers need an explicit guard if the consumer of the write is asynchronous

Are you planning on a PR in the near future? Explicit buffer dependencies are also a feature I need (it sounds like we’re working on very similar things!), so it would be good not to duplicate the effort.

  1. Asynchronous producers must have an explicit guard
  2. Synchronous producers need an explicit guard if the consumer of the write is asynchronous

I think your rules to prevent buffer overwrites will cover my case. In fact I don’t know how to implement this feature. So I am really looking forward for your following PR.