How to pass tiled input to extern function in TensorIR?

Hi, I’m trying to tensorize with implement part calling extern c function named “extern03”. My question is buffer “B” below.

With “tvm_stack_make_array”, in extern function i can only get a pointer to original buffer B(ex: 128x128), not a 4x8 buffer. If i use original pointer to do index calculation as workaround, i still lack of original buffer shape info in extern function (ex: 128x128 is unknown in extern function).

Is there any wrong on my script, or a better way to achieve my goal?

Extern register like:

TVM_REGISTER_GLOBAL("extern03").set_body([](TVMArgs args, TVMRetValue* ret)

TensorIntrin Describe:

from tvm.script import tir as T
@T.prim_func
def dot_product_8x4_i32_desc(
    A: T.Buffer((4,),    "int16", offset_factor=1),
    B: T.Buffer((4, 8), "int16", offset_factor=1),
    C: T.Buffer((8,),    "int32", offset_factor=1),
) -> None:
with T.block("root"):
    T.reads(C[0:8], A[0:4], B[0:4, 0:8])
    T.writes(C[0:8])
    for i in T.serial(0, 8):
        with T.init():
            C[i] = T.int32(0)
        for k in T.serial(0, 4):
            with T.block("update"):
                vi, vk = T.axis.remap("SR", [i, k])
                C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vk, vi], "int32")

TensorIntrin Implement:

@T.prim_func
def dot_product_8x4_i32_impl(
    A: T.Buffer((4,),   "int16", offset_factor=1),
    B: T.Buffer((4, 8), "int16", offset_factor=1),
    C: T.Buffer((8,),   "int32", offset_factor=1),
) -> None:
with T.block("root"):
    T.reads(C[0:8], A[0:4], B[0:4, 0:8])
    T.writes(C[0:8])

    T.evaluate(
        T.tvm_call_packed(
            "extern03",
            T.tvm_stack_make_array(
                A.data,
                T.tvm_stack_make_shape(4, dtype="handle"),
                0,
                1,
                T.int16(0),
                0,
                dtype="handle"
            ),
            T.tvm_stack_make_array(
                B.data,
                T.tvm_stack_make_shape(4, 8, dtype="handle"),
                0,
                2,
                T.int16(0),
                0,
                dtype="handle"
            ),
            T.tvm_stack_make_array(
                C.data,
                T.tvm_stack_make_shape(8, dtype="handle"),
                0,
                1,
                T.int32(0),
                0,
                dtype="handle"
            ),
            T.int32(0),
            dtype="int32"
        )
    )