MultiWorkerMirroredStrategy with Keras: can we relax the steps checking when distributed dataset is passed in?

Recently, when we tried to use MultiWorkerMirroredStrategy with Keras, we found:

  1. When Keras wrap our passed in dataset with experimental distributed dataset, we found we cannot scale over x nodes because it needs us pass in a global batch size and global batch size needs to take number of workers into consideration (global batch size = batch size * num of workers * num of replica). Therefore, when we have a lot of workers, compared with Mirrored strategy, we start seeing job failure due to OOM
  2. We try to get around this issue by passing in distribute_datasets_from_function that we can have full control over per replica batch and sharding logic (and get around OOM issue). Then our job failed at:


Line 733 in 1923123

if _is_distributed_dataset(self._dataset):

When we passed in normal dataset, it has UNKNOWN cardinality and leverage


Line 710 in 1923123

def should_recreate_iterator(self):

to recreate iterator for every epoch. Our use case is to have validation step to exhaust our dataset instead of hard coding steps. I wonder if we can relax check in L733 altogether with change to L714. Then we can support no steps input from users? If you agree, I can submit the PR to make the change.

Please let me know if any downside of doing so.