iter_var: T.int32() = T.ceildiv(K, block_K)
for ko in T.serial(iter_var):
...
The generated CUDA code becomes:
int32_t iter_var = K / Block_K;
for (int32_t k = 0; k < iter_var; k++){
...
}
This is powerful, but I’d like to discuss feasible solutions for cleanly defining local variables. For instance, when using single-element buffers in TVM:
Concern: This might violate Let semantics (immutable bindings).
Approach 2: Introduce Local Variables via Buffer Extension
This would require AST modifications to support first-class local variables instead of single-element buffers.
Looking forward to community thoughts on these approaches or alternative solutions.
I have a implementation done in my fork, which is to introduce sugars at parser/printer level. In the IR local variables remain buffers.
A = T.alloc_cell("int32")
A = A + 1
T.cp_async(A.buffer.data, ...)
Parser will create a buffer with dtype int32 with shape [1], but in the value table of parser, A is recorded as A[0] (BufferLoad). Then anywhere A used as a PrimExpr naturally works.
BufferStore is the place that needs additional handling. Parser accepts the case where the lhs of assign (and augassign) is a BufferLoad (of a [1] shape buffer).
If the user wants to access Buffer attributes, or encounters any other cases where the buffer of A is needed, A.buffer can be used.
The primray motivation of this solution is that
I don’t prefer heavy solutions like introducing more nodes into IR
For backend codes (like CUDA), it doesn’t seem to matter to keep local variables as arrays. Or you can modify the code genenator. It doesn’t affect other system parts anyway
TVMScript also provide doc.NodeVisitor and doc.NodeTransformer for AST level mutations. In our private fork we collect all lhs variables with multiple Assign or AugAssign bindings. They are implicitly transformed to scalar buffer load & store during parsing to structure same as @spectrometerHBH’s answer.
from tvm.script import tir as T
@T.prim_func
def function(X: T.Buffer([16], "int32")):
v = 0
X[0] = v
for i in range(10):
v = i
X[i + 1] = v
print(function)
# default tir parser output
# different assignment to `v` treat as new let bindings
@T.prim_func
def function(X: T.Buffer((16,), "int32")):
v: T.int32 = 0
X[0] = v
for i in range(10):
v_1: T.int32 = i
X[i + 1] = v_1
# scalar transformed output
@T.prim_func
def function(X: T.Buffer((16,), "int32")):
# with T.block("root"):
v = T.alloc_buffer((1,), "int32")
v[0] = 0
X[0] = v[0]
for i in range(10):
v[0] = i
X[i + 1] = v[0]
In llvm-based backends, we could also introduce new storage scope or someother annotations to ensure the scalar buffer allocation map to alloca and finally transform to SSA value via mem2reg.
One important note is though this behavior is more “pythonic”, it break the concise scoping behaviour for default IR parsing. And it may be hard to keep round trip property. So from my understanding it could be useful when we use TVMScript as a programming language, but may not be proper or default behaviour for TVMScript’s origin serving purpose.
great discussions, it would be good for us to form an opinion on what should be our reasonable “official path”, my understanding is that @spectrometerHBH 's proposal may still retain roundtrip property by some careful pattern matching, maybe some flavor that combines the “cell” notation and option in parser to turn things into cell?