Visualisation of LSTM for time series data

Hi all,

I have a trained LSTM model which is performing a binary classification of waveform data. The waveforms themselves are different lengths so are stored in ragged tensors.

I want to be able to visualise which elements of a given waveform the model thinks are most important. Ideally an end result being a plot of the waveform with each datapoint of the waveform a colour depending on its value of importance.

I came across a Saliency model which could output gradients but cant seem to be able to get it working with ragged tensors.

I was wondering if anyone had ideas on how to best achive this? I’ve pasted my current code below.

Thnaks in advance!

Define a Saliency Model

class SaliencyModel(tf.keras.Model):
def init(self, base_model, input_shape):
super(SaliencyModel, self).init()
self.base_model = base_model
self.trainable_input = self.add_weight(
name=‘trainable_input’,
shape=(None,) + tuple(input_shape[1:]), # Extracting shape from input_shape
initializer=‘zeros’,
trainable=True,
dtype=tf.float32
)

def call(self, inputs):
    with tf.GradientTape() as tape:
        output = self.base_model(inputs)
    gradients = tape.gradient(output, self.trainable_input)
    return gradients

saliency_model = SaliencyModel(base_model=model, input_shape=model.input)