This posts seeks to discuss handling of inplace update in relax and how to make it compatible with the dataflow block.
Background
In-place updates can be useful in various places as a memory optimization where, instead of creating a new copy of memory, we simply update an existing buffer. While in a lot of cases, it is not strictly necessary in many places, as memory planning can get us very far by reusing memory in a ping-pong fashion. However, there are still key places where in-place updates are useful. Specifically, they are mostly relevant in cases where we need some form of incremental update:
- Incrementally update certain rows of a matrix (e.g., in the kv-cache in LM decoding).
- Incrementally update a few entries of a large embedding matrix.
In both cases, it is desirable to not create a new copy of the input to reduce the amount of copy and memory cost.
While Relax does support mutation outside of the dataflow block, it is desirable to also support some form of in-place update inside the dataflow block. This post brings the proposal to do so.
Proposal
One approach here is that we can adopt move semantics in Rust/C++. If a value is not aliased and is only being referenced in a single place, we can afford to pass it onto a function that in-place updates
it, return the value, and still pretend that the overall process is functional.
@R.function
def func(x: R.Tensor((1, 2), "float32")):
# mark x as an in-place argument, which means
# the caller allows in-place updates to happen
# and can choose to return the result in return values
R.func_attrs({"relax.inplace_arg": 0 })
with R.dataflow():
# in-place add x by 1, return the updated value
x1 = inplace_addone(x, 1)
# it is OK to use x1 further
R.output(x1)
return x1
@R.function
def func_invalid(x: R.Tensor((1, 2), "float32")):
# mark x as an in-place argument, which means
# the caller allows in-place updates to happen
# and can choose to return the result in return values
R.func_attrs({"relax.inplace_arg": 0 })
with R.dataflow():
# in-place add x by 1, return the updated value
x1 = inplace_addone(x)
# after in-place update, x cannot be referenced by any other operators
gv = add(x, x)
R.output(gv)
return gv
Such in-place updates within the dataflow block can be helpful to have a large dataflow block without breaking things into segments.
We can put strong requirements to ensure safety in such cases.
-
C0: No alias: A var has to be an internal var that is freshly allocated or an argument marked as an in-place arg.
-
C1: Single use: The var can only be used once, and the usage site is the in-place operator.
- This condition can be further relaxed to the previous usage site of the var only reads but not retain or alias the particular operator.
- Note that single use has its advantages, e.g., in the case of parallel execution of previous read and in-place update, the DAG order is not sufficient for correctness. A safer way is to model read also as an in-place operator, which will sequentialize the read and update.
@R.function def func0(x: R.Tensor((1, 2), "float32")): R.func_attrs({"relax.inplace_arg": 0 }) with R.dataflow(): y = exp(x) # based on DAG order, y = exp(x) do # not have strict order with in-place update # which can cause unsafe reordering x1 = inplace_addone(x)
-
C2: Inplace pass through: The in-place operator will return a value that is backed by the same value, and the return value can be further consumed by follow-up in-place or non-in-place operators.
We can mark certain operators as in-place operators, which means they would satisfy such conditions.
CallTIRInplace
One specific construct that can be useful to enable is call_tir_inplace
, specifically, it has the following semantics:
def call_tir_inplace(tir_func, args, shape_args):
tir_func(*args, *shape_args)
return args[0]
Here are some example codes that can leverage this construct to implement concat in an efficient way without creating an extra copy of memory.
@tvm.script.ir_module:
class IRModule:
@T.prim_func
def concat_update(
X: T.Buffer((3, 8), "float32"),
Y: T.Buffer((1, 8), "float32")),
offset: T.int64
):
for i in range(8):
X[offset, i] = Y[0, i]
@R.function
def main(
y0: R.Tensor((1, 8), "float32")),
y1: R.Tensor((1, 8), "float32")),
y2: R.Tensor((1, 8), "float32"))
):
with R.dataflow():
z0 = R.log(y0)
z1 = R.log(y1)
z2 = R.log(y2)
# x3 = concat([z0, z1, z2])
x0 = R.empty((3, 8), "float32"))
x1 = R.call_tir_inplace(concat_update, [x0, z0], R.shape([0]))
x2 = R.call_tir_inplace(concat_update, [x1, z1], R.shape([1]))
x3 = R.call_tir_inplace(concat_update, [x2, z2], R.shape([2]))
R.outputs(x3)
return x3
The above code example enables call_tir_inplace
to create a result of concatenation of three arrays in an in-place fashion.
Fuse Compute Store into the Final Buffer
The main advantage of this construction is that we can enable fusion that directly fuses previous computation into the in-place storage. For example, we can run a fusion pass to fuse the z calculations into the concat. This is a useful optimization that reduces memory usage in operators that involve concat.
@tvm.script.ir_module:
class IRModule:
@T.prim_func
def log_concat_update(
X: T.Buffer((3, 8), "float32"),
Y: T.Buffer((1, 8), "float32")),
offset: T.int64):
for i in range(8):
X[offset, i] = T.log(Y[0, i])
@R.function
def main(
y0: R.Tensor((1, 8), "float32")),
y1: R.Tensor((1, 8), "float32")),
y2: R.Tensor((1, 8), "float32"))
):
with R.dataflow():
# x3 = concat([y0, y1, y2])
x0 = R.empty((3, 8), "float32"))
x1 = R.call_tir_inplace(log_concat_update, [x0, y0], R.shape([0]))
x2 = R.call_tir_inplace(log_concat_update, [x1, y1], R.shape([1]))
x3 = R.call_tir_inplace(log_concat_update, [x2, y2], R.shape([2]))
R.outputs(x3)
return x3
The same construct can be useful for use to do sparse embedding update and KV-cache update.
Discussions
This post aims to start this topic. There are extra things that are worth thinking about, for example:
- How would such in-place operator interact with operator fusion and scheduling?
- What are the ways to signal the correct intent?
- Sometimes users like to write explicit mutation code, can we perform “functionalization” which transforms the explicit mutation code to the in-place update style?