Discrete Change In Batch Size Causes Gradients To Be Exactly 0

Hello everyone,
I am working on a model using deepONet from deepXDE, which uses keras behind the scenes, and am finding some REALLY weird behavior.
I am researching the effects of switching to mixed_float16 on model memory, speed, and accuracy. Something really weird was happening that with a full batch size, the gradients were 0 and the model wasn’t updating from iteration to iteration.
However, when I brought down the batch size to below a certain number (for one model it was 512 items, another was 656), then the model could train again and the gradients were not zero.
I have spent a LOT of time looking into this really weird behavior and was wondering if anyone knows this behavior and knows what’s wrong?
For example, with this code: ```python

Load dataset

d = np.load(“antiderivative_aligned_train.npz”, allow_pickle=True)
X_train = (d[“X”][0].astype(np.float), d[“X”][1].astype(np.float16))
y_train = d[“y”].astype(np.float16)
d = np.load(“antiderivative_aligned_test.npz”, allow_pickle=True)
X_test = (d[“X”][0].astype(np.float16), d[“X”][1].astype(np.float16))
y_test = d[“y”].astype(np.float16)

data = dde.data.TripleCartesianProd(
X_train=X_test, y_train=y_test, X_test=X_train, y_test=y_train

Choose a network

m = 100
dim_x = 1
net = dde.nn.DeepONetCartesianProd(
[m, 40, 40],
[dim_x, 40, 40],
“Glorot normal”,

Define a Model

model = dde.Model(data, net)

Compile and Train

model.compile(“adam”, lr=0.001, metrics=[“mean l2 relative error”])
losshistory, train_state = model.train(iterations=200, batch_size=656)
```, the model does not train (the loss is the same every epoch). but when I make batch_size 655, it trains. (The code is from here: Antiderivative operator from an aligned dataset — DeepXDE 1.9.2.dev8+gcd34292 documentation )
I have talked to the author of deepXDE about this, and he does not know what is going on, which makes me think it is a keras/tensorflow quirk
Switching to float32 instead of mixed_float16 also fixes this, but my whole purpose was to try to use mixed_float16, which has worked well with other types of models in deepXDE.