Gradients being computed incorrectly with custom loss function

I am attempting to train an LSTM with a custom loss function. The model’s goal is to read in some sequential noisy points on a curve, and generate coefficients for an nth degree polynomial that fits the curve. The curve has multiple variables, x, y and z, that all depend on time. The model will generate one set of coefficients for each variable such that when plotted over time, the polynomial is a best fit for its respective set of points.

The input to the model is shape (batch_size, sequence_length, num_variables) and it outputs a tensor with shape (batch_size, num_variables, polynomial_degree). My custom loss function compares that output to a set of points known to lie on the polynomial being estimated, and returns a loss tensor with shape (batch_size).

Here is my model definition:

model = tf.keras.Sequential()
model.add(tf.keras.layers.Masking(mask_value=float(-10), 
    input_shape=(train_X.shape[1], train_X.shape[2])))
model.add(tf.keras.layers.LSTM(32))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(30), 
    kernel_initializer=tf.keras.initializers.zeros())
model.add(tf.keras.layers.Reshape((3, 10)))

model.compile(loss=batched_polynomial_loss, optimizer=tf.keras.optimizers.Adam())

model.fit(x=train_X, y=train_y, validation_data=val, epochs=5, batch_size=32)

Here is my loss function (tf_power_outer is a replication of numpy.power.outer):

def tf_power_outer(elem):
    powers = tf.convert_to_tensor(np.arange(polynomial_model_degree), dtype=tf.float32)
    expand_elem = tf.convert_to_tensor(np.full((polynomial_model_degree), 1.0), dtype=tf.float32)
    expanded_elem = tf.tensordot(expand_elem, elem, axes=0)
    expand_powers = tf.convert_to_tensor(np.full((100), 1.0), dtype=tf.float32)
    expanded_powers = tf.tensordot(powers, expand_powers, axes=0)
    return tf.pow(expanded_elem, expanded_powers)


def batched_polynomial_loss(y_true, y_pred):
    true_tracks = tf.transpose(y_true, [0, 2, 1])
    tt_shape = tf.shape(true_tracks)
    times = tf.slice(true_tracks, [0, tt_shape[1] - 1, 0], [tt_shape[0], 1, tt_shape[2]])
    times = tf.squeeze(times, axis=1)
    variables = tf.slice(true_tracks, [0, 0, 0], [tt_shape[0], tt_shape[1] - 1, tt_shape[2]])
    time_mats = tf.map_fn(tf_power_outer, times)
    guess_data = tf.matmul(y_pred, time_mats)
    mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
    full_error = mse(variables, guess_data)
    reduced_error = tf.math.reduce_mean(full_error, axis=1)
    return reduced_error

It was not working with model.fit, so I wrote this training loop to inspect it more closely:

opt = tf.keras.optimizers.Adam()
def step(X, y):
    with tf.GradientTape() as tape:
        pred = model(X)
        loss = batched_polynomial_loss(y, pred)
    print(loss)
    grads = tape.gradient(loss, model.trainable_variables)
    print(grads)
    opt.apply_gradients(zip(grads, model.trainable_variables))


step(train_X[0:32], train_y[0:32])
step(train_X[32:64], train_y[32:64])
step(train_X[64:96], train_y[64:96])

On the first iteration, loss looks like:

tf.Tensor(
[0.43734643 0.4373459 0.43734697 0.4373469 0.43734667 0.4373454
0.4373459 0.4373479 0.43734646 0.43734646 0.4373469 0.43734574
0.43734762 0.43734512 0.43734625 0.4373463 0.4373455 0.4373475
0.43734655 0.43734694 0.43734714 0.43734702 0.43734705 0.4373466
0.43734702 0.43734536 0.43734682 0.43734637 0.43734694 0.4373466
0.43734622 0.4373466 ], shape=(32,), dtype=float32)

and grads looks like:

[<tf.Tensor: shape=(11, 128), dtype=float32, numpy=
array([[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]], dtype=float32)>, <tf.Tensor: shape=(32, 128), dtype=float32, numpy=
array([[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]], dtype=float32)>, <tf.Tensor: shape=(128,), dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>, <tf.Tensor: shape=(32, 64), dtype=float32, numpy=
array([[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]], dtype=float32)>, <tf.Tensor: shape=(64,), dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>, <tf.Tensor: shape=(64, 30), dtype=float32, numpy=
array([[ 8.4006101e-02, -5.3240155e+02, -1.1512082e+06, …,
-4.7750685e+22, -8.7685334e+25, -1.6336383e+29],
[ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, …,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
[ 1.1525370e-02, -7.3044151e+01, -1.5794281e+05, …,
-6.5512637e+21, -1.2030189e+25, -2.2413071e+28],
…,
[ 3.0950422e-04, -1.9615617e+00, -4.2414673e+03, …,
-1.7593081e+20, -3.2306457e+23, -6.0189166e+26],
[ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, …,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
[ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, …,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00]], dtype=float32)>, <tf.Tensor: shape=(30,), dtype=float32, numpy=
array([ 6.0670018e-01, -3.8450579e+03, -8.3141405e+06, -1.5660871e+10,
-2.8943713e+13, -5.3658182e+16, -1.0029243e+20, -1.8912737e+23,
-3.5962617e+26, -6.8893381e+29, 1.6471731e+01, 1.5045589e+04,
1.8590212e+07, 2.6241573e+10, 4.0149233e+13, 6.4911980e+16,
1.0927340e+20, 1.8970660e+23, 3.3739135e+26, 6.1173324e+29,
-1.5370066e+01, -1.7797381e+04, -2.5041262e+07, -3.8708691e+10,
-6.3491046e+13, -1.0852857e+17, -1.9123389e+20, -3.4486026e+23,
-6.3327237e+26, -1.1798300e+30], dtype=float32)>]

In the second iteration, loss and grads have the same shapes, but every single value in both is “nan”.

I don’t see anything wrong with the loss that was computed in the first iteration, an array of values all ~0.43 seems normal for an untrained model. But the gradients that get computed after this are weirdly huge, I see one over 1e+30. Is my loss function formatted incorrectly? Why are the gradients getting calculated so weirdly?