Hello! For some reason I need to replace all the BN ops by relay.add
in a relay graph, and here’s the example code:
import tvm
import tvm.relay as relay
from tvm.relay.testing import run_opt_pass
from tvm.relay.testing.mobilenet import conv_block, separable_conv_block, get_workload
from tvm.relay.dataflow_pattern import rewrite, is_var, wildcard, is_op, TupleGetItemPattern, DFPatternCallback
target_str = 'llvm -mcpu=core-avx2'
target = tvm.target.Target(target_str)
class ReplaceBatchNormCallback(DFPatternCallback):
# A callback class to rewrite the matched pattern to a batch_norm op.
def __init__(self, layout="NHWC"):
super(ReplaceBatchNormCallback, self).__init__()
self.layout = layout
self.x = is_var() | wildcard()
self.var = is_var()
self.mean = is_var()
self.beta = is_var()
self.gamma = is_var()
pattern = is_op('nn.batch_norm')(self.x, self.gamma, self.beta, self.mean, self.var)
tuple_get_item_node = TupleGetItemPattern(pattern, 0)
self.pattern = tuple_get_item_node
def callback(self, pre, post, node_map):
x = node_map[self.x][0]
beta = node_map[self.beta][0]
# if self.layout == "NCHW":
# beta = relay.expand_dims(beta, axis=-1, num_newaxis=2)
add = relay.add(x, beta)
return add
def example(image_shape, layout='NHWC'):
shape = (1, 224, 224, 3) if layout == 'NHWC' else (1, 3, 224, 224)
data = relay.var('data', shape=shape)
body = conv_block(data, "conv_block_1", 32, strides=(2, 2), layout=layout)
body = separable_conv_block(body, 'separable_conv_block_1', 32, 64, layout=layout)
body = separable_conv_block(body, 'separable_conv_block_2', 64, 128, downsample=True, layout=layout)
_, model_params = get_workload(batch_size=1, dtype='float32', image_shape=image_shape, layout=layout)
params = {}
for k, v in model_params.items():
if ("conv_block_1" in k) or ('separable_conv_block_1' in k) or ('separable_conv_block_2' in k):
params[k] = v
return relay.Function(relay.analysis.free_vars(body), body), params
f, params = example((3, 224, 224), layout='NCHW')
mod = tvm.IRModule.from_expr(f)
mod = relay.transform.InferType()(mod)
tmp_f = mod['main']
print("============ Rewrite")
tmp_f = rewrite(ReplaceBatchNormCallback(layout='NCHW'), tmp_f)
tmp_f = run_opt_pass(tmp_f, relay.transform.InferType())
However I got this error log:
============ Rewrite
Incompatible broadcast type TensorType([1, 32, 112, 112], float32) and TensorType([32], float32)
Incompatible broadcast type TensorType([1, 32, 112, 112], float32) and TensorType([32], float32)
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
The type inference pass was unable to infer a type for this expression.
This usually occurs when an operator call is under constrained in some way, check other reported errors for hints of what may of happened.
I think it’s because for NCHW format the 1D bias tensor should be expanded to [32, 1, 1]. It could be solved by uncommenting the lines in the callback
function. Since I have other graph processing steps to follow which need the bias tensor to stay 1D, I wonder if there’s another way of doing this?
A more general follow up question is: how could I specify functions like FInferCorrectLayout, FTVMConvertLayout, XXXRel, etc for a customer op if I wanna replace some ops in a graph with this op in a similar case, i.e. the op has 1D inputs and the format is NCHW?
Thanks in advance!