Do we need to manually invoke BlockBuilder.emit?

Hi all, I’ve come across an issue while working with relax PyExprMutator to add new nodes. Specifically, when a new node is initialized, it appears that checked_type and struct_info are not being initialized automatically. For instance, in my case, when performing a layout transform, the code needs to parse certain type and shape information, but it throws an exception due to the missing initialization of checked_type.

I’m wondering if this behavior is a bug or if there’s a specific rationale behind it. Interestingly, if I manually emit these initializations, the issue gets resolved. But this is a bit redundant.

@relax.expr_functor.mutator
class Mutator(PyExprMutator):
    """Mutator that performs transformation."""

    def visit_call_(self, call_node: Call):
        g_var = call_node.args[0]
        if g_var not in propogate_candidates.keys():
            return super().visit_call_(call_node)
        inp, weight = call_node.args[1]
        inp = self.builder_.emit(relax.op.layout_transform(inp, index_map=lambda i, j: (i // 16, j // 16, i % 16, j % 16))) 
        weight = self.builder_.emit(relax.op.layout_transform(weight, index_map=lambda i, j: (i // 16, j // 16, i % 16, j % 16))) 
        if self.transform_level.value >= TransformKind.IntraWarpTransform.value:
            inp = self.builder_.emit(relax.op.layout_transform(inp, index_map=A_permutation))
            weight = self.builder_.emit(relax.op.layout_transform(weight, index_map=B_permutation))
        output_shape = call_node.struct_info.shape
        new_output_shape = (output_shape[0] // 16, output_shape[1] // 16, 16, 16)
        
        call_node = self.builder_.emit(relax.call_tir(g_var, (inp, weight), out_sinfo=relax.TensorStructInfo(new_output_shape, call_node.struct_info.dtype)))
        call_node = self.builder_.emit(relax.op.layout_transform(call_node, index_map=lambda i, j, ii, jj: (i * 16 + ii, j * 16 + jj)))
        return call_node

making a new node without emit:

making a new node with emit:

they are not automatically initialized, you can call blockbuilder.normalize to initialize them. We also encourage go with struct_info as it contains all the necessary info

1 Like