I’m just beginning to learn tvm, please correct me if anything wrong.
The equivalent question is: Is decreased-strided DLTensor stored in row-major?
For example, let’s say a compact 3d tensor of shape (4,3,2). If it’s stored in row major, the strides will be (6,2,1).
I’m really confused about the code here:
./src/runtime/contrib/cblas/gemm_utils.h
where C = matmul(A, B) is encapsulated
In function CallGemm
:
// Reversed strides indicates an in-place transpose operation.
inline bool IsInPlaceTransposed(DLTensor* tensor) {
return tensor->strides && (tensor->strides[1] > tensor->strides[0]);
}
template <typename TGemmOp>
inline void CallGemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8;
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK_EQ(ElementStride(A), 1);
CHECK_EQ(ElementStride(B), 1);
CHECK_EQ(ElementStride(C), 1);
// C can never be transposed.
CHECK(!IsInPlaceTransposed(C));
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;
CHECK(TypeMatch(B->dtype, kDLFloat, bit_depth));
CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa),
static_cast<typename TGemmOp::TDatatype>(alpha),
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(B->data) + B->byte_offset),
ColumnStride(B),
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(A->data) + A->byte_offset),
ColumnStride(A), static_cast<typename TGemmOp::TDatatype>(beta),
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(C->data) + C->byte_offset),
ColumnStride(C));
}
I wonder whether these two checks are redundant?
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;
Since A
and B
here must be row-majored (B
in op call is ahead of A
, which is a trick for using col-majored matmul on row-majored stored A
and B
), A
and B
must have decreased strides.
That is to say, A->stride[0] must be greater than A->stride[1]. So, IsInPlaceTransposed(A) must be True.