Stateful LSTM model support

I am trying to figure out how TVM can support stateful LSTM models during inference.

Here is a blog about these stateful models. A quote from the blog: “For example, it might be a part of a larger system that works on video frames. It might be required to perform some action instantly after each frame, instead of waiting for a sufficiently long sequence of video frames before being fed to the network.”

I can work around the problem by passing state as result of a inference call to compiled tvm model, and pass the state back in. This however breaks the interface of the model: every time I update the model, I may also need to update the inference code which calls the model, depends on which internal state has changed. This makes the model code difficult to maintain.

Here is an example of the LSTM stateful model in TensorFlow 2.4:

class CustomModule(tf.Module):
    def __init__(self):
        super(CustomModule, self).__init__()
        self.v = tf.keras.layers.LSTM(1, stateful=True, return_sequences=True)

    @tf.function
    def __call__(self, n):
        return self.v(n)        

a = tf.random.normal([2,1,1])
module = CustomModule()
print(module(a))
print(module(a))

Output of the program above is like the following, notice that two prints give different results:

tf.Tensor(
[[[ 0.11585967]]
 [[-1.1594343 ]]], shape=(2, 1, 1), dtype=float32)

tf.Tensor(
[[[ 0.00368929]]
 [[-0.09180649]]], shape=(2, 1, 1), dtype=float32)

Here is an even simpler tensorflow stateful model which returns increasing integer for each inference call. I also printed the tensorflow graph nodes for the module.

import tensorflow as tf

class CustomModule(tf.Module):
    def __init__(self):
        super(CustomModule, self).__init__()
        self.v = tf.Variable(1)

    @tf.function
    def __call__(self):
        self.v.assign_add(1)
        return self.v

module = CustomModule()
print(module())
print(module())
print(module())

'''
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
'''

call = module.__call__.get_concrete_function()
graph = call.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')

'''
[] -> Const
[] -> AssignAddVariableOp/resource
['AssignAddVariableOp/resource', 'Const'] -> AssignAddVariableOp
['AssignAddVariableOp/resource', '^AssignAddVariableOp'] -> ReadVariableOp
['ReadVariableOp', '^AssignAddVariableOp', '^ReadVariableOp'] -> Identity
'''

So ideally I want relay to support the notion of stateful models. This may cause some difficulty with some of the relay optimization passes because stateful function means that they are no longer purely functional. The previous RFC about handling effects may be related to what I am asking here: @MarisaKirisame @tqchen @junrushao

1 Like