Thank you for this detailed explanation! We digest the content and try to apply this concept to an existing pass.
There are still many implementation details we have not figured out. Yet the following is how we illustrate the var
mechanism should be like. Please kindly help us if we misunderstand anything.
Goal
Implement a pass to construct a graph. The graph is a tracing map to record the transformation before and after a pass.
What the map should looks like
Personally I would prefer the key are the f, the new equivalent now, and value are the original var. It should be more
convienent for us to trace back to the source. So it should be like:
Map<Var,Var>
// Keys are the equivalent now inside f
// Values are originally-imported Var.
Because after a sequence of pass transformations, we would have a final IRModule. Select a certain expression in the
final IRModule[“main”], we can trace back to the source. If we use the the originally-imported Var as Key. Perhaps we
have to iterate through all the map to find the resulted Var after transformations.
How to invoke
Considering the function GetPassPrefix
in “src/relay/backend/utils.cc” we insert a pass OutLiner
between passes:
//...
pass_seqs.push_back(transform::SimplifyInference());
pass_seqs.push_back(OutLiner);
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(OutLiner);
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(OutLiner);
//...
Process looks like
Take the Relay Pass, SimplifyInference
for example, it unpacks certain Calls like batch norm op. The following
image is a part of result after the transformation of SimplifyInference pass in our Explorer.
It takes the batch_norm call and its tupleGeItem as source exprs and unpacks them to a set of basic operations.
Now the following is the process once we introduce the OutLiner
pass:
Back to the IR pretty print, we would start from IR[“main”] here:
def main(...) {
%0 = nn.conv2d(%input, %model.conv1.weight,...) /* si=torch._convolution_3 */;
%1 = nn.batch_norm(%0,...) /* si=torch.batch_norm_8 */;
%2 = %1.0 /* si=torch.batch_norm_8 */;
}
After the SimplifyInference
the IR[“main”] becomes:
def main(...) {
%0 = add(%model.bn1.running_var, 1e-05f) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%1 = sqrt(%0) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%2 = divide(1f , %1) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%3 = multiply(%2, %model.bn1.weight) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%4 = nn.conv2d(%input, %model.conv1.weight,...) /* si=torch._convolution_3 */;
%5 = expand_dims(%3, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%6 = negative(%model.bn1.running_mean) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%7 = multiply(%6, %3) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%8 = add(%7, %model.bn1.bias) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%9 = multiply(%4, %5) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%10 = expand_dims(%8, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%11 = add(%9, %10) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
}
Now it is the time to invoke OutLiner
. It generates another global function, outlined_bn_0.
def main(...) {
%0 = nn.conv2d(%input, %model.conv1.weight,...) /* si=torch._convolution_3 */;
%1 = @outlined_bn_0(%0,...)
}
def outlined_bn_0(%i1...) {
%0 = add(%model.bn1.running_var, 1e-05f) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%1 = sqrt(%0) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%2 = divide(1f , %1) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%3 = multiply(%2, %model.bn1.weight) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%4 = expand_dims(%3, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%5 = negative(%model.bn1.running_mean) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%6 = multiply(%5, %3) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%7 = add(%6, %model.bn1.bias) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%8 = multiply(%i1, %4) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%9 = expand_dims(%7, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%10 = add(%8, %9) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
}
#Perhaps we would need the original main as reference
def main_before_SimplifyInference_0(){
#...
}
On the same time, we maintain our the tracing map like this (Key and value should be a Var, yet I am not pretty sure show
to exress them in a Var form).
# key: transformed result
# values: original things
map = {
hash(outlined_bn_0): {%1-batch_norm, %2-%1.0}
}
Using the graph constructed by tracing map, we should be able to trace an IR back to its very original form. Perhaps the
functionality of OutLiner
might be Implemented based on StructuralEqual
. But we haven’t come up a good idea for this
currently. Still, if this OutLiner
is Implementalbe, it will be really convenient.
Questions
In here we come up some questions about this strategy:
- What IRModule would be used once the
OutLiner
is invoked? Should be IR1 but not the IR2, right?
def main(...) {
%0 = nn.conv2d(%input, %model.conv1.weight,...) /* si=torch._convolution_3 */;
%1 = @outlined_bn_0(%0,...)
}
def outlined_bn_0(%i1...) {
%0 = add(%model.bn1.running_var, 1e-05f) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%1 = sqrt(%0) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%2 = divide(1f , %1) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%3 = multiply(%2, %model.bn1.weight) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%4 = expand_dims(%3, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%5 = negative(%model.bn1.running_mean) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%6 = multiply(%5, %3) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%7 = add(%6, %model.bn1.bias) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%8 = multiply(%i1, %4) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%9 = expand_dims(%7, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%10 = add(%8, %9) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
}
def main(...) {
%0 = add(%model.bn1.running_var, 1e-05f) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%1 = sqrt(%0) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%2 = divide(1f , %1) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%3 = multiply(%2, %model.bn1.weight) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%4 = nn.conv2d(%input, %model.conv1.weight,...) /* si=torch._convolution_3 */;
%5 = expand_dims(%3, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%6 = negative(%model.bn1.running_mean) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%7 = multiply(%6, %3) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%8 = add(%7, %model.bn1.bias) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%9 = multiply(%4, %5) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%10 = expand_dims(%8, axis=1, num_newaxis=2) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
%11 = add(%9, %10) /* si=[ torch.batch_norm_8, torch.batch_norm_8 ] */;
}
-
If we choose the IR1, and continue the transformations of the rest of passes. It might end in a nested form. The
readiblity should become very terrible. Perhaps a unpack pass for outlined_fn is requried too, right?
-
Still about the nested form, if we use the nested form like IR1, many pattern matching things may need to rewrite,
because now they need to check the outlined_fn in the graph. The complexity of Implement a pass might increase.
Thank you for reading such long post. it feels great that we can try to figure a better way to maintain the source
information.