Issues with Custom PyTorch Operator Integration into TVM IR

Hello TVM Community,

I’m currently working on integrating a custom PyTorch operator, named LTQ, into TVM IR, but I’m encountering some issues. I’m seeking advice or suggestions from anyone who might have experience or insights into this matter.

Background:

I’ve developed a custom ResNet18 model in PyTorch that includes a custom operator, LTQ (Learnable Ternary Quantization). This operator has been successfully built and tested within PyTorch. Here’s the basic structure of my LTQ implementation in PyTorch:

torch.ops.load_library("build/libLTQ.so")

class LTQ(nn.Module):
    def __init__(self, num_bits):
        super(LTQ, self).__init__()
        self.num_bits = num_bits

    def forward(self, x):
        return torch.ops.custom_ops.LTQ(x, self.num_bits)

Issue:

I’ve extended relay.frontend.pytorch.PyTorchOpConverter to include a Python implementation of LTQ using TVM relay operations. The converter function is as follows:

LTQ Converter Function
# LTQ converter implementation
def ltq_converter(self, *args):
        
        input = args[0]
        x = input[0]
        num_bits = input[1]
        init_range = 2.0
        n_val = 2 ** num_bits - 1
        interval = init_range / n_val
        zero = tvm.relay.const(0.0)
        one = tvm.relay.const(1.0)
        two = tvm.relay.const(2.0)
        eps = tvm.relay.const(1e-3)
        start = tvm.relay.var("start", shape=(), dtype="float32")
        a = tvm.relay.var("a", shape=(n_val,), dtype="float32")
        scale1 = tvm.relay.var("scale1", shape=(), dtype="float32")
        scale2 = tvm.relay.var("scale2", shape=(), dtype="float32")

        x = tvm.relay.multiply(x, scale1)

        x_forward = x
        x_backward = x
        step_right = zero

        a_pos = tvm.relay.where(tvm.relay.greater(a, eps), a, eps)
        for i in range(n_val):
            step_right = tvm.relay.add(step_right, tvm.relay.const(interval))
            if i == 0:
                thre_forward = tvm.relay.add(start, tvm.relay.divide(tvm.relay.op.take(a_pos, tvm.relay.const(0)), two))
                thre_backward = start
                x_forward = tvm.relay.where(tvm.relay.greater(x, thre_forward), step_right, zero)
                x_backward = tvm.relay.where(tvm.relay.greater(x, thre_backward), 
                                        tvm.relay.add(tvm.relay.multiply(tvm.relay.divide(tvm.relay.const(interval), tvm.relay.op.take(a_pos, tvm.relay.const(i))), 
                                                                tvm.relay.subtract(x, thre_backward)), 
                                                tvm.relay.subtract(step_right, tvm.relay.const(interval))), 
                                        zero)
            else:
                thre_forward = tvm.relay.add(thre_forward, tvm.relay.divide(tvm.relay.add(tvm.relay.op.take(a_pos, tvm.relay.const(i-1)), tvm.relay.op.take(a_pos, tvm.relay.const(i))), two))
                thre_backward = tvm.relay.add(thre_backward, tvm.relay.op.take(a_pos, tvm.relay.const(i-1)))
                x_forward = tvm.relay.where(tvm.relay.greater(x, thre_forward), step_right, x_forward)
                x_backward = tvm.relay.where(tvm.relay.greater(x, thre_backward), 
                                        tvm.relay.add(tvm.relay.multiply(tvm.relay.divide(tvm.relay.const(interval), tvm.relay.op.take(a_pos, tvm.relay.const(i))), 
                                                                tvm.relay.subtract(x, thre_backward)), 
                                                tvm.relay.subtract(step_right, tvm.relay.const(interval))), 
                                        x_backward)

        thre_backward = tvm.relay.add(thre_backward, tvm.relay.op.take(a_pos, tvm.relay.const(i)))
        x_backward = tvm.relay.where(tvm.relay.greater(x, thre_backward), two, x_backward)

        out = tvm.relay.multiply(tvm.relay.add(tvm.relay.subtract(tvm.relay.add(x_forward, x_backward), x_backward), x_backward), scale2)
        return out

I added the mapping in create_convert_map method:

def create_convert_map(self):
    self.convert_map = {
        "custom_ops::ltq": self.ltq_converter
    }

When converting the model to TVM IR, I use the following script:

model = resnet18(4, True)
example_input = torch.randn(1, 3, 224, 224)
graph_model = torch.jit.trace(model, example_input)
input_names = [("input", example_input.shape)]
mod, params = relay.frontend.from_pytorch(graph_model, input_names)
print(mod)

The output IR, however, does not reflect my custom operator LTQ. Instead, it shows a series of operations without any mention of LTQ. Here’s a snippet of the IR output:

Current IR Output (Snippet)
# ... (IR output details)
def @main(.....) {
  %0 = nn.conv2d(%input, %aten::_convolution_0.weight, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* span=aten::_convolution_0:0:0 */;
  %1 = nn.batch_norm(%0, %aten::batch_norm_0.weight, %aten::batch_norm_0.bias, %aten::batch_norm_0.running_mean, %aten::batch_norm_0.running_var) /* span=aten::batch_norm_0:0:0 */;
  %2 = %1.0 /* span=aten::batch_norm_0:0:0 */;
  %3 = nn.relu(%2) /* span=aten::relu_0:0:0 */;
  %4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* span=aten::max_pool2d_0:0:0 */;
  %5 = broadcast_to_like(%aten::expand_as_0.bias, %4) /* span=aten::expand_as_0:0:0 */;
  %6 = add(%4, %5) /* span=aten::add_0:0:0 */;
  %7 = broadcast_to(%aten::prelu_0.weight, shape=[64]) /* span=aten::prelu_0:0:0 */;
  %8 = nn.prelu(%6, %7) /* span=aten::prelu_0:0:0 */;
  %9 = broadcast_to_like(%aten::expand_as_1.bias, %8) /* span=aten::expand_as_1:0:0 */;
  %10 = add(%8, %9) /* span=aten::add_1:0:0 */;
  %11 = greater(%aten::gt_0.a, %aten::gt_0.eps) /* span=aten::gt_0:0:0 */;
  %12 = where(%11, %aten::gt_0.a, %aten::gt_0.eps) /* span=aten::where_0:0:0 */;
  %13 = take(%12, 0 /* span=aten::select_0:0:0 */, axis=0, mode="wrap") /* span=aten::select_0:0:0 */;
  %14 = divide(%13, 2f /* span=aten::div_0:0:0 */) /* span=aten::div_0:0:0 */;
  %15 = take(%12, 0 /* span=aten::select_2:0:0 */, axis=0, mode="wrap") /* span=aten::select_2:0:0 */;
  %16 = take(%12, 1 /* span=aten::select_3:0:0 */, axis=0, mode="wrap") /* span=aten::select_3:0:0 */;
  %17 = divide(%15, 2f /* span=aten::div_1:0:0 */) /* span=aten::div_1:0:0 */;
  %18 = divide(%16, 2f /* span=aten::div_2:0:0 */) /* span=aten::div_2:0:0 */;
  %19 = add(%aten::add_4.start, %14) /* span=aten::add_4:0:0 */;
  %20 = add(%17, %18) /* span=aten::add_8:0:0 */;
  %21 = take(%12, 1 /* span=aten::select_6:0:0 */, axis=0, mode="wrap") /* span=aten::select_6:0:0 */;
  %22 = take(%12, 2 /* span=aten::select_7:0:0 */, axis=0, mode="wrap") /* span=aten::select_7:0:0 */;
  %23 = divide(%21, 2f /* span=aten::div_3:0:0 */) /* span=aten::div_3:0:0 */;
  %24 = divide(%22, 2f /* span=aten::div_4:0:0 */) /* span=aten::div_4:0:0 */;
  %25 = add(%19, %20) /* span=aten::add_9:0:0 */;
  %26 = add(%23, %24) /* span=aten::add_13:0:0 */;
  %27 = take(%12, 2 /* span=aten::select_10:0:0 */, axis=0, mode="wrap") /* span=aten::select_10:0:0 */;
  %28 = take(%12, 3 /* span=aten::select_11:0:0 */, axis=0, mode="wrap") /* span=aten::select_11:0:0 */;
  %29 = divide(%27, 2f /* span=aten::div_5:0:0 */) /* span=aten::div_5:0:0 */;
  %30 = divide(%28, 2f /* span=aten::div_6:0:0 */) /* span=aten::div_6:0:0 */;
  %31 = add(%25, %26) /* span=aten::add_14:0:0 */;
  %32 = add(%29, %30) /* span=aten::add_18:0:0 */;
  %33 = take(%12, 3 /* span=aten::select_14:0:0 */, axis=0, mode="wrap") /* span=aten::select_14:0:0 */;
  %34 = take(%12, 4 /* span=aten::select_15:0:0 */, axis=0, mode="wrap") /* span=aten::select_15:0:0 */;
  %35 = divide(%33, 2f /* span=aten::div_7:0:0 */) /* span=aten::div_7:0:0 */;
  %36 = divide(%34, 2f /* span=aten::div_8:0:0 */) /* span=aten::div_8:0:0 */;
  %37 = add(%31, %32) /* span=aten::add_19:0:0 */;
  %38 = add(%35, %36) /* span=aten::add_23:0:0 */;
  %39 = take(%12, 4 /* span=aten::select_18:0:0 */, axis=0, mode="wrap") /* span=aten::select_18:0:0 */;
  %40 = take(%12, 5 /* span=aten::select_19:0:0 */, axis=0, mode="wrap") /* span=aten::select_19:0:0 */;
  %41 = divide(%39, 2f /* span=aten::div_9:0:0 */) /* span=aten::div_9:0:0 */;
  %42 = divide(%40, 2f /* span=aten::div_10:0:0 */) /* span=aten::div_10:0:0 */;
  %43 = add(%37, %38) /* span=aten::add_24:0:0 */;
  %44 = add(%41, %42) /* span=aten::add_28:0:0 */;
  %45 = take(%12, 5 /* span=aten::select_22:0:0 */, axis=0, mode="wrap") /* span=aten::select_22:0:0 */;
  %46 = take(%12, 6 /* span=aten::select_23:0:0 */, axis=0, mode="wrap") /* span=aten::select_23:0:0 */;
  %47 = divide(%45, 2f /* span=aten::div_11:0:0 */) /* span=aten::div_11:0:0 */;
  %48 = divide(%46, 2f /* span=aten::div_12:0:0 */) /* span=aten::div_12:0:0 */;
  %49 = add(%43, %44) /* span=aten::add_29:0:0 */;
  %50 = add(%47, %48) /* span=aten::add_33:0:0 */;
.........

Desired Outcome:

I am looking to see my custom operator LTQ explicitly represented in the TVM IR, similar to this format:

def @main(...):
     %0 = nn.conv2d(%input, %aten::_convolution_0.weight, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* span=aten::_convolution_0:0:0 */;
  %1 = nn.batch_norm(%0, %aten::batch_norm_0.weight, %aten::batch_norm_0.bias, %aten::batch_norm_0.running_mean, %aten::batch_norm_0.running_var) /* span=aten::batch_norm_0:0:0 */;
  %2 = %1.0 /* span=aten::batch_norm_0:0:0 */;
  %3 = nn.relu(%2) /* span=aten::relu__0:0:0 */;
  %4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* span=aten::max_pool2d_0:0:0 */;
  %5 = broadcast_to_like(%aten::expand_as_0.bias, %4) /* span=aten::expand_as_0:0:0 */;
  %6 = add(%4, %5) /* span=aten::add_0:0:0 */;
  %7 = broadcast_to(%aten::prelu_0.weight, shape=[64]) /* span=aten::prelu_0:0:0 */;
  %8 = nn.prelu(%6, %7) /* span=aten::prelu_0:0:0 */;
  %9 = broadcast_to_like(%aten::expand_as_1.bias, %8) /* span=aten::expand_as_1:0:0 */;
  %10 = add(%8, %9) /* span=aten::add_1:0:0 */;
  %11 = ltq(%a, 0.001f /* span=custom_ops::LTQ_0:0:0 */) /* span=custom_ops::LTQ_0:0:0 */;
.........

Questions:

  1. Are there additional steps or modifications required in the TVM conversion process to ensure that custom operators like LTQ are correctly and explicitly represented in the TVM IR?
  2. Is there a specific way to get a single IR node representation of a custom operator instead of each of its sub-operations?

Any guidance, suggestions, or pointers towards relevant documentation or examples would be greatly appreciated.

Thank you in advance for your assistance!

1 Like