Understanding Lower TIR


I am writing a simple gemm code and speeding up the code with tvm autoschedule. But I now have some difficulties in understanding what the lower tir is doing and how I should interpret what the auto-scheduler did to accelerate the code. I have pasted the base lower tir code and the accelerated version below. It would help a lot if someone could point me to the documents of lower tir scripts or help me interpret the scripts!

# The code that I am trying to accelerate
@auto_scheduler.register_workload  # Note the auto_scheduler decorator
def matmul(M, N, K, dtype):
    A = te.placeholder((M, K), name="A", dtype=dtype)
    B = te.placeholder((K, N), name="B", dtype=dtype)

    k = te.reduce_axis((0, K), name="k")
    matmul = te.compute(
        (M, N),
        lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
        attrs={"layout_free_placeholders": [B]},  # enable automatic layout transform for tensor B

    return [A, B, matmul]

Base script:
    @main = primfn(A_1: handle, B_1: handle, matmul_1: handle) -> ()
      attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
      buffers = {matmul: Buffer(matmul_2: Pointer(uint8), uint8, [32, 128], []),
                 B: Buffer(B_2: Pointer(uint8), uint8, [64, 128], []),
                 A: Buffer(A_2: Pointer(uint8), uint8, [32, 64], [])}
      buffer_map = {A_1: A, B_1: B, matmul_1: matmul} {
      for (i: int32, 0, 32) {
        for (j: int32, 0, 128) {
          matmul_2[((i*128) + j)] = 0u8
          for (k: int32, 0, 64) {
            let cse_var_1: int32 = ((i*128) + j)
            matmul_2[cse_var_1] = ((uint8*)matmul_2[cse_var_1] + ((uint8*)A_2[((i*64) + k)]*(uint8*)B_2[((k*128) + j)]))

Autotuned script:
    @main = primfn(A_1: handle, B_1: handle, matmul_1: handle) -> ()
      attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
      buffers = {matmul: Buffer(matmul_2: Pointer(float32), float32, [2048, 2048], []),
                 B: Buffer(B_2: Pointer(float32), float32, [2048, 2048], []),
                 A: Buffer(A_2: Pointer(float32), float32, [2048, 2048], [])}
      buffer_map = {A_1: A, B_1: B, matmul_1: matmul} {
      allocate(auto_scheduler_layout_transform: Pointer(global float32), float32, [4194304]), storage_scope = global {
        for (ax0.ax1.fused.ax2.fused: int32, 0, 32) "parallel" {
          for (ax4: int32, 0, 64) {
            for (ax6: int32, 0, 32) {
              for (ax7: int32, 0, 64) {
                auto_scheduler_layout_transform[((((ax0.ax1.fused.ax2.fused*131072) + (ax4*2048)) + (ax6*64)) + ax7)] = (float32*)B_2[((((ax4*65536) + (ax6*2048)) + (ax0.ax1.fused.ax2.fused*64)) + ax7)]
        for (i.outer.outer.j.outer.outer.fused.i.outer.inner.fused: int32, 0, 512) "parallel" {
          allocate(matmul.local: Pointer(local float32), float32, [2048]), storage_scope = local;
          for (j.outer.inner: int32, 0, 4) {
            for (i.c.outer.inner.init: int32, 0, 32) {
              for (j.c.inner.init: int32, 0, 64) {
                matmul.local[((i.c.outer.inner.init*64) + j.c.inner.init)] = 0f32
            for (k.outer: int32, 0, 64) {
              for (i.c.outer.inner: int32, 0, 32) {
                for (k.inner: int32, 0, 32) {
                  for (j.c.inner: int32, 0, 64) {
                    let cse_var_1: int32 = ((i.c.outer.inner*64) + j.c.inner)
                    matmul.local[cse_var_1] = ((float32*)matmul.local[cse_var_1] + ((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 256)*2097152) + (floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 32)*65536)) + (i.c.outer.inner*2048)) + (k.outer*32)) + k.inner)]*(float32*)auto_scheduler_layout_transform[(((((floordiv(floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 256), 32)*524288) + (j.outer.inner*131072)) + (k.outer*2048)) + (k.inner*64)) + j.c.inner)]))
            for (i.inner: int32, 0, 32) {
              for (j.inner: int32, 0, 64) {
                matmul_2[((((((floordiv(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 256)*2097152) + (floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 32)*65536)) + (i.inner*2048)) + (floordiv(floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 256), 32)*256)) + (j.outer.inner*64)) + j.inner)] = (float32*)matmul.local[((i.inner*64) + j.inner)]