Background
The transform_params
pass allows us to lift the original weight pre-transforms into a separate function, so we can transform the layout of the paramters.
class Before:
@R.function
def main(x: R.Tensor,
params: R.Tuple([R.Tensor, R.Tensor]):
w0 = R.transpose(params[0])
lv0 = R.mm(x, w0)
w1 = R.tranpose(params[1])
lv1 = R.mm(lv0, w1)
return lv1
class AfterTransformParams:
@R.function
def main(x: R.Tensor,
params: R.Tuple([R.Tensor, R.Tensor]):
lv0 = R.mm(x, params[0])
lv1 = R.mm(lv0, params[1])
return lv1
@R.function
def transform_params(params):
w0 = R.transpose(params[0])
w1 = R.transpose(params[1])
new_params = (w0, w1)
return new_params
One can then call transform_params
either at compile time or runtime to pre-transform the layout of the parameters.
When we are working on LLM settings, we face the problem that different model platform pair may require different weight layout. For example, the fast transformer kernel requires an interleaved layout. We might also introduce more complicated data layout in the future via layout tuning.
Most of our solution so far focuses on compile-time layout transformation, which means we pre-transform the weight to the corresponding format when running convert_weight
. This can leads to several different weight format, such as: q4f16_0, q4f16_1 or q4f16_ftgroup, both are reasonably compatible with each other up to permutations.
Ideally, we would like to enable weight loading from one common format, and allow runtime convertion of the layout format. This is especially important as we start to inroduce more platforms and layout related optimizations.
Challenge and Solution
While in theory we can keep and compile transform_params
in runtime. We still face a key challenge here: memory consumption. Specificially, LLM weights are huge, and directly running transform_params
in AfterTransformParams
will results in double memory cost. Knowing that we only need a single copy of the weight parameter. There is one natural solution, we can introduce another pass, or simply update TransformParams pass to inplace update the params tuple, knowing that we will not need the original copy of params:
class AfterInplaceTransformParams:
@R.function
def main(x: R.Tensor,
params: R.Tuple([R.Tensor, R.Tensor]):
lv0 = R.mm(x, params[0])
lv1 = R.mm(lv0, params[1])
return lv1
@R.function
def inplace_transform_params(params):
w0 = R.transpose(params[0])
p0 = R.inplace_update_tuple(params, 0, w0)
w1 = R.transpose(params[1])
p1 = R.inplace_update_tuple(p0, 1, w1)
return p1
In this particular case, we only need to pay one extra temporary memory cost to perform the inplace transformation.
Other things to note:
- we should make sure the original reference to the ndarray is cleared, especially when they come from global NDArrayCache, by calling
NDArrayCache::Clear
before calling transform_params - Additionally, we can also hint the memory planner to explicitly de-allocate the temp memory in transform params, as they won’t be used after the transformation.
Followup
The inplace transform params opens up the following opportunities:
- We can reduce the amount of model weight format in convert weight stage, e.g. we only need to keep
q4f16_1
for example. - We can perform platform dependent layout optimizations