Inconsistent PReLU Implementation Across TVM Frontends

I’ve encountered a discrepancy in how TVM loads PReLU operations from different model formats (ONNX, PyTorch, and TFLite). This inconsistency makes further development based on the loaded Relay graph more challenging.

Current Implementations

PyTorch Frontend


def prelu(self, inputs, input_types):

    data = inputs[0]

    dim = self.get_dims(data)

    ndims = len(dim)

    axis = 0 if ndims == 1 else 1

    alpha = op.broadcastto(inputs[1], (dim[axis]))

    return _op.nn.prelu(data, alpha, axis)

TFLite Frontend


def convert_prelu(self, op):

    input_tensors = self.get_input_tensors(op)

    assert len(input_tensors) == 2, "input tensors length should be 2"

    input_tensor = input_tensors[0]

    alpha_tensor = input_tensors[1]

    if self.has_expr(alpha_tensor.tensor_idx):

        alpha_expr = self.get_expr(alpha_tensor.tensor_idx)

    else:

        alpha_tensor_type = alpha_tensor.tensor.Type()

        alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)

        alpha_expr = self.exp_tab.new_const(

            self.get_tensor_value(alpha_tensor),

            dtype=alpha_tensor_type_str,

            source_name=alpha_tensor.tensor.Name(),

        )

    in_expr = self.get_expr(input_tensor.tensor_idx)

    data_shape = to_int_list(self.get_tensor_shape(input_tensor))

    alpha_expr = op.broadcastto(alpha_expr, data_shape)

    alpha_expr = op.reshape(alphaexpr, [-1])

    out = op.nn.prelu(op.reshape(in_expr, [-1]), alpha_expr, axis=0)

    out = op.reshape(out, datashape)

    return out

ONNX Frontend


class Prelu(OnnxOpConverter):

    @classmethod

    def implv1(cls, inputs, attr, params):

        assert len(inputs) == 2, f"Prelu need 2 inputs, {len(inputs)} given"

        input_shape = shape_of(inputs[0])

        alpha = op.broadcastto_like(inputs[1], inputs[0])

        alpha = _op.reshape(alpha, [-1])

        output = op.nn.prelu(op.reshape(inputs[0], [-1]), alpha, axis=0)

        return op.reshape(output, inputshape)

Observed Difference

The ONNX and TFLite loaders apply a flatten-reshape_to_original_shape operation before and after the PReLU operation. However, the PyTorch loader implements PReLU without this pattern, loading it as-is.

Questions

  1. Is there a historical or empirical reason for maintaining these two different patterns?

  2. Would it be possible to simplify the implementation by adopting a single, consistent pattern across all frontends?

Your insights on this matter would be greatly appreciated. Thank you for your time and consideration.