Combining Separate tir.Block using compute_at()

Hello,

I notice that the plan to support compute_at() on Tir is on this issue https://github.com/apache/tvm/issues/7527#

The compute_at() API in TVM schedules primitives enables the combination of separate stages of computation when possible. This operation would bring performance gain or other potential optimization opportunities. For example, two loops with the exact same looping range can be combined into one loop by compute_at().

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")

s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])

Similarly, when defining computations using tir.Block, we notice that there are situations where combining Blocks are beneficial. For example,

for i, j in tir.grid(128, 128):
     with tir.block([128, 128], "A_Block") as [vi, vj]:
         A[vi, vj] = tir.float32(0)
for i, j in tir.grid(128, 128):
     with tir.block([128, 128], "B_Block") as [vi, vj]:
         B[vi, vj] = A[vi, vj]

There are two separate blocks in the above example. We would like to know if compute_at() will support combining two blocks into one block in the future? If so, what are the conditions on those two blocks needed to be satisfied to enable compute_at() performing the combination?

1 Like

Thanks for asking!

To be clear, it is not necessary to break our block isolation when using compute-at. For example, after compute-at, the IR may become:

for i in tir.range(0, 128):
    for j in tir.range(0, 128):
        with tir.block([128, 128], "A_Block") as [vi, vj]:
            A[vi, vj] = tir.float32(0)
        with tir.block([128, 128], "B_Block") as [vi, vj]:
            B[vi, vj] = A[vi, vj]

1 Like

Hello Junru,

Thanks for answering.

However, I am asking if two blocks are already in separate, will compute_at() merge two blocks like it does to the loops above, or is this API going to support this kind of merging?

It is definitely not hard to develop a schedule primitive that merges adjacent blocks. Would you love to name a usecase in deep learning where this kind of transformation is useful?

Let me clarify some more:

  1. Two blocks do not have to be adjacent to be merged though in this example they are.
  2. I would like to know if block merging is an already planned feature to be added in the future instead of knowing the difficulty of implementing this.
  3. I believe it is a useful optimization (possibly reduce the computation) when we have two separate blocks and combine them into one when possible.

Thanks @xiacijie for the reply!

  1. Yes, I got what you mean and definitely agree with the definition.
  2. No. We have a primitive called “blockize” that does the opposite (not exactly, but it creates more blocks), and have thought of such a primitive. Development should be fairly simple (~200 lines in core implementation), and we are more than happy to assist if you want :slight_smile:
  3. To be clear, merging blocks is a transformation that itself doesn’t bring performance gain: Block in TensorIR is a construct that creates conceptual isolation, but lowers to nothing - merging blocks or not, it doesn’t affect generated code.

The reason that it is not developed is that we haven’t found a real-world scenario yet where this primitive is useful, and I definitely appreciate it a lot if you could bring up with an example :+1:

Hello @junrushao,

One further question on the syntax of tir.block which confuses me.

for i, j in tir.grid(128, 128):
    with tir.block([128, 128], "init") as [vi, vj]:
        C[vi, vj] = 0.0
with tir.block([128, 128], "init") as [vi, vj]:
        C[vi, vj] = 0.0

Those two piece of code seem to do the same job - initialize a 128 * 128 buffer to 0.

I am confused when we already declare a 128 * 128 loop outer, we still have to put [128, 128] as a parameter when using tir.block. What’s the purpose of putting this range when using tir.block? Why not just desgin tir.block in this way when we already use a loop outer:

for i, j in tir.grid(128, 128):
        with tir.block("init"):
            C[i, j] = 0.0

Thanks!

Thanks @xiacijie for the question!

# Snippet-1:
for i, j in tir.grid(128, 128):
    with tir.block([128, 128], "init") as [vi, vj]:
        C[vi, vj] = 0.0

# Snippet-2:
with tir.block([128, 128], "init") as [vi, vj]:
    C[vi, vj] = 0.0

Yes, the two snippets above are strictly equivalent, and the second one is the syntactic sugar for the first. TVM script has a built-in functionality called “auto-completion” that desugars snippets.

To get a sense what the fully desugared IR looks like, you may print it out with the following command in python:

print(tvm.script.asscript(PrimFunc-or-IRModule))

The syntax below describes the signature of a block:

with tir.block([128, 128], "init") as [vi, vj]:

TensorIR is designed with the “block isolation” philosophy, and a block here describes a chunk of computation without needing context. When desugared, your particular example above expands to:

for i in range(0, 128):
  for j in range(0, 128):
    with tir.block([128, 128], "init") as [vi, vj]:
      # vi's domain is [0, 128), and it is data-parallel
      # vj's domain is [0, 128), and it is data-parallel
      tir.bind(vi, i)  # binds `i` to `vi`
      tir.bind(vi, j)  # binds `j` to `vj`
      tir.reads([])  # reads nothing
      tir.writes(C[vi : vi + 1, vj : vj + 1]) 
      C[vi, vj] = 0

The property of the block is that:

  • Instances of block execution are described with pair (vi, vj), where vi, vj \in [0, 128).
  • For a certain instance of a block (vi, vj), it doesn’t read anything, and writes to a buffer region C[vi : vi + 1, vj : vj + 1]
  • vi, vj are both data parallel, which means block instances (vi, vj) can be executed in arbitrary orders or in parallel

Block bindings (tir.bind) describe how those loops “drags” the block execution. It is possible that we execute in another order:

for i in range(0, 128):
  for j in range(0, 128):
    with tir.block([128, 128], "init") as [vi, vj]:
      tir.bind(vi, 127 - i)  # binds `127 - i` to `vi`
      tir.bind(vi, 127 - j)  # binds `127 - j` to `vj`

In short, in TensorIR, we decouple “in which order loop runs” and “the computation in the block body”. Therefore, over-complete information may occur (as you described) when the binding is trivial, and we provide syntactic sugars for this case.

Please let me know if it answers your question. Happy to assist further if you are interested