Currently loading a pytorch module to tvm (relay) follows the path
torch.nn.Module -> torch.jit.trace -> torchsript -> tvm.relay , which works work for most vision and nlp models. However, such conversion has its own limitations, one case is the limited tracing inside customized modules. For example,
def fn1(x, w): return F.linear(w, w) x = torch.randn(10, 15) w = torch.randn(20, 15) # works ts = torch.jit.trace(fn1, (x, w)) class MyLinear(autograd.Function): @staticmethod def forward(ctx, x, w): ctx.save_for_backward(x, w) return x.mm(w.t()) @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors dydx = grad_output.mm(weight) dydw = grad_output.t().mm(input) return dydx, dydw def fn2(x, w): return MyLinear.apply(x, w) # does not work ts = torch.jit.trace(fn2, (x, w)) print(ts.code)
In many DL tasks, like vision and nlp, defining forward and backward is commonly used in research and develpment. Missing support for customized OP prevents more people from using TVM.
Fortunately, pytorch recent has a new IR frontend torch.fx which can trace into the customized OP and the API has eventually become stable in recent release. I would suggest to add support for torch.fx and can help scratch a PR.
Feel free to comment and share your thoughts.