InplaceTransformParams

Background

The transform_params pass allows us to lift the original weight pre-transforms into a separate function, so we can transform the layout of the paramters.

class Before:
    @R.function
    def main(x: R.Tensor,
             params: R.Tuple([R.Tensor, R.Tensor]):
        w0 = R.transpose(params[0])
        lv0 = R.mm(x, w0)
        w1 = R.tranpose(params[1])
        lv1 = R.mm(lv0, w1)
        return lv1

class AfterTransformParams:
    @R.function
    def main(x: R.Tensor,
             params: R.Tuple([R.Tensor, R.Tensor]):
        lv0 = R.mm(x, params[0])
        lv1 = R.mm(lv0, params[1])
        return lv1

    @R.function
    def transform_params(params):
        w0 = R.transpose(params[0])
        w1 = R.transpose(params[1])
        new_params = (w0, w1)
        return new_params

One can then call transform_params either at compile time or runtime to pre-transform the layout of the parameters.

When we are working on LLM settings, we face the problem that different model platform pair may require different weight layout. For example, the fast transformer kernel requires an interleaved layout. We might also introduce more complicated data layout in the future via layout tuning.

Most of our solution so far focuses on compile-time layout transformation, which means we pre-transform the weight to the corresponding format when running convert_weight. This can leads to several different weight format, such as: q4f16_0, q4f16_1 or q4f16_ftgroup, both are reasonably compatible with each other up to permutations.

Ideally, we would like to enable weight loading from one common format, and allow runtime convertion of the layout format. This is especially important as we start to inroduce more platforms and layout related optimizations.

Challenge and Solution

While in theory we can keep and compile transform_params in runtime. We still face a key challenge here: memory consumption. Specificially, LLM weights are huge, and directly running transform_params in AfterTransformParams will results in double memory cost. Knowing that we only need a single copy of the weight parameter. There is one natural solution, we can introduce another pass, or simply update TransformParams pass to inplace update the params tuple, knowing that we will not need the original copy of params:

 class AfterInplaceTransformParams:
    @R.function
    def main(x: R.Tensor,
             params: R.Tuple([R.Tensor, R.Tensor]):
        lv0 = R.mm(x, params[0])
        lv1 = R.mm(lv0, params[1])
        return lv1

    @R.function
    def inplace_transform_params(params):
        w0 = R.transpose(params[0])
        p0 = R.inplace_update_tuple(params, 0, w0)
        w1 = R.transpose(params[1])
        p1 = R.inplace_update_tuple(p0, 1, w1)
        return p1

In this particular case, we only need to pay one extra temporary memory cost to perform the inplace transformation.

Other things to note:

  • we should make sure the original reference to the ndarray is cleared, especially when they come from global NDArrayCache, by calling NDArrayCache::Clear before calling transform_params
  • Additionally, we can also hint the memory planner to explicitly de-allocate the temp memory in transform params, as they won’t be used after the transformation.

Followup

The inplace transform params opens up the following opportunities:

  • We can reduce the amount of model weight format in convert weight stage, e.g. we only need to keep q4f16_1 for example.
  • We can perform platform dependent layout optimizations
1 Like

I like this approach, and especially like maintaining the transform_params function within the compiled module.

I think the majority of the in-place handling could be done by first applying LazyTransformParams, then using the in-place passes that @slyubomirsky has been working on (the DataflowUseInplaceCalls pass, then FuseTIR with in-place support. The LazyTransformParams would make the parameters be internal objects, and therefore eligible for in-place operations. Afterwards, the DataflowUseInplaceCalls could update the parameter transformations to be performed in-place.

1 Like

Additionally, we can also hint the memory planner to explicitly de-allocate the temp memory in transform params, as they won’t be used after the transformation.

If we use LazyTransformParams, I think we get this for free from the KillAfterLastUse pass, currently part of the default relax lowering pipeline. So long as the get_item function doesn’t cache the NDArray it returns, the only remaining reference would be in the VM. When the R.vm.kill_tensor call is encountered, that would remove the last reference to the temp memory.

This sounds right to me. Otherwise we would need a way to indicate that it is okay to modify the parameters in-place (if we have a pass called InplaceTransformParams, then this is probably okay if the user supplies a list of params that are expected to be modified in this manner).

Let’s give my understanding of why we require a InplaceTransformParams Infra.

For example, in mlc-llm, we typically have two relax functions:

@R.function
def prefill(input_params, model_params) -> res: 
	...

@R.function
def decode(input_params, model_params) -> res: 
	...

The initial question is, why do we separate model parameters and the runtime module instead of packaging the weight into the TVM module with metadata? Could it be because the weight is too large?

Say we have to do some layout transform to optimize memory access:

@R.function
def prefill(input_params, model_params) -> res: 
	call(te.layout_transform, model_params[0])
	...

@R.function
def decode(input_params, model_params) -> res: 
	call(te.layout_transform, model_params[0])
	...

If we place model_params as function arguments, it becomes difficult to perform constant folding. The compiler doesn’t know whether the input model_params is a constant array. Furthermore, even though we implement an advanced FoldConstant Pass with relax by considering additional information that connects input parameters to constant parameters, it’s challenging to perform constant folding as well. This is because the input parameters might be consumed by different functions, which may require propagating various operators into constants (although in MLC-LLM, we can assume prefill and decode propagate the same items).

That’s why we should generate in-place transform functions on argument. And we could potentially utilize structural equality comparison to eliminate identical weights,

in this case, we wanrt to have InplaceTransformParams tha is aware that the second params is a parameter, and lift out a runtime pretransform function