OOM when calling model.predict()

Hi everyone,

I’m having an issue with model.predict() causing OOM errors. Strangely, this doesn’t happen while training, only while predicting on the val dataset.

I’m using a TFRecordDataset with batches of size 512

val_dataset = tf.data.TFRecordDataset(filenames=[val_files])
.map(tf_parse, num_parallel_calls=6)
.batch(BATCH_SIZE).prefetch(16)
.repeat(1)

def tf_parse(eg):
example = tf.io.parse_example(eg[tf.newaxis], {“features”: tf.io.FixedLenFeature(shape=(1050,), dtype=tf.float32),
“targets”: tf.io.FixedLenFeature(shape=(6,), dtype=tf.float32)})
return example[“features”][0], (example[“features”][0], example[“targets”][0], example[“targets”][0])

As I stated above, training works fine, but when I try to predict on the entire val_dataset I get an OOM error. Trying smaller bits of the dataset with model.predict(val_dataset.take(50)), for example, works fine, but not with the entire val_dataset. Even specifying batch_size=1 in predict doesn’t help at all.

The input data is 1050 columns of numeric data, and there are about 540k rows of data. During training, GPU memory usage is around 2.5/8.0GB.

Does anyone have any suggestions?

EDIT
I’ve run some more tests. model.evaluate() works fine as well. Does Tensorflow cache results on the GPU, then send it down to the CPU in one big set, or does it send results down per batch, flush buffers, and continue? I suspect it’s the former, because my outputs end up being (1050+1050+6) x num_rows due to the architecture