Inplace and Fusion Opportunities
This is a quick note about what opportunities we can enable through inplace support. While inplace operations do not necessarily bring better speedup, there are quite a bit opportunities in turning the program to make better use of memory.
One motivating example is the embedding operator in LLM serving. Consider a multimodal LLM where we have vision encoding and text embedding look ups.
class Module:
def image_encoding(image, params):
# a simple vision pipeline to project into embedding space
lv0 = conv2d(image, params[0])
lv1 = flatten(lv0)
lv2 = matmul(lv1, params[1])
return lv2
def embedding_lookup(token_ids, params):
# look up embedding for text
lv0 = embedding_lookup(token_ids, params[3])
return lv0
In order to make use of both image and text modality, we will need to run embedding lookup, and image encoding, then concatenate them together during runtime.
def runtime_make_embedding(mod, token_ids, image, params):
"""This code runs in engine to enable some flexible
customization of batching/chunking strategy"""
img_embedding = mod["image_encoding"](image, params)
text_embedding = mod["embedding_lookup")(token_ids, params)
# assume we had a concat function
final_embedding = mod["concat"](text_embedding, img_embedding)
return final_embedding
Note that the above example is somewhat simplified, as there might be interleave of text before/after images, which requires us to have different method of concatenation. The main issue here is that we have to support allocation for img_embedding
, text_embedding
and final_embedding
separately. One main challenge we are facing here is that the size of image text embedding can change depending on scenarios. In serving setting, we would like to have strong control of memory use, and ensure as static allocation as possible. A typical way to control is to enable a maximum chunk size, which corresponds to the maximum sequence that final_embedding
can take. We would then use the remaining memory for other use cases(e.g. maximize the kv cache).
For this scenario, ideally we would like a different way to call the function (through destinaiton passing), as show below
class ModuleUpdated:
@T.prim_func
def copy_into(
out: Buffer((256, 128), "f16"),
inp: Buffer(("m", 128),"f16"),
offset: T.int32
):
m = T.int32()
for i, j in grid(m, 128):
out[i + offset, j] = inp
def image_encoding(image, params,
final_embedding : R.Tensor((256, 128), "f16"),
offset):
# a simple vision pipeline to project into embedding space
lv0 = conv2d(image, params[0])
lv1 = flatten(lv0)
lv2 = matmul(lv1, params[1])
lv3 = call_tir_place(
copy_into, [lv2, final_embedding],
R.Tensor((256, 128), "f16"),
inplace_indices=[1]
)
return lv3
def embedding_lookup(token_ids, params,
final_embedding : R.Tensor((256, 128), "f16"),
offset):
# look up embedding for text
lv0 = embedding_lookup(token_ids, params[3])
lv1 = call_tir_place(
copy_into, [lv0, final_embedding],
R.Tensor((256, 128), "f16"),
inplace_indices=[1]
)
return lv1
Some remarks:
- 256 is the maximum sequence length of the final embedding
- The image encoding and embedding lookup now explicitly take the final output and an offset(that indicates from which sequence index we should start write things into)
- This can be viewed as a special form of destination passing where we directly write into a region of final data structure
- Now the engine can just allocate a fixed size embedding based on the maximum chunk size
def runtime_make_embedding(mod, token_ids, image, params):
"""This code runs in engine to enable some flexible
customization of batching/chunking strategy"""
final_embedding = pre_allocate((256, 128), "f16")
# run inplace operations
final_embedding = mod["image_encoding"](image, params, final_embedding, 0)
final_embedding = mod["embedding_lookup")(
token_ids, params, final_embedding, image_encoding_len)
return final_embedding
Additionally, we can enhance FuseOps and FuseTensorIR to allow fusion of the copy_into
with previous tensorIR operators, this would allows us to get a directly fused operator and allow the matmul or embedding lookup to directly write into the corresponding location of final_embedding.
Noteably, this needs comes when we have a interesting runtime and compiled function interactions. aka, the runtime would like to have some form of fixed memory size allocate for concatenated values, and we can use techniques to transform a function that previously do not handle inplace to an inplace form that is memory efficient.