Hi
I was trying to run standard TVM transformation passes on Resnet101 model and instrument the time taken during those passes.
Here is the code used for testing
import tvm
from tvm.relay.build_module import bind_params_by_name
import tvm.relay.transform as _transform
from tvm.ir.instrument import (
PassTimingInstrument,
pass_instrument,
)
import tensorflow as tf
if __name__ == '__main__':
model = tf.keras.applications.resnet.ResNet101(
include_top=False,
weights='imagenet',
pooling=None,
classes=1000,
input_shape=(224,224,3),
)
shape_dict = {'input_1': (1, 224, 224, 3)}
mod, params = tvm.relay.frontend.from_keras(model,shape_dict, layout="NHWC")
print(mod)
mod["main"] = bind_params_by_name(mod["main"], params)
seq = tvm.transform.Sequential(
[
_transform.InferType(),
_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
]
)
timing_inst = PassTimingInstrument()
with tvm.transform.PassContext(opt_level=3, instruments=[timing_inst]):
mod = seq(mod)
profiles = timing_inst.render()
print(profiles)
The output is
sequential: 235372886us [154us] (100.00%; 100.00%)
InferType: 1159522us [1159522us] (0.49%; 0.49%)
InferType: 1160940us [1160940us] (0.49%; 0.49%)
SimplifyInference: 1378713us [88561us] (0.59%; 0.59%)
InferType: 1290152us [1290152us] (0.55%; 93.58%)
FoldConstant: 171760759us [170613491us] (72.97%; 72.97%)
InferType: 1147268us [1147268us] (0.49%; 0.67%)
FoldScaleAxis: 59912799us [217us] (25.45%; 25.45%)
InferType: 1089359us [1089359us] (0.46%; 1.82%)
BackwardFoldScaleAxis: 1566861us [387372us] (0.67%; 2.62%)
InferType: 1179489us [1179489us] (0.50%; 75.28%)
InferType: 1112379us [1112379us] (0.47%; 1.86%)
ForwardFoldScaleAxis: 1767860us [619704us] (0.75%; 2.95%)
InferType: 1148156us [1148156us] (0.49%; 64.95%)
FoldConstant: 54376124us [53259735us] (23.10%; 90.76%)
InferType: 1116389us [1116389us] (0.47%; 2.05%)
These passes all written in C++ take a total time of 235 sec out of which Constant folding takes around 171 sec + 54 sec ~= 225 sec (> 3min).
Resnet101 is batchnorm heavy network (total of 104 batchnorm ops) and the passes used in the above code are to merge batchnorm with its preceding conv.
My doubt is, as these passes are written in C++, the high compile time taken for these passes justified ? Can’t we make these transformation passes more faster??