Hi all, is there a way to determine from within an op if the op has been invoked under the context of a GradientTape
?
Something like the following:
@tf.function
def some_op():
gt = get_gradient_tape_somehow()
if gt:
# we are in a gradient tape context
return 1
return 0
strategy = SomeDistributionStrategy()
with strategy.scope():
with GradientTape() as tape:
out = strategy.run(some_op)