Currently loading a pytorch module to tvm (relay) follows the path torch.nn.Module -> torch.jit.trace -> torchsript -> tvm.relay
, which works work for most vision and nlp models. However, such conversion has its own limitations, one case is the limited tracing inside customized modules. For example,
def fn1(x, w):
return F.linear(w, w)
x = torch.randn(10, 15)
w = torch.randn(20, 15)
# works
ts = torch.jit.trace(fn1, (x, w))
class MyLinear(autograd.Function):
@staticmethod
def forward(ctx, x, w):
ctx.save_for_backward(x, w)
return x.mm(w.t())
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
dydx = grad_output.mm(weight)
dydw = grad_output.t().mm(input)
return dydx, dydw
def fn2(x, w):
return MyLinear.apply(x, w)
# does not work
ts = torch.jit.trace(fn2, (x, w))
print(ts.code)
In many DL tasks, like vision and nlp, defining forward and backward is commonly used in research and develpment. Missing support for customized OP prevents more people from using TVM.
Fortunately, pytorch recent has a new IR frontend torch.fx which can trace into the customized OP and the API has eventually become stable in recent release. I would suggest to add support for torch.fx and can help scratch a PR.
Feel free to comment and share your thoughts.