Implementing 1D cross-correlation similarity loss

Hey,
I’m looking to add a loss to my NLP model which regularizes it such that the output confidences over the vocabulary are structurally similar to the input one-hot sequences, but allowing for re-arrangement of the position of the tokens. Can anyone recommend the right way to implement this using tensorflow’s vectorized convolution ops?

pseudocode looks something like this:
originals = Tensor(shape=(batch_size, sequence_length, vocabulary_size) #one-hot inputs
confidences = Tensor(shape=(batch_size, sequence_length, vocabulary_size) #output from model

#join sequence into one dimension
#reshape to (batch_size, sequence_length * vocabulary-size)
originals_flat = tf.reshape(originals, (batch_size, -1)
confidences_flat = tf.reshape(confidences, (batch_size, -1)

similarity_scores = []
for one_hot, conf in zip(originals_flat, confidences_flat):
# “sliding window” comparison, see cross-correlation
cross_corr = np.correlate(conf, one_hot, mode=“full”)

max_corr = np.max(cross_corr)

# Normalize the maximum correlation value
similarity = (
    2
    * max_corr
    / (np.sum(one_hot) + np.sum(conf) + tf.keras.backend.epsilon())
)

similarity_scores.append(similarity)

which gets me the correct result i’m looking for. however, I’d like to remove the loop and instead implement it on the original batched tensor using tf.nn.convolution (or conv1d, conv2d).Unfortunately, no amount of fiddling with the shapes or strides manages to get me the values or even output shapes i’m looking for when a batch dimension is included. I managed to get it working by using tf.map_fn, but this in essence is just re-creating the loop I’m trying to pass down to the actual conv1d vectorized code. The revised tensorflow version is actually slower than the numpy version with the for loop.

originals = tf.constant([[[0, 1, 0], [0, 1, 0], [1, 0, 0]],
[[1, 0, 0], [0, 0, 1], [0, 1, 0]]], dtype=tf.float32)
reconstructions = tf.constant([[[0, 1, 0], [0, 1, 0], [1, 0, 0]],
[[1, 0.5, 0], [0, 0.5, 0.5], [1, 0, 0]]], dtype=tf.float32)

reshape from (batch_size, seq_len, vocab_size) to (batch_size, seq_len * vocab_size)

confidences_flat = tf.reshape(reconstructions, (reconstructions.shape[0], -1))
one_hot_sequences_flat = tf.reshape(originals, (originals.shape[0], -1))

@tf.function()
def _tf_1d_cross_correlation(args):
originals, reconstructions = args

data = tf.reshape(originals, [1, -1, 1], name='data')
kernel = tf.reshape(reconstructions, [-1, 1, 1], name='kernel')

#cross-correlate
res = tf.squeeze(tf.nn.conv1d(data, kernel, 1, 'VALID'))

#normalize
similarity = 2 * res / (tf.reduce_sum(originals) + tf.reduce_sum(reconstructions) + tf.keras.backend.epsilon())

return similarity

Apply the function to each pair of sequences

similarities = tf.map_fn(_tf_1d_cross_correlation, (one_hot_sequences_flat, confidences_flat), dtype=tf.float32)

print(similarities)

result: tf.Tensor([1. 0.46153846], shape=(2,), dtype=float32)

Okay, I think I have it figured out. Let me know if someone else has a better idea.

#To remove the contribution of the pad tokens
mask = [1. for _ in range(self.vocabulary_size)]
mask[self.pad_index] = 0.
mask = tf.constant(mask, dtype=tf.float32)
# set values in the pad index to 0
masked_confidences = reconstructions * mask
masked_one_hots = originals * mask

# reshape from (batch_size, seq_len, vocab_size) to (batch_size, seq_len * vocab_size)
confidences_flat = tf.reshape(masked_confidences, (masked_confidences.shape[0], -1))
one_hot_sequences_flat = tf.reshape(masked_one_hots, (masked_one_hots.shape[0], -1))

#calculate the padding needed to match the numpy "FULL" padding behavior
padding = (reconstructions.shape[1] - 1) // 2

#manually add padding to the beginning and end of the sequences
padded_originals = tf.pad(masked_one_hots, [[0, 0], [padding, padding], [0, 0]])
padded_reconstructions = tf.pad(masked_confidences, [[0, 0], [padding, padding], [0, 0]])

#flatten the sequences into one dimension
confidences_flat = tf.reshape(padded_reconstructions, (padded_reconstructions.shape[0], -1))
one_hot_sequences_flat = tf.reshape(padded_originals, (padded_originals.shape[0], -1))

#normalize each cross-correlation to the sum of the remaining input magnitude after masking of pad tokens
batched_confidences_magnitudes = tf.math.reduce_sum(confidences_flat, axis=-1)
batched_originals_magnitudes = tf.math.reduce_sum(one_hot_sequences_flat, axis=-1)
normalization_tensor = (batched_confidences_magnitudes + batched_originals_magnitudes) * 0.5 + tf.keras.backend.epsilon()
#reshape from (batch_size,) to (batch_size, 1) to avoid "non-broadcastable shapes" error
normalization_tensor = tf.expand_dims(normalization_tensor, axis=-1)

#restrucutre the data in the format expected by tf.conv1d
#input tensor of shap [batch_shape, in_width, in_channels]
data = tf.reshape(confidences_flat, [confidences_flat.shape[0], -1, 1], name='data')
#a filter / kernel tensor of shape [filter_width, in_channels, out_channels]
kernel = tf.reshape(one_hot_sequences_flat, [-1, 1, 1], name='kernel')


cross_corr = tf.squeeze(tf.nn.conv1d(input=data, filters=kernel, stride=1, padding='SAME'))
#normalize to each item in the batch's remaining magnitude after masking pad tokens
normalized = cross_corr / normalization_tensor
similarity = tf.math.reduce_max(normalized, axis=1)
loss = (-1) * tf.reduce_mean(similarity)
return loss