How to display multi classes segmentation result

Hi Folks,
The predicted image from my Unet model has shape of (48, 256, 256, 5) where the indexes are batch size, H, W & number of classes (5) respectively. I can display each class by plt.imshow() but expecting a more fit for purpose visualization’s method, please advise if any, thank you.