I think the PropagateSharding pass does not correctly handle the nested function calling case (in your code main calls into mlp).
Also you are right that the integration is not complete. You are welcome to extend the integration.
I think the PropagateSharding pass does not correctly handle the nested function calling case (in your code main calls into mlp).
Also you are right that the integration is not complete. You are welcome to extend the integration.