# Implementation of soft argmax with custom gradient

I am trying to implement the soft argmax operation in Tensorflow. I want to have the normal argmax in the forward pass, and the softmax approximation in the backward pass. To be precise, my input is in format NCHW, where C is 2 channels. The problem that I am facing right now is the dimensionality reduction because of the slicing. My gradient and output are of the same shape (1 channel output), however they differ from the input which is not allowed by Tensorflow as far as I understand.

`tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 409600 values, but the requested shape has 819200`

My current best try:

``````@tf.custom_gradient
def soft_argmax(x):

@tf.function
def argmax_soft(x):
# Soft argmax is softmax*index, which simplifies to only the first index for the 2 class case
out = tf.nn.softmax(x, axis=1)[:,1,:,:]
return

``````

How could I design such a function? Note that I am specifically searching for something that defines a custom gradient or function.
For the idea, this worked, but not the solution that I am looking for:

``````out_no_grad = tf.argmax(x, axis=1)

``````

Edit: In the meantime, I figured out that this implementation is actually working by expanding dims of `out_no_grad`. To make this post still interesting for future readers, I will make some assumptions and ask some questions about `tf.custom_gradient`:

1. So far, my understanding is that the gradient in the `grad_fn` should be a broadcastable shape with respect to `dy`, not per say the same shape. Am I correct in this?
2. However, `dy` should always have the same shape as x, right?
3. Can `out_no_grad` have a different shape/format than the gradient? In my implementation, as stated, I needed the soft_argmax to be in NCHW format, but for an optimization pass later I need it back in NHWC again. Would the following still make sense? Iām getting tangled up a bit in the gradients here. My gut tells me that it would not make sense and will break gradient flow.
``````@tf.custom_gradient
def soft_argmax(x):