Hi! I am try to understand the tir. While reading the tutorials of vta, I saw this codes, and I am wondering where can I get more information about this ’ tir.vta.uop_push’?
/*!
-
\brief Push uop into kernel buffer.
-
In GEMM mode, do a blocked GEMM with 2d access pattern.
-
In ALU mode, do a vectorized ALU operation with 2d access pattern.
-
\code
-
DType accum[INP_BUFF_DEPTH][l][n];
-
DType weight[WGT_BUFF_DEPTH][n][m];
-
DType input[INP_BUFF_DEPTH][l][m];
-
if reset_out == 1
-
accum[dst_index] = 0
-
elif mode == 0
-
accum[dst_index] += GEMM(input[src_index], weight[wgt_index]);
-
else
-
if (use_imm)
-
accum[dst_index] = opcode(accum[dst_index], imm_val);
-
else
-
accum[dst_index] = opcode(accum[dst_index], accum[src_index]);
-
\endcode
-
\param mode Set to GEMM mode if set to 0, ALU mode is set to 1.
-
\param reset_out Resets the accum to 0.
-
\param dst_index The accum memory index.
-
\param src_index The input memory (gemm) / accum memory (alu) index.
-
\param wgt_index The weight memory index.
-
\param opcode The ALU opcode.
-
\param use_imm Use immediate in ALU mode if set to true.
-
\param imm_val Immediate value in ALU mode.
*/
TVM_DLL void VTAUopPush(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index,
uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val);