Ah yes, I don’t believe transform.gradient will call the simplify inference I pointed to, but my point is that dropout has no implementation and if you have a model with dropout you will get an error like the one you see.
Therefore the way to run models with dropout is by removing it with the pass I linked (which is what is supposed to done in inference time anyway).
For your use case however I believe you need to define the implementation first and then the gradient.