How to log a matplotlib 3D projection to TensorBoard

I have a custom dataset and I would like to log the plots to tensorboard.

The dataset is the following:

import numpy as np
from scipy.stats import multivariate_normal

num_dim = 2
num_discrete_values = 8
coords = np.linspace(-2, 2, num_discrete_values)

rv = multivariate_normal(mean=[0.0, 0.0], cov=[[1, 0], [0, 1]], seed=42)
grid_elements = np.transpose([np.tile(coords, len(coords)), np.repeat(coords, len(coords))])
prob_data = rv.pdf(grid_elements)
prob_data = prob_data / np.sum(prob_data)

And i can visualise this dataset using the following:

import matplotlib.pyplot as plt
from matplotlib import cm

mesh_x, mesh_y = np.meshgrid(coords, coords)
grid_shape = (num_discrete_values, num_discrete_values)

fig, ax = plt.subplots(figsize=(9, 9), subplot_kw={"projection": "3d"})
prob_grid = np.reshape(prob_data, grid_shape)
surf = ax.plot_surface(mesh_x, mesh_y, prob_grid, cmap=cm.coolwarm, linewidth=0, antialiased=False)
fig.colorbar(surf, shrink=0.5, aspect=5)

For my use case i would like to use the tensorboard logger/summary writer to store this image. However, tensorboard cannot log 3d projections as this is currently constructed and an error occurs due to this. Does anyone have any idea on how to transform this into something compatible with the tensorboard logger?

To log a matplotlib 3D projection to TensorBoard, you’ll first convert the 3D plot into an image by rendering it to an in-memory buffer using io.BytesIO(). Then, read this buffer into a NumPy array and log it to TensorBoard using tf.summary.image(). This approach works because TensorBoard can log images, but not direct 3D projections from matplotlib.