OOM when calling model.predict()

Hi everyone,

I’m having an issue with model.predict() causing OOM errors. Strangely, this doesn’t happen while training, only while predicting on the val dataset.

I’m using a TFRecordDataset with batches of size 512

val_dataset = tf.data.TFRecordDataset(filenames=[val_files])
.map(tf_parse, num_parallel_calls=6)
.batch(BATCH_SIZE).prefetch(16)
.repeat(1)

def tf_parse(eg):
example = tf.io.parse_example(eg[tf.newaxis], {“features”: tf.io.FixedLenFeature(shape=(1050,), dtype=tf.float32),
“targets”: tf.io.FixedLenFeature(shape=(6,), dtype=tf.float32)})
return example[“features”][0], (example[“features”][0], example[“targets”][0], example[“targets”][0])

As I stated above, training works fine, but when I try to predict on the entire val_dataset I get an OOM error. Trying smaller bits of the dataset with model.predict(val_dataset.take(50)), for example, works fine, but not with the entire val_dataset. Even specifying batch_size=1 in predict doesn’t help at all.

The input data is 1050 columns of numeric data, and there are about 540k rows of data. During training, GPU memory usage is around 2.5/8.0GB.

Does anyone have any suggestions?

EDIT
I’ve run some more tests. model.evaluate() works fine as well. Does Tensorflow cache results on the GPU, then send it down to the CPU in one big set, or does it send results down per batch, flush buffers, and continue? I suspect it’s the former, because my outputs end up being (1050+1050+6) x num_rows due to the architecture

Hi @TalhaAsmal

Welcome to the TensorFLow Forum!

You can use few techniques to fix the OOM error such as reduce the batch size, increase the available GPU memory or can use tf.data.experimental.AUTOTUNE with .map() to configure the dataset for performance.

val_dataset = tf.data.TFRecordDataset(filenames=[val_files]) \
    .map(tf_parse, num_parallel_calls=tf.data.AUTOTUNE) \
    .batch(BATCH_SIZE).prefetch(8).repeat(1)

Please try again using these methods and let us know if the issue still persists. Thank you

It may be too late, but try calling the following method after each call to predict() that you make:

model.make_predict_function(force=True)