Recently, when we tried to use MultiWorkerMirroredStrategy with Keras, we found:
- When Keras wrap our passed in dataset with experimental distributed dataset, we found we cannot scale over x nodes because it needs us pass in a global batch size and global batch size needs to take number of workers into consideration (global batch size = batch size * num of workers * num of replica). Therefore, when we have a lot of workers, compared with Mirrored strategy, we start seeing job failure due to OOM
- We try to get around this issue by passing in distribute_datasets_from_function that we can have full control over per replica batch and sharding logic (and get around OOM issue). Then our job failed at:
Line 733 in 1923123
When we passed in normal dataset, it has UNKNOWN cardinality and leverage
Line 710 in 1923123
to recreate iterator for every epoch. Our use case is to have validation step to exhaust our dataset instead of hard coding steps. I wonder if we can relax check in L733 altogether with change to L714. Then we can support no steps input from users? If you agree, I can submit the PR to make the change.
Please let me know if any downside of doing so.