Hello TVM Community,
I’m currently working on integrating a custom PyTorch operator, named LTQ, into TVM IR, but I’m encountering some issues. I’m seeking advice or suggestions from anyone who might have experience or insights into this matter.
Background:
I’ve developed a custom ResNet18 model in PyTorch that includes a custom operator, LTQ (Learnable Ternary Quantization). This operator has been successfully built and tested within PyTorch. Here’s the basic structure of my LTQ implementation in PyTorch:
torch.ops.load_library("build/libLTQ.so")
class LTQ(nn.Module):
def __init__(self, num_bits):
super(LTQ, self).__init__()
self.num_bits = num_bits
def forward(self, x):
return torch.ops.custom_ops.LTQ(x, self.num_bits)
Issue:
I’ve extended relay.frontend.pytorch.PyTorchOpConverter
to include a Python implementation of LTQ using TVM relay operations. The converter function is as follows:
LTQ Converter Function
# LTQ converter implementation
def ltq_converter(self, *args):
input = args[0]
x = input[0]
num_bits = input[1]
init_range = 2.0
n_val = 2 ** num_bits - 1
interval = init_range / n_val
zero = tvm.relay.const(0.0)
one = tvm.relay.const(1.0)
two = tvm.relay.const(2.0)
eps = tvm.relay.const(1e-3)
start = tvm.relay.var("start", shape=(), dtype="float32")
a = tvm.relay.var("a", shape=(n_val,), dtype="float32")
scale1 = tvm.relay.var("scale1", shape=(), dtype="float32")
scale2 = tvm.relay.var("scale2", shape=(), dtype="float32")
x = tvm.relay.multiply(x, scale1)
x_forward = x
x_backward = x
step_right = zero
a_pos = tvm.relay.where(tvm.relay.greater(a, eps), a, eps)
for i in range(n_val):
step_right = tvm.relay.add(step_right, tvm.relay.const(interval))
if i == 0:
thre_forward = tvm.relay.add(start, tvm.relay.divide(tvm.relay.op.take(a_pos, tvm.relay.const(0)), two))
thre_backward = start
x_forward = tvm.relay.where(tvm.relay.greater(x, thre_forward), step_right, zero)
x_backward = tvm.relay.where(tvm.relay.greater(x, thre_backward),
tvm.relay.add(tvm.relay.multiply(tvm.relay.divide(tvm.relay.const(interval), tvm.relay.op.take(a_pos, tvm.relay.const(i))),
tvm.relay.subtract(x, thre_backward)),
tvm.relay.subtract(step_right, tvm.relay.const(interval))),
zero)
else:
thre_forward = tvm.relay.add(thre_forward, tvm.relay.divide(tvm.relay.add(tvm.relay.op.take(a_pos, tvm.relay.const(i-1)), tvm.relay.op.take(a_pos, tvm.relay.const(i))), two))
thre_backward = tvm.relay.add(thre_backward, tvm.relay.op.take(a_pos, tvm.relay.const(i-1)))
x_forward = tvm.relay.where(tvm.relay.greater(x, thre_forward), step_right, x_forward)
x_backward = tvm.relay.where(tvm.relay.greater(x, thre_backward),
tvm.relay.add(tvm.relay.multiply(tvm.relay.divide(tvm.relay.const(interval), tvm.relay.op.take(a_pos, tvm.relay.const(i))),
tvm.relay.subtract(x, thre_backward)),
tvm.relay.subtract(step_right, tvm.relay.const(interval))),
x_backward)
thre_backward = tvm.relay.add(thre_backward, tvm.relay.op.take(a_pos, tvm.relay.const(i)))
x_backward = tvm.relay.where(tvm.relay.greater(x, thre_backward), two, x_backward)
out = tvm.relay.multiply(tvm.relay.add(tvm.relay.subtract(tvm.relay.add(x_forward, x_backward), x_backward), x_backward), scale2)
return out
I added the mapping in create_convert_map
method:
def create_convert_map(self):
self.convert_map = {
"custom_ops::ltq": self.ltq_converter
}
When converting the model to TVM IR, I use the following script:
model = resnet18(4, True)
example_input = torch.randn(1, 3, 224, 224)
graph_model = torch.jit.trace(model, example_input)
input_names = [("input", example_input.shape)]
mod, params = relay.frontend.from_pytorch(graph_model, input_names)
print(mod)
The output IR, however, does not reflect my custom operator LTQ
. Instead, it shows a series of operations without any mention of LTQ
. Here’s a snippet of the IR output:
Current IR Output (Snippet)
# ... (IR output details)
def @main(.....) {
%0 = nn.conv2d(%input, %aten::_convolution_0.weight, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* span=aten::_convolution_0:0:0 */;
%1 = nn.batch_norm(%0, %aten::batch_norm_0.weight, %aten::batch_norm_0.bias, %aten::batch_norm_0.running_mean, %aten::batch_norm_0.running_var) /* span=aten::batch_norm_0:0:0 */;
%2 = %1.0 /* span=aten::batch_norm_0:0:0 */;
%3 = nn.relu(%2) /* span=aten::relu_0:0:0 */;
%4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* span=aten::max_pool2d_0:0:0 */;
%5 = broadcast_to_like(%aten::expand_as_0.bias, %4) /* span=aten::expand_as_0:0:0 */;
%6 = add(%4, %5) /* span=aten::add_0:0:0 */;
%7 = broadcast_to(%aten::prelu_0.weight, shape=[64]) /* span=aten::prelu_0:0:0 */;
%8 = nn.prelu(%6, %7) /* span=aten::prelu_0:0:0 */;
%9 = broadcast_to_like(%aten::expand_as_1.bias, %8) /* span=aten::expand_as_1:0:0 */;
%10 = add(%8, %9) /* span=aten::add_1:0:0 */;
%11 = greater(%aten::gt_0.a, %aten::gt_0.eps) /* span=aten::gt_0:0:0 */;
%12 = where(%11, %aten::gt_0.a, %aten::gt_0.eps) /* span=aten::where_0:0:0 */;
%13 = take(%12, 0 /* span=aten::select_0:0:0 */, axis=0, mode="wrap") /* span=aten::select_0:0:0 */;
%14 = divide(%13, 2f /* span=aten::div_0:0:0 */) /* span=aten::div_0:0:0 */;
%15 = take(%12, 0 /* span=aten::select_2:0:0 */, axis=0, mode="wrap") /* span=aten::select_2:0:0 */;
%16 = take(%12, 1 /* span=aten::select_3:0:0 */, axis=0, mode="wrap") /* span=aten::select_3:0:0 */;
%17 = divide(%15, 2f /* span=aten::div_1:0:0 */) /* span=aten::div_1:0:0 */;
%18 = divide(%16, 2f /* span=aten::div_2:0:0 */) /* span=aten::div_2:0:0 */;
%19 = add(%aten::add_4.start, %14) /* span=aten::add_4:0:0 */;
%20 = add(%17, %18) /* span=aten::add_8:0:0 */;
%21 = take(%12, 1 /* span=aten::select_6:0:0 */, axis=0, mode="wrap") /* span=aten::select_6:0:0 */;
%22 = take(%12, 2 /* span=aten::select_7:0:0 */, axis=0, mode="wrap") /* span=aten::select_7:0:0 */;
%23 = divide(%21, 2f /* span=aten::div_3:0:0 */) /* span=aten::div_3:0:0 */;
%24 = divide(%22, 2f /* span=aten::div_4:0:0 */) /* span=aten::div_4:0:0 */;
%25 = add(%19, %20) /* span=aten::add_9:0:0 */;
%26 = add(%23, %24) /* span=aten::add_13:0:0 */;
%27 = take(%12, 2 /* span=aten::select_10:0:0 */, axis=0, mode="wrap") /* span=aten::select_10:0:0 */;
%28 = take(%12, 3 /* span=aten::select_11:0:0 */, axis=0, mode="wrap") /* span=aten::select_11:0:0 */;
%29 = divide(%27, 2f /* span=aten::div_5:0:0 */) /* span=aten::div_5:0:0 */;
%30 = divide(%28, 2f /* span=aten::div_6:0:0 */) /* span=aten::div_6:0:0 */;
%31 = add(%25, %26) /* span=aten::add_14:0:0 */;
%32 = add(%29, %30) /* span=aten::add_18:0:0 */;
%33 = take(%12, 3 /* span=aten::select_14:0:0 */, axis=0, mode="wrap") /* span=aten::select_14:0:0 */;
%34 = take(%12, 4 /* span=aten::select_15:0:0 */, axis=0, mode="wrap") /* span=aten::select_15:0:0 */;
%35 = divide(%33, 2f /* span=aten::div_7:0:0 */) /* span=aten::div_7:0:0 */;
%36 = divide(%34, 2f /* span=aten::div_8:0:0 */) /* span=aten::div_8:0:0 */;
%37 = add(%31, %32) /* span=aten::add_19:0:0 */;
%38 = add(%35, %36) /* span=aten::add_23:0:0 */;
%39 = take(%12, 4 /* span=aten::select_18:0:0 */, axis=0, mode="wrap") /* span=aten::select_18:0:0 */;
%40 = take(%12, 5 /* span=aten::select_19:0:0 */, axis=0, mode="wrap") /* span=aten::select_19:0:0 */;
%41 = divide(%39, 2f /* span=aten::div_9:0:0 */) /* span=aten::div_9:0:0 */;
%42 = divide(%40, 2f /* span=aten::div_10:0:0 */) /* span=aten::div_10:0:0 */;
%43 = add(%37, %38) /* span=aten::add_24:0:0 */;
%44 = add(%41, %42) /* span=aten::add_28:0:0 */;
%45 = take(%12, 5 /* span=aten::select_22:0:0 */, axis=0, mode="wrap") /* span=aten::select_22:0:0 */;
%46 = take(%12, 6 /* span=aten::select_23:0:0 */, axis=0, mode="wrap") /* span=aten::select_23:0:0 */;
%47 = divide(%45, 2f /* span=aten::div_11:0:0 */) /* span=aten::div_11:0:0 */;
%48 = divide(%46, 2f /* span=aten::div_12:0:0 */) /* span=aten::div_12:0:0 */;
%49 = add(%43, %44) /* span=aten::add_29:0:0 */;
%50 = add(%47, %48) /* span=aten::add_33:0:0 */;
.........
Desired Outcome:
I am looking to see my custom operator LTQ
explicitly represented in the TVM IR, similar to this format:
def @main(...):
%0 = nn.conv2d(%input, %aten::_convolution_0.weight, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* span=aten::_convolution_0:0:0 */;
%1 = nn.batch_norm(%0, %aten::batch_norm_0.weight, %aten::batch_norm_0.bias, %aten::batch_norm_0.running_mean, %aten::batch_norm_0.running_var) /* span=aten::batch_norm_0:0:0 */;
%2 = %1.0 /* span=aten::batch_norm_0:0:0 */;
%3 = nn.relu(%2) /* span=aten::relu__0:0:0 */;
%4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* span=aten::max_pool2d_0:0:0 */;
%5 = broadcast_to_like(%aten::expand_as_0.bias, %4) /* span=aten::expand_as_0:0:0 */;
%6 = add(%4, %5) /* span=aten::add_0:0:0 */;
%7 = broadcast_to(%aten::prelu_0.weight, shape=[64]) /* span=aten::prelu_0:0:0 */;
%8 = nn.prelu(%6, %7) /* span=aten::prelu_0:0:0 */;
%9 = broadcast_to_like(%aten::expand_as_1.bias, %8) /* span=aten::expand_as_1:0:0 */;
%10 = add(%8, %9) /* span=aten::add_1:0:0 */;
%11 = ltq(%a, 0.001f /* span=custom_ops::LTQ_0:0:0 */) /* span=custom_ops::LTQ_0:0:0 */;
.........
Questions:
- Are there additional steps or modifications required in the TVM conversion process to ensure that custom operators like LTQ are correctly and explicitly represented in the TVM IR?
- Is there a specific way to get a single IR node representation of a custom operator instead of each of its sub-operations?
Any guidance, suggestions, or pointers towards relevant documentation or examples would be greatly appreciated.
Thank you in advance for your assistance!