Flattening a cache_read created buffer

When I schedule a convolution, my 4D Input tensor will be moved into my on-chip memory, which is only a 2D data structure. Is there a way to represent this using cache_read? It seems like this is only supported with te.compute, but I am not sure.

The primitive transform_layout helps

1 Like

Hi,siyuan I want to use tvm script to codegen a quantization and gemm fused op. The input matrix A is row major, matrix B is column major(both in gmem and shared mem). But default matrix layout is row-major in tvm. will this primitive ransform_layout help to solve my problem?

Thanks! I experimented with it, but it doesn’t seem like this in itself turns a 4D buffer into a 2D one. I applied an AXIS_SEPARATOR, but after that I still seem to miss a step to reduce the dimensionality of the buffer. Could you help me out here?

I found the FlattenBuffer transform, but that can’t be used to modify the mod of my schedule. Is there something else that can be used?

Sorry I have no idea about this, as it’s not a case used in the current target

Thanks! I looked a bit more into the function you provided originally and found an example, but I can’t get it to work with my example, here is the mvin block for the input data:

        for ax0, ax1, ax2, ax3 in T.grid(2, 9, 65, 8):
            with T.block("a_in_local.spad"):
                v0 = T.axis.spatial(4, n_0 * 2 + ax0)
                v1 = T.axis.spatial(256, p_0 * 8 + ax1)
                v2 = T.axis.spatial(256, q_0 * 64 + ax2)
                v3 = T.axis.spatial(128, ic_0 * 8 + ax3)
                T.where(p_0 * 8 + ax1 < 256 and q_0 * 64 + ax2 < 256)
                T.reads(a_in[v0, v1, v2, v3])
                T.writes(a_in_local_spad[v0, v1, v2, v3])
                a_in_local_spad[v0, v1, v2, v3] = a_in[v0, v1, v2, v3]

It is created by a cache_read function. And here is the example function I use to reindex:

sch.transform_layout(block_cwght, buffer=("read", 0),
                    index_map=lambda n, h, w, c:(
                        (((n * 585) + (h * 65)) + w),
                        c
                    ),
                )

This gives me a Could not parse mapping as sum of iterators. Error: Could not normalize iterators error. Do you have an idea how I can formulate the reindex here? I recognize that the T.where() is an issue, but I don’t know of a way to avoid that.