Utilizing XLA Optimization with Tensorflow Keras on GPU

According to the documentation, using the XLA optimization on Tensorflow Keras involves adding the following line during model compilation:

model.compile(optimizer=“adam”, jit_compile=True)

This works fine on CPU (both with and without jit_compile=True) but on Colab’s GPU I get a Graph Execution error (but works fine without jit_compile=True) when calling model.fit, so I’m wondering what the problem could be:

/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.traceback)
—> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
53 ctx.ensure_initialized()
54 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
—> 55 inputs, attrs, num_outputs)
56 except core._NotOkStatusException as e:
57 if name is not None:

InvalidArgumentError: Graph execution error:

Detected at node ‘StatefulPartitionedCall’ defined at (most recent call last):
File “/usr/lib/python3.7/runpy.py”, line 193, in _run_module_as_main
main”, mod_spec)
File “/usr/lib/python3.7/runpy.py”, line 85, in _run_code
exec(code, run_globals)
File “/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py”, line 16, in
app.launch_new_instance()
File “/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py”, line 846, in launch_instance
app.start()
File “/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py”, line 612, in start
self.io_loop.start()
File “/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py”, line 132, in start
self.asyncio_loop.run_forever()
File “/usr/lib/python3.7/asyncio/base_events.py”, line 541, in run_forever
self._run_once()
File “/usr/lib/python3.7/asyncio/base_events.py”, line 1786, in _run_once
handle._run()
File “/usr/lib/python3.7/asyncio/events.py”, line 88, in _run
self._context.run(self._callback, *self._args)
File “/usr/local/lib/python3.7/dist-packages/tornado/ioloop.py”, line 758, in _run_callback
ret = callback()
File “/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py”, line 300, in null_wrapper
return fn(*args, **kwargs)
File “/usr/local/lib/python3.7/dist-packages/tornado/gen.py”, line 1233, in inner
self.run()
File “/usr/local/lib/python3.7/dist-packages/tornado/gen.py”, line 1147, in run
yielded = self.gen.send(value)
File “/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py”, line 381, in dispatch_queue
yield self.process_one()
File “/usr/local/lib/python3.7/dist-packages/tornado/gen.py”, line 346, in wrapper
runner = Runner(result, future, yielded)
File “/usr/local/lib/python3.7/dist-packages/tornado/gen.py”, line 1080, in init
self.run()
File “/usr/local/lib/python3.7/dist-packages/tornado/gen.py”, line 1147, in run
yielded = self.gen.send(value)
File “/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py”, line 365, in process_one
yield gen.maybe_future(dispatch(*args))
File “/usr/local/lib/python3.7/dist-packages/tornado/gen.py”, line 326, in wrapper
yielded = next(result)
File “/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py”, line 268, in dispatch_shell
yield gen.maybe_future(handler(stream, idents, msg))
File “/usr/local/lib/python3.7/dist-packages/tornado/gen.py”, line 326, in wrapper
yielded = next(result)
File “/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py”, line 545, in execute_request
user_expressions, allow_stdin,
File “/usr/local/lib/python3.7/dist-packages/tornado/gen.py”, line 326, in wrapper
yielded = next(result)
File “/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py”, line 306, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File “/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py”, line 536, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File “/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py”, line 2855, in run_cell
raw_cell, store_history, silent, shell_futures)
File “/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py”, line 2881, in _run_cell
return runner(coro)
File “/usr/local/lib/python3.7/dist-packages/IPython/core/async_helpers.py”, line 68, in pseudo_sync_runner
coro.send(None)
File “/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py”, line 3058, in run_cell_async
interactivity=interactivity, compiler=compiler, result=result)
File “/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py”, line 3249, in run_ast_nodes
if (await self.run_code(code, result, async
=asy)):
File “/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py”, line 3326, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File “”, line 1, in
get_ipython().run_cell_magic(‘time’, ‘’, ‘tf.random.set_seed(123)\n\nno_fb_esn_layer = NoFeedbackESN(dtdivtau=dt/tau,\n units=n,\n output_size=m,\n activation=‘tanh’,\n seed=123)\nforce_model = FORCEModel(force_layer=no_fb_esn_layer)\nforce_model.compile(metrics=[“mae”], jit_compile=True)\nhistory = force_model.fit(x=inputs_with_hint,\n y=target,\n epochs=n_epoch)\npredictions_force = force_model.predict(inputs_with_hint)\n’)
File “/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py”, line 2359, in run_cell_magic
result = fn(*args, **kwargs)
File “”, line 2, in time
File “/usr/local/lib/python3.7/dist-packages/IPython/core/magic.py”, line 187, in
call = lambda f, *a, **k: f(*a, **k)
File “/usr/local/lib/python3.7/dist-packages/IPython/core/magics/execution.py”, line 1310, in time
exec(code, glob, local_ns)
File “”, line 12, in
File “/content/tension/tension/base.py”, line 748, in fit
**kwargs)
File “/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py”, line 64, in error_handler
return fn(*args, **kwargs)
File “/usr/local/lib/python3.7/dist-packages/keras/engine/training.py”, line 1384, in fit
tmp_logs = self.train_function(iterator)
File “/usr/local/lib/python3.7/dist-packages/keras/engine/training.py”, line 1021, in train_function
return step_function(self, iterator)
File “/usr/local/lib/python3.7/dist-packages/keras/engine/training.py”, line 1010, in step_function
outputs = model.distribute_strategy.run(run_step, args=(data,))
Node: ‘StatefulPartitionedCall’
Trying to access resource Resource-18-at-0x7cf5480 located in device /job:localhost/replica:0/task:0/device:CPU:0 from device /job:localhost/replica:0/task:0/device:GPU:0
[[{{node StatefulPartitionedCall}}]] [Op:__inference_train_function_1161]