I have a custom dataset and I would like to log the plots to tensorboard.

The dataset is the following:

```
import numpy as np
from scipy.stats import multivariate_normal
num_dim = 2
num_discrete_values = 8
coords = np.linspace(-2, 2, num_discrete_values)
rv = multivariate_normal(mean=[0.0, 0.0], cov=[[1, 0], [0, 1]], seed=42)
grid_elements = np.transpose([np.tile(coords, len(coords)), np.repeat(coords, len(coords))])
prob_data = rv.pdf(grid_elements)
prob_data = prob_data / np.sum(prob_data)
```

And i can visualise this dataset using the following:

```
import matplotlib.pyplot as plt
from matplotlib import cm
mesh_x, mesh_y = np.meshgrid(coords, coords)
grid_shape = (num_discrete_values, num_discrete_values)
fig, ax = plt.subplots(figsize=(9, 9), subplot_kw={"projection": "3d"})
prob_grid = np.reshape(prob_data, grid_shape)
surf = ax.plot_surface(mesh_x, mesh_y, prob_grid, cmap=cm.coolwarm, linewidth=0, antialiased=False)
fig.colorbar(surf, shrink=0.5, aspect=5)
plt.show()
```

For my use case i would like to use the tensorboard logger/summary writer to store this image. However, tensorboard cannot log 3d projections as this is currently constructed and an error occurs due to this. Does anyone have any idea on how to transform this into something compatible with the tensorboard logger?