Better memory planning for relax?


I am current try to make tvm running with big resolution image, so I want to internal memory could be reused whenever possible.

When I get thing work with relax, I find current relax memory plan still far from perfect. As below example shows, beside storage needed to be allocated to Var like lv1/lv_1, R.nn.avg_pool2d itself also need allocate a pieces of buffer.

Since the function run in serial, I wonder whether the allocation for lv1/lv_1 could be reused for R.nn.avg_pool2d?

class Module:
    def f(x1: R.Tensor((1, 3, "h", "w"), dtype="float16")) -> R.Tensor((1, 64, "h // 2", "w // 2"), dtype="float16"):
        h = T.int64()
        w = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv: R.Tensor((1, h, w, 3), dtype="float16") = R.permute_dims(x1, axes=[0, 2, 3, 1])
            lv_1: R.Tensor((1, h, w, 64), dtype="float16") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_leakyrelu_cutlass(lv, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1])
            lv1: R.Tensor((1, h, w, 64), dtype="float16") = cls.fused_relax_nn_conv2d_relax_add_relax_nn_leakyrelu1_cutlass(lv_1, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3])
            lv5: R.Tensor((1, h // 2, w // 2, 64), dtype="float16") = R.nn.avg_pool2d(lv1, pool_size=[2, 2], strides=[2, 2], dilation=[1, 1], padding=[0, 0, 0, 0], ceil_mode=False, layout="NHWC", out_layout="NHWC")
            gv: R.Tensor((1, 64, h // 2, w // 2), dtype="float16") = R.permute_dims(lv5, axes=[0, 3, 1, 2])
        return gv

usually input and output are not resued, however, if we run things in several stages, we can likely immediately reuse another buffer(say output of previous conv that is no longer needed), so in all the memory usage should be kept minimum

So for this avg_pool2d case, do we have existed mechanism for like telling avg_pool2d could use conv2d’s used buffer?

atm no, since we don’t support inplace planning by default, and it also requires special op implementation. Although i think when the layer goes slightly deeper the overall cost won’t increase as much, because there will be another temp swapping buffer and we only need one