Issue with TPUs on Google Colab when training BERT

Hi there,

I’m pretty new at this, so not sure if this is a bug or I’m doing something wrong.

I’m trying to run BERT in Google Colab using TPU, however I’m getting an error message which can be seen here.

Tensorflow version 2.8.0

Code I’m using for loading the TPU is vastly based on the original code for pre-training T5 by Google taken from here:

print("Installing dependencies...")
%tensorflow_version 2.x

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

BASE_DIR = "gs://bucket-xx" #@param { type: "string" }
if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
DATA_DIR = os.path.join(BASE_DIR, "data/text.csv")
MODELS_DIR = os.path.join(BASE_DIR, "models/bert")
ON_CLOUD = True


if ON_CLOUD:
  print("Setting up GCS access...")
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "v2-8"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.enable_eager_execution()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

This is the code that I’m using to train BERT:

!python /content/scripts/run_mlm.py \
--model_name_or_path bert-base-cased \
--tpu_num_cores 8 \
--validation_split_percentage 20 \
--line_by_line \
--learning_rate 2e-5 \
--per_device_train_batch_size 128 \
--per_device_eval_batch_size 256 \
--num_train_epochs 4 \
--output_dir MODELS_DIR \
--train_file /content/text.csv

the run_mlm.py script is taken from the original transformers repo and can be seen here.

I found a very similar issue which can be seen here.

Any help is much appreciated thanks.