Is `tfp.sts.impute_missing_values` trainable?

I have a series of coordinates for which some values [[x_i, y_i, z_i, t_i]] are missing, and would like to impute them. These are trajectories of an object I’m interested in following, so while the timepoints t_i are known some of the measurements [x_i, y_i, z_i] are missing. There are about ~ 100 000 trajectories, and with this type of volume I’m thinking I can train a better imputer than something out of the box.

Is it possible to train tfp.sts.impute_missing_values by subsetting known values so it “learns” to impute in a bespoke fashion? For instance, say I have a linear trajectory:

time_series = [[0.1, 1.8, 1.9, 0], [0.2, 1.9, 2.0, 1], [0.3, 2.0, 2.1, 2], [0.4, 2.1, 2.2, 3], [0.5, 2.2, 2.3, 4]]

By dropping out a particular value (ie. t=2),
X_train = [[0.1, 1.8, 1.9, 0], [0.2, 1.9, 2.0, 1], [nan, nan, nan, 2], [0.4, 2.1, 2.2, 3], [0.5, 2.2, 2.3, 4]]

I’d like my model to learn to impute the missing coordinates at the specified timepoint above?