[DISCUSS] Inplace Update in Dataflow Block

This posts seeks to discuss handling of inplace update in relax and how to make it compatible with the dataflow block.

Background

In-place updates can be useful in various places as a memory optimization where, instead of creating a new copy of memory, we simply update an existing buffer. While in a lot of cases, it is not strictly necessary in many places, as memory planning can get us very far by reusing memory in a ping-pong fashion. However, there are still key places where in-place updates are useful. Specifically, they are mostly relevant in cases where we need some form of incremental update:

  • Incrementally update certain rows of a matrix (e.g., in the kv-cache in LM decoding).
  • Incrementally update a few entries of a large embedding matrix.

In both cases, it is desirable to not create a new copy of the input to reduce the amount of copy and memory cost.

While Relax does support mutation outside of the dataflow block, it is desirable to also support some form of in-place update inside the dataflow block. This post brings the proposal to do so.

Proposal

One approach here is that we can adopt move semantics in Rust/C++. If a value is not aliased and is only being referenced in a single place, we can afford to pass it onto a function that in-place updates it, return the value, and still pretend that the overall process is functional.

@R.function
def func(x: R.Tensor((1, 2), "float32")):
    # mark x as an in-place argument, which means
    # the caller allows in-place updates to happen
    # and can choose to return the result in return values
    R.func_attrs({"relax.inplace_arg": 0 })
    with R.dataflow():
        # in-place add x by 1, return the updated value
        x1 = inplace_addone(x, 1)
        # it is OK to use x1 further
        R.output(x1)
    return x1

@R.function
def func_invalid(x: R.Tensor((1, 2), "float32")):
    # mark x as an in-place argument, which means
    # the caller allows in-place updates to happen
    # and can choose to return the result in return values
    R.func_attrs({"relax.inplace_arg": 0 })
    with R.dataflow():
        # in-place add x by 1, return the updated value
        x1 = inplace_addone(x)
        # after in-place update, x cannot be referenced by any other operators
        gv = add(x, x)
        R.output(gv)
    return gv

Such in-place updates within the dataflow block can be helpful to have a large dataflow block without breaking things into segments.

We can put strong requirements to ensure safety in such cases.

  • C0: No alias: A var has to be an internal var that is freshly allocated or an argument marked as an in-place arg.

  • C1: Single use: The var can only be used once, and the usage site is the in-place operator.

    • This condition can be further relaxed to the previous usage site of the var only reads but not retain or alias the particular operator.
    • Note that single use has its advantages, e.g., in the case of parallel execution of previous read and in-place update, the DAG order is not sufficient for correctness. A safer way is to model read also as an in-place operator, which will sequentialize the read and update.
    @R.function
    def func0(x: R.Tensor((1, 2), "float32")):
        R.func_attrs({"relax.inplace_arg": 0 })
        with R.dataflow():
            y = exp(x)
            # based on DAG order, y = exp(x) do
            # not have strict order with in-place update
            # which can cause unsafe reordering
            x1 = inplace_addone(x)
    
    
  • C2: Inplace pass through: The in-place operator will return a value that is backed by the same value, and the return value can be further consumed by follow-up in-place or non-in-place operators.

We can mark certain operators as in-place operators, which means they would satisfy such conditions.

CallTIRInplace

One specific construct that can be useful to enable is call_tir_inplace, specifically, it has the following semantics:

def call_tir_inplace(tir_func, args, shape_args):
        tir_func(*args, *shape_args)
        return args[0]

Here are some example codes that can leverage this construct to implement concat in an efficient way without creating an extra copy of memory.


@tvm.script.ir_module:
class IRModule:
    @T.prim_func
    def concat_update(
        X: T.Buffer((3, 8), "float32"),
        Y: T.Buffer((1, 8), "float32")),
        offset: T.int64
    ):
        for i in range(8):
            X[offset, i] = Y[0, i]

    @R.function
    def main(
        y0: R.Tensor((1, 8), "float32")),
        y1: R.Tensor((1, 8), "float32")),
        y2: R.Tensor((1, 8), "float32"))
    ):
        with R.dataflow():
            z0 = R.log(y0)
            z1 = R.log(y1)
            z2 = R.log(y2)
            # x3 = concat([z0, z1, z2])
            x0 = R.empty((3, 8), "float32"))
            x1 = R.call_tir_inplace(concat_update, [x0, z0], R.shape([0]))
            x2 = R.call_tir_inplace(concat_update, [x1, z1], R.shape([1]))
            x3 = R.call_tir_inplace(concat_update, [x2, z2], R.shape([2]))
            R.outputs(x3)
        return x3

The above code example enables call_tir_inplace to create a result of concatenation of three arrays in an in-place fashion.

Fuse Compute Store into the Final Buffer

The main advantage of this construction is that we can enable fusion that directly fuses previous computation into the in-place storage. For example, we can run a fusion pass to fuse the z calculations into the concat. This is a useful optimization that reduces memory usage in operators that involve concat.

@tvm.script.ir_module:
class IRModule:
    @T.prim_func
    def log_concat_update(
        X: T.Buffer((3, 8), "float32"),
        Y: T.Buffer((1, 8), "float32")),
        offset: T.int64):
        for i in range(8):
            X[offset, i] = T.log(Y[0, i])

    @R.function
    def main(
        y0: R.Tensor((1, 8), "float32")),
        y1: R.Tensor((1, 8), "float32")),
        y2: R.Tensor((1, 8), "float32"))
    ):
        with R.dataflow():
            # x3 = concat([y0, y1, y2])
            x0 = R.empty((3, 8), "float32"))
            x1 = R.call_tir_inplace(log_concat_update, [x0, y0], R.shape([0]))
            x2 = R.call_tir_inplace(log_concat_update, [x1, y1], R.shape([1]))
            x3 = R.call_tir_inplace(log_concat_update, [x2, y2], R.shape([2]))
            R.outputs(x3)
        return x3

The same construct can be useful for use to do sparse embedding update and KV-cache update.

Discussions

This post aims to start this topic. There are extra things that are worth thinking about, for example:

  • How would such in-place operator interact with operator fusion and scheduling?
  • What are the ways to signal the correct intent?
  • Sometimes users like to write explicit mutation code, can we perform “functionalization” which transforms the explicit mutation code to the in-place update style?
5 Likes

I think having in-place operators would be an important optimization, but I think it would be difficult to do it as proposed without greatly changing the type system. Rust is able to do many optimizations involving in-place computation because of its linear type system, whereas we have not implemented anything comparable for Relax.

Aliasing poses a big problem because there are very many actions that could possibly cause a value to have aliases. E.g., if a function takes an argument, you don’t know if the argument is aliased somewhere else; a PackedFunc might store a copy of an argument and return it elsewhere. We would have to be very conservative if we want to eliminate cases without aliases.

What I could imagine working is an optimization that lowers operations into in-place versions (late in compilation, after we’ve eliminated DataflowBlocks) where we can guarantee that no aliasing has occurred. I think otherwise, we should treat the in-place versions of operators as not being pure and thus not in dataflow blocks.

3 Likes

Indeed we need to be extra careful and it is hard to support all general cases.

That is why the current proposal places a strong restriction as per C0 and C1. One possible pt for discussion is that even such restricted cases brings benefits. For example, the call_tir_inplace example which can be important to enable some optimizations explicitly.

Okay, it might be worth thinking about how these conditions would be enforced. This might pose some implementation issues, given how we’ve been checking StructInfo so far.

I would personally prefer keeping in-place operations out of dataflow blocks.

I shared the thinking as well originally (that it is sufficient to keep inplace outside of DF block). After seeing some of the use-cases recently (like the concat optimization), I start to feel that it is worth while to have a “functionalized” version of inplace (aka those that consumes and return the same array) as part of dataflow block as they can be analyzed and used mostly in similar ways.

One thing that might be useful is to have well-formness check to check the condition C0 and C1. Given we are enforcing a single use, likely it means that we don’t need to check other usage site as a result such check is simple. We can then check the invariance in the lowering of call_tir_inplace to ensure correctness

Another possibility is to still keep call_tir_inplace's semantics as functional (X0).

  • X0: During lowering if we detect the input have multiple use, then we will trigger a copy of the input, but if not we will do inplace. So the semantics of the operator is still pure, although the underlying tir function runs the compute inplace.
  • X1: We can also have a “strict mode” behavior, which simply errors out when copy is necessary.

A combination of X0 and X1 should be sufficient for some of the critical use-cases while not over-generalizing things to complicate the analysis.

I think if we go with the approach of “copy if we detect the conditions for keeping it in-place are not fulfilled,” it would be fine to keep it in dataflow blocks. My concern is that the check for possible aliasing will require multiple passes and might not be something we want to happen every time we run the normalizer/well-formed check. Hence, I think lowering ops to in-place versions later in compilation might be less disruptive.

Note from the June 13, 2023, TVM Unity Community Meeting: One approach that could be convenient in practice would be to have ops that are marked as “in-place” but only lower them to truly in-place (mutating) versions later in compilation if they meet the appropriate criteria (no aliasing and being in-place), otherwise replacing them with pure operations. This would get around the issue with purity tracking without losing any convenience.

2 Likes

As discussed in the June 27, 2023, community meeting, here is a more concrete proposal for implementing in-place operations.

The basic mechanism for implementing in-place operations would be to replace operator calls with TIR PrimFuncs that compute in-place. We could perform such a replacement using a special operator call_tir_inplace, which would be similar to call_tir:

call_tir_inplace(gvar, args, out_idxs, packed_ints):

  • gvar is the global variable identifying a TIR PrimFunc.
  • out_idxs indicate which of the args should be considered “outputs”
  • The PrimFunc will be invoked with the arguments args, mutating the arguments whose indices match out_idxs. The operator will return aliases of those modified arguments.

As mentioned in the community meeting, having call_tir_inplace as a separate operator has the benefit of allowing us to treat it as pure and have it inside DataflowBlocks and also allow for special handling in passes like fusion or layout propagation.

The tricky problem with in-place operations, as we will explore below, is that detecting opportunities to make ops in-place should ideally be done early in compilation (when dataflow blocks are still intact) but implementing them may require low-level changes to memory allocations.

Simple Case

Let’s suppose that we intend to implement an in-place addition, where we have z = add(x, y). We can replace this call with an in-place version if the following conditions are met:

  1. Whichever argument is chosen (see condition 3) to hold the result must not be live after that point in the program
  2. No alias of the argument chosen to hold the result may be live after that point in the program
  3. Either x or y has the same type (especially shape) as the output.

If these conditions are met, we can replace this call with z = call_tir_inplace(inplace_add, [x, y], out_idxs=[0], []) (the choice of which argument should be the output depends on which one matches the type). The inplace_add PrimFunc would be manually written. If the conditions are not met, then do nothing and legalize the add later.

However, it should be noted that checking these conditions can be complicated, as there are many actions that could potentially create aliases. For example, we must conservatively assume that any packed call could alias any of its arguments or, if you want to consider some truly adversarial examples, alias any value ever passed to a packed func, per the below example. With the same example, we can also see how a packed call is a nightmare for liveness analysis, since the argument passed to packed1 is still live if packed2 is called later in the program, even though the compiler cannot know that there is any relationship between the two. Thus, packed calls are “black holes” for both alias and liveness analysis: The result of a packed call can be the alias of any value ever passed to any packed call.

GLOBAL_VAR = None

def packed1(x):
    nonlocal GLOBAL_VAR
    GLOBAL_VAR = x
    return x

# the result of packed2 aliases a value that was passed to packed1!
def packed2(y):
    return GLOBAL_VAR

Of course, we do not have to permit these kinds of adversarial examples, but if we want to define restrictions on what packed calls should be allowed to do, we need to be very up-front with them, as there is no practical way the compiler can enforce such restrictions. Unfortunately, I could imagine that the pattern of setting or reading a global variable via PackedFuncs could legitimately be needed, so there could be issues with adding such restrictions.

Even if we analyze only values within DataflowBlocks that are DataflowVars, there is no restriction on assigning a DataflowVar to a variable from outside the block, which brings all the same possible issues of aliasing…

Most likely, we will need some kind of escape hatch for liveness and alias analysis, some annotation a user can give to insist to the compiler that a variable is safe to use despite having been passed to a PackedCall, e.g., a non_aliased(x) operator or a call_packed_nonaliasing(func, *args, sinfo_args) operator. These would definitely be pretty finnicky details for users to manage, though the opaque nature of PackedFuncs gives us little choice here.

When Shapes Do Not Match

In the above example of adding in-place, we may consider the situation where we have z = add(x, y) and neither x nor y has the same shape as z. This would be the case if broadcasting is taking place. It would still be possible to perform the addition in-place if either x or y has an underlying storage that is large enough to store the broadcasted result. This means that handling such a case would require reasoning about the sizes of allocated regions of memory rather than only tensor shapes.

The in-place optimizations proposed at the June 27, 2023, community meeting of eliding concatenations and splits (ensuring that concatenated tensors are allocated in a single storage so that the concatenation itself is a no-op, or ensuring that the tensors created by splitting are all just offsets within the same storage; both optimizations can be done only if the inputs to the operators will not be live afterwards) similarly require reasoning at the level of memory allocation.

This poses a bit of a problem from the perspective of phase ordering, since the simple case in the previous section does not need to reason about storages but cases like split, concat, and non-matching shapes do. However, reasoning about whether the aliasing and liveness conditions are satisfied at a low level can be more complicated, so it would be preferable to handle those analyses earlier in compilation.

One solution would be to do the liveness and alias analyses early, replace the operator with a different version that is marked “in-place,” and later in compilation replace the special operator and its inputs with the appropriate versions (performing the memory allocations as specified and an in-place TIR call for the operator).

This approach would require two passes:

  1. A pass early in compilation that identifies opportunities to perform operators in-place. This would be responsible for doing the liveness and alias analyses. In the simple case where the shapes of the input and output match, the in-place version can be inserted immediately. In cases where that is not so, the operator can be replaced with a version that indicates that it is a candidate for being done in-place (e.g., add_inplace for add).
  2. A pass later in compilation that analyzes the candidates for in-place versions and ensures that the memory allocations are performed appropriately to facilitate these in-place versions (this could perhaps be done as part of StaticPlanBlockMemory). This would be where the proposed optimizations for split and concat would be done.

An alternative approach might be to do all analysis and replacements in one pass, which would have to come late in compilation since the more complex cases involve dealing with explicit memory allocations. However (see the phase-ordering discussion), exactly where in the phase ordering this pass should be done is unclear: It should come before legalization, since legalization would get rid of all the Relax operators, but it would also have to deal with explicit memory allocations, which is normally handled after legalization.

A third possibility would be to add another parameter to call_tir_inplace that gives the size of the output. The operators could thus be replaced immediately with call_tir_inplace when appropriate and StaticPlanBlockMemory would only have to handle the single case of call_tir_inplace. This would solve certain cases, like the broadcasting add example, but it is not clear how this would account for the concat or split cases. We could perhaps address those through some manner of specifying that the arguments or results should be contiguous in memory or share an underlying storage (would a single flag suffice?).

Summary

The two-staged approach seems like it would be the least difficult to design, though that also poses some phase-ordering difficulties (StaticPlanBlockMemory would become a bit overloaded in terms of its intended functionality). Specifying memory allocation details in call_tir_inplace might also be a reasonably simple solution. In any case, it is clear that dealing with memory allocations as a part of handling in-place operations is a source of complexity and I would welcome further discussion on how the implementation should be approached, particularly from the perspective of how memory allocations are represented in Relax.

4 Likes

Hi @slyubomirsky @tqchen , can we enable multiple outputs for call_tir_inplace?

We have a use case of fusing rotary embedding and flashattention in MLC-LLM, we suppose the programming interface would look like:

@T.prim_func
def fused_rotary_flashattention(k: T.Buffer(...), q: T.Buffer(...), v: T.Buffer(...), o: T.Buffer(...)):
    ...

updated_k, o = T.call_tir_inplace(fused_rotary_flashattention(k, q, v), inplace_shapes=(k_shape,), output_shapes=(o_shape,))

and the updates on k will be in place, and o will be a brand new output tensor.

To distinguish this with standard call_tir_inplace whose outputs are all in-place updated tensors, maybe we can rename this as call_tir_inplace_with_outputs.

2 Likes

Thanks for great input. here is another possible way of doing it

@T.prim_func
def func(input0: T.Buffer(...), input1: T.Buffer(...), output0: T.Buffer(...)):
     ...

updated_input0, output0 = T.call_tir_inplace(
    func, [input0, input1], R.Tuple([R.Tensor(), R.Tensor()), 
    inplace_indices=[0, -1])

Semantics written in python:

def call_tir_inplace(func, args, out_sinfo, inplace_indices):
    outputs = []
    for i, shape in enumerate(flatten_shape_as_list(out_sinfo)):
        if inplace_indices[i] >= 0:
            alloc = empty(shape)
            outputs.append(alloc)
    func(*args, *outputs)

Describe the semantics in text:

  • The semantics from callers’ pov is quite like call_tir. we always pass in the complete output sinfo we expect to the function
  • The inplace_indices specifies the inplace option of each output in their flattened ordering), if it is -1, it means we need to allocate a new output(when every elements are -1 it is the same as call_tir). Otherwise, it indicate that output should be passed as an inplace argument of args[i]

Some high level remarks:

  • preserving the out_sinfo allows us to be able to have the extra shape info, as well as specify the tuple structure of output return

We can also develop similar variants i think

5 Likes

I like @tqchen’s latest post with the inplace_indices, as that would avoid the need to have a separate built-in altogether. That is, instead of introducing a new built-in T.call_tir_inplace, the existing T.call_tir could have an optional argument for the in-place indices, which defaults to an empty list.

I have a couple of questions on how it would interact with other Relax functionality, along with my current brainstorming thoughts on it.

  • Does a Relax variable remain valid in the calling scope after being used as an in-place argument? I’d lean toward “no”, as it could contain a value other than its original tensor value.
  • Does inplace_indices need to specify which output buffer corresponds to a given input buffer? I would lean toward “no”, as it would only serve to allow the calling scope access to the no-longer-valid input tensors.
  • Does the use of inplace_indices mandate that the operation be performed in-place, or merely allow in-place operations? I would lean toward the latter, as it would allow a two-step process for lowering, with an initial pass annotating any arguments that are valid for used as in-place arguments (i.e. the last point-of-use of an internal allocation), and a second pass that replaces the implementations where possible. This way, only the first pass would require global information of the end-to-end flow.

If others have the same preferences for these questions, then this starts to sound more like a general model for transfer of ownership, where in-place operations are allowed when the tensor argument is transferred to the callee, and are opaque to the caller. This could also be used to express double-buffers, for example prev, current = T.call_tir_inplace(prev, current, inplace_indices=[0,1]) where internally the returned prev uses the same backing buffer as the argument current.

3 Likes

I think I agree with all the suggestions for those questions, but I wanted to clarify one doubt.

For the 3rd question, if inplace_indices is just a suggestion and does not guarantee that the operation happens inplace, how would future passes identify whether that call happens inplace or not?

After the second pass in your suggestion that does the actual replacements to be inplace where possible, it could clear the inplace_indices for the remaining ones where it failed. This way future passes can potentially take advantage of that information (especially memory planning passes).

personally i feel having an explicit builtin might be useful as usually the kind of optimizations can be different considering the existence of inplace ops

In this case, inplace_indices is mainly a way for us to matching to the prim_func, that contains inplace behavior. The original proposal would detect if the variable being inplace get “consumed” and if so, then we can do inplace. We can do such analysis in lowering the call_tir_inplace, then memory planner should be able to pick up

1 Like

I like @tqchen’s suggestion, especially the note about making call_tir_inplace general enough to specify that some arguments must be newly allocated and that others should be done in-place. What exactly happens, then, if the safety conditions for an in-place call are not met? Would the in-place argument be replaced with a copy?

Edit: Or is the intent that we would not insert the call_tir_inplace operator at all if the conditions aren’t met?

2 Likes

That makes sense, thanks.

Good point. I had been thinking of it from a discoverability standpoint, where it is much easier to find additional parameters of an existing function than a new function altogether.

After the second pass, I think the information would already be encoded by the PrimFunc implementing a specific operator. As TQ mentioned, the operators would be replaced by either a PrimFunc that performs in-place updates or a PrimFunc that fills an output buffer, and the memory planning could go from there.

The way I’m picturing it, the future passes would only know which PrimFunc had been inserted to replace the operator, and the inplace_indices would no longer be needed.

1 Like

Tracking issue posted: https://github.com/apache/tvm/issues/15319

Further comments on implementation issues are welcome both here or on the tracking issue

1 Like

@tqchen @psrivas2 (and others), in terms of phase ordering, could it might make more sense to do the inplace analysis after getting rid of dataflow blocks? I’m a little worried about creating the call_tir_inplace operator and treating it as pure when it is isn’t. Are there any passes we would want to do after lowering ops to in-place? Is there a compelling reason to do it early in compilation?

1 Like