I use pytorch huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad or other huggingface transfomer model like bert-base-uncased to test sparse inference in TVM. However, when I try to find the dense_op in relay.Expr of the model, I find nothing. The following are my steps.
# step1
mod, params = relay.frontend.from_pytorch(model, shape_list)
# step2
mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params)
# step3
dense_weight_names = relay.analysis.sparse_dense._search_dense_op_weight(mod)
# dense_weight_names = []
The reason is in step2 simplify_fc_transpose cannot successfully transpose the transpose_op into one line which is duplicated, so _search_dense_op_weight cannot get dense_op. I show part of the mod of step2’s output in the following.
%0 = cast(%input_ids, dtype="int32");
%1 = full(0, shape=[1, 128], dtype="int64");
%2 = cast(%1, dtype="int32");
%3 = take(%model.embeddings.word_embeddings.weight, %0, axis=0);
%4 = take(%model.embeddings.token_type_embeddings.weight, %2, axis=0);
%5 = strided_slice(%model.embeddings.position_ids, begin=[0, 0], end=[1, 128], strides=[1, 1]);
%6 = cast(%5, dtype="int32");
%7 = add(%3, %4);
%8 = take(%model.embeddings.position_embeddings.weight, %6, axis=0);
%9 = add(%7, %8);
%10 = nn.layer_norm(%9, %model.embeddings.LayerNorm.weight, %model.embeddings.LayerNorm.bias, epsilon=1e-12f);
%11 = nn.dropout(%10, rate=0.1f);
%12 = %11.0;
-> %13 = transpose(%model.encoder.layer.0.attention.self.query.weight, axes=[1, 0]);
%14 = reshape(%12, newshape=[-1, 768]);
-> %15 = transpose(%13, axes=[1, 0]);
-> %16 = nn.dense(%14, %15, units=None);
To let _search_dense_op_weight find dense_op, relay.Expr should be like the following. The VarNode is exactly in CallNode, showing in this code.
%16 = nn.dense(%14, %model.encoder.layer.0.attention.self.query.weight, units=None);
And to let VarNode exactly in CallNode, the transpose should combine and not be duplicated. We can find simplify_fc_transpose can only transpose the following relay.Expr, showing in this code.
%14 = reshape(%12, newshape=[-1, 768]);
%15 = transpose(%model.encoder.layer.0.attention.self.query.weight, axes=[1, 0]);
%16 = nn.dense(%14, %15, units=None);
Does anyone know which functions should we fix? Or can anyone help to fix the bug?
-
relay.frontend.from_pytorch() should merge
transpose_op
-
simplify_fc_transpose.convert() should merge
transpose_op
-
_search_dense_op_weight() should trace duplicated
transpose_op
Thanks!