Hi @mbrookhart – No problem. I am including the complete code.
class ReWriteInputs(ExprMutator):
"""This pass partitions the subgraph based on the if conditin
"""
def __init__(self):
super().__init__()
self.inputs = []
def visit_function(self, fn):
"""This function returns concatenated add operators for a one add operator.
It creates multiple add operators.
:param call:
:return:
"""
new_params = []
for x in range(len(fn.params)):
if len(fn.params[x].type_annotation.shape)==1:
d = fn.params[x].type_annotation.shape
var_new = relay.var(fn.params[x].name_hint, shape=(1, 1, 1, d[0]), dtype=fn.params[x].type_annotation.dtype)
new_params.append(var_new)
else:
new_params.append(fn.params[x])
new_body = self.visit(fn.body)
print("Visited all", new_body)
func = relay.Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs)
return func
def visit_var(self, var):
if len(var.type_annotation.shape) == 1:
print("Dimension needs to change from 1d to 4d")
d = var.type_annotation.shape
var_new = relay.var(var.name_hint, shape=(1, 1, 1, int(d[0])), dtype=var.type_annotation.dtype)
return var_new
else:
print("Do nothing for other cases")
return var
This is how I am calling.
re_write_inputs = ReWriteInputs.ReWriteInputs()
f = module['main']
new_function = re_write_inputs.visit(f)
new_module = tvm.IRModule({"main": new_function})
print(new_module)
Here are the original module and the module after running 1D to 4D.
Original module
def @main(%input: Tensor[(1, 1, 28, 28), float32], %condition: Tensor[(1), bool], %v10: Tensor[(32, 32, 3, 3), float32], %v11: Tensor[(32), float32], %v12: Tensor[(32, 32, 3, 3), float32], %v13: Tensor[(32), float32], %v2: Tensor[(32, 1, 3, 3), float32], %v3: Tensor[(32), float32], %v4: Tensor[(32, 32, 1, 1), float32], %v5: Tensor[(32), float32], %v6: Tensor[(32, 64, 1, 1), float32], %v7: Tensor[(32), float32], %v8: Tensor[(64, 32, 1, 1), float32], %v9: Tensor[(64), float32]) {
%0 = take(%condition, 0);
%11 = if (%0) {
%1 = nn.conv2d(%input, %v2, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
%2 = nn.bias_add(%1, %v3);
%3 = nn.relu(%2);
%4 = nn.conv2d(%3, %v8, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
%5 = nn.bias_add(%4, %v9);
%6 = nn.relu(%5);
%7 = nn.conv2d(%6, %v6, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
%8 = nn.bias_add(%7, %v7);
nn.relu(%8)
} else {
%9 = nn.conv2d(%3, %v4, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
%10 = nn.bias_add(%9, %v5);
nn.relu(%10)
};
%12 = nn.conv2d(%11, %v10, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
%13 = nn.bias_add(%12, %v11);
%14 = nn.relu(%13);
%15 = nn.conv2d(%14, %v12, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
%16 = nn.bias_add(%15, %v13);
nn.relu(%16)
}
The module after 1D to 4D pass. It creates a DANGLING free vars with RANDOM names (see: free_var %v111)
def @main(%input: Tensor[(1, 1, 28, 28), float32], %condition: Tensor[(1, 1, 1, 1), bool], %v10: Tensor[(32, 32, 3, 3), float32], %v11: Tensor[(1, 1, 1, 32), float32], %v12: Tensor[(32, 32, 3, 3), float32], %v13: Tensor[(1, 1, 1, 32), float32], %v2: Tensor[(32, 1, 3, 3), float32], %v3: Tensor[(1, 1, 1, 32), float32], %v4: Tensor[(32, 32, 1, 1), float32], %v5: Tensor[(1, 1, 1, 32), float32], %v6: Tensor[(32, 64, 1, 1), float32], %v7: Tensor[(1, 1, 1, 32), float32], %v8: Tensor[(64, 32, 1, 1), float32], %v9: Tensor[(1, 1, 1, 64), float32]) {
free_var %condition1: Tensor[(1, 1, 1, 1), bool];
%0 = take(%condition1, 0);
%11 = if (%0) {
%1 = nn.conv2d(%input, %v2, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
free_var %v31: Tensor[(1, 1, 1, 32), float32];
%2 = nn.bias_add(%1, %v31);
%3 = nn.relu(%2);
%4 = nn.conv2d(%3, %v8, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
free_var %v91: Tensor[(1, 1, 1, 64), float32];
%5 = nn.bias_add(%4, %v91);
%6 = nn.relu(%5);
%7 = nn.conv2d(%6, %v6, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
free_var %v71: Tensor[(1, 1, 1, 32), float32];
%8 = nn.bias_add(%7, %v71);
nn.relu(%8)
} else {
%9 = nn.conv2d(%3, %v4, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
free_var %v51: Tensor[(1, 1, 1, 32), float32];
%10 = nn.bias_add(%9, %v51);
nn.relu(%10)
};
%12 = nn.conv2d(%11, %v10, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
free_var %v111: Tensor[(1, 1, 1, 32), float32];
%13 = nn.bias_add(%12, %v111);
%14 = nn.relu(%13);
%15 = nn.conv2d(%14, %v12, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
free_var %v131: Tensor[(1, 1, 1, 32), float32];
%16 = nn.bias_add(%15, %v131);
nn.relu(%16)
}