How does keras-core deal with CxHxW dataloaders from pytorch

I have been very excited since the announcement of keras core and all that it brings to the table. However I stumbled across this question that has been bugging me for some time now.

In keras core, you can write one training loop that works with both pytorch dataloaders or tensorflow datasets, however in pytorch, images are presented in the order CxHxW whiles for tf datasets it is HxWxC. I am wondering if there is going to be any code changes to deal with this descanpcy or if all this is taken care of under the hood by keras with no action from the user?

Hi @Atia

This is known issue and also mentioned in the Keras-core release announcement :

  • Image layout and performance considerations with PyTorch. When using convnets, the typical image layout to use is "channels_last" (aka NHWC), which is the standard in cuDNN, TensorFlow, JAX, and others. However, PyTorch uses "channels_first". You can use any Keras Core convnet with any image layout, and you can easily switch from one default layout to the other via the keras_core.config.set_image_data_format() flag. Importantly, when using PyTorch convnets in the "channels_last" format, Keras will have to convert layouts back and forth at each layer, which is inefficient. For best performance, remember to set your default layout to "channels_first" when using convnets in PyTorch.

In the future, we hope to resolve this issue by by-passing torch.nn ops and going directly to cuDNN. Thank you.