In the code below, there’s a line tape.gradient([reg_loss, cat_loss], model.trainable_variables)
. For trainable weights that affects both reg_loss and cat_loss, are the gradients for those weights just averaged or summed with respect to the two losses?
import tensorflow as tf from tensorflow.keras.layers import Dense from tensorflow.keras import Model from sklearn.datasets import load_iris tf.keras.backend.set_floatx('float64') iris, target = load_iris(return_X_y=True) X = iris[:, :3] y = iris[:, 3] z = target ds = tf.data.Dataset.from_tensor_slices((X, y, z)).shuffle(150).batch(8) class MyModel(Model): def __init__(self): super(MyModel, self).__init__() self.d0 = Dense(16, activation='relu') self.d1 = Dense(32, activation='relu') self.d2 = Dense(1) self.d3 = Dense(3, activation='softmax') def call(self, x, training=None, **kwargs): x = self.d0(x) x = self.d1(x) a = self.d2(x) b = self.d3(x) return a, b model = MyModel() loss_obj_reg = tf.keras.losses.MeanAbsoluteError() loss_obj_cat = tf.keras.losses.SparseCategoricalCrossentropy() optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3) loss_reg = tf.keras.metrics.Mean(name='regression loss') loss_cat = tf.keras.metrics.Mean(name='categorical loss') error_reg = tf.keras.metrics.MeanAbsoluteError() error_cat = tf.keras.metrics.SparseCategoricalAccuracy() @tf.function def train_step(inputs, y_reg, y_cat): with tf.GradientTape() as tape: pred_reg, pred_cat = model(inputs) reg_loss = loss_obj_reg(y_reg, pred_reg) cat_loss = loss_obj_cat(y_cat, pred_cat) gradients = tape.gradient([reg_loss, cat_loss], model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) loss_reg(reg_loss) loss_cat(cat_loss) error_reg(y_reg, pred_reg) error_cat(y_cat, pred_cat)