Hi all, I’ve come across an issue while working with relax PyExprMutator to add new nodes. Specifically, when a new node is initialized, it appears that checked_type
and struct_info
are not being initialized automatically. For instance, in my case, when performing a layout transform, the code needs to parse certain type and shape information, but it throws an exception due to the missing initialization of checked_type.
I’m wondering if this behavior is a bug or if there’s a specific rationale behind it. Interestingly, if I manually emit these initializations, the issue gets resolved. But this is a bit redundant.
@relax.expr_functor.mutator
class Mutator(PyExprMutator):
"""Mutator that performs transformation."""
def visit_call_(self, call_node: Call):
g_var = call_node.args[0]
if g_var not in propogate_candidates.keys():
return super().visit_call_(call_node)
inp, weight = call_node.args[1]
inp = self.builder_.emit(relax.op.layout_transform(inp, index_map=lambda i, j: (i // 16, j // 16, i % 16, j % 16)))
weight = self.builder_.emit(relax.op.layout_transform(weight, index_map=lambda i, j: (i // 16, j // 16, i % 16, j % 16)))
if self.transform_level.value >= TransformKind.IntraWarpTransform.value:
inp = self.builder_.emit(relax.op.layout_transform(inp, index_map=A_permutation))
weight = self.builder_.emit(relax.op.layout_transform(weight, index_map=B_permutation))
output_shape = call_node.struct_info.shape
new_output_shape = (output_shape[0] // 16, output_shape[1] // 16, 16, 16)
call_node = self.builder_.emit(relax.call_tir(g_var, (inp, weight), out_sinfo=relax.TensorStructInfo(new_output_shape, call_node.struct_info.dtype)))
call_node = self.builder_.emit(relax.op.layout_transform(call_node, index_map=lambda i, j, ii, jj: (i * 16 + ii, j * 16 + jj)))
return call_node
making a new node without emit:
making a new node with emit: