SparseTensor memory leak

Hello General,

(tensorflow 2.13 / tensorflow-metal 2.13 / OS X 13)
I am experiencing an apparent memory leak in a pipeline where the inputs to the first (dense) layer are large SparseTensors. As training data is read / passed to the model, memory use increases and the shell goes OOM after a few epochs, in proportion to batch size.

  • I read column-indices from dense tensors as TFRecords, and consume records with with the two methods below.
  • The leak does not occur when I create SparseTensors in the method without the tensors read from record. They are otherwise identical to the ones generated from the training data, and are manipulated in the same way inside the method.
  • I parse the input tensors (list_a / list_b) as, but parsing as has the same behavior.

The _to_sparse_tensor method is run in graph mode (because it is passed to a DataSet map), so one possible cause of the leak may be that the computation graph is being continually updated during training. However I cant see where this would happen. The arguments passed are all tensors of constant shape, and neither method pulls in the surrounding scope in obvious ways.

Does anyone have an idea where the problem may lie, or how I can debug it?


def _to_sparse_tensor(list_a, list_b, scalar_a, scalar_b):
    a_size = tf.constant(100000)
    b_size = tf.constant(50)
    row_indicies = tf.constant(0, shape=[6,], dtype=tf.int64)
    indicies = tf.stack([row_indicies, list_a], 1)
    values = tf.constant(1, shape=[6,], dtype=tf.int64)
    sparse_a = tf.sparse.SparseTensor(indices=indicies,
                                           dense_shape=[1, a_size])
    indicies = tf.stack([row_indicies, list_b], 1)
    sparse_b = tf.sparse.SparseTensor(indices=indicies,
                                                dense_shape=[1, b_size])
    in_onehot = tf.sparse.concat(1, [sparse_a, sparse_b])
    in_onehot = tf.sparse.reshape(in_onehot, [a_size+b_size,])
    return in_onehot, scalar_a, scalar_b

def _parse_function_list(example):
    features_dict = {"list_a":[], dtype=tf.int64),
                     "list_b":[], dtype=tf.int64),
                     "scalar_a":[], dtype=tf.int64),
                     "scalar_b":[], dtype=tf.int64)}

    context_dict = {}
    context, features =

    list_a, list_b, scalar_a, scalar_b = features['list_a'], \
        features['list_b'], \
        features['scalar_a'], \
    return list_a, list_b, scalar_a, scalar_b