For my current reasearch, I want to add the Federated learning strategy SCAFFOLD from Karimireddy to my experiment. I am using tensorflow for my reasearch, however I can only find implementations in pytorch online, and I ran into some issues implementing it myself.
Below I attach the algorithm from the paper:
Basically, it just adds (code: line 10 -ci+c) some additional values (control variates) to the gradient before applying it to the model. There is a local control variates (c_i) which is different between all clients, and a global control variates (c) which gets updated every round after model aggregation.
Here the model weights also gets aggregated same as in FedAvg.
I tried to write this in a optimizer in tensorflow, due to the length of the code, I added it in pastebin: https://pastebin.com/ZtX4DvhG
Everything works and converges when directly using tf.SGD optimizer with the same learning rate. But when I use my implementation, it doesn’t converge and the model weights explode after few rounds. (approx. 5-7 aggregations). the weight and loss becomes: NaN.
On the client side, I do the following (I simplified it to keep it short):
# The optimizer Scaffold implementation is in the pastebin link above. optimizer = Scaffold(learning_rate=hyperparams.SGD_learning_rate) # load global model weights and global/local control variates to the optimizer optimizer.set_controls(global_model.weights, hyperparams.scaffold) model.compile(optimizer=optimizer, loss=loss, metrics=metrics) model.fit() # compute new local control variates store locally for next round local_controls = optimizer.get_new_client_controls( global_model.get_weights(), model.get_weights(), option=option, ) # compute local control variates and send to server local_controls_diff = ( [new - old for new, old in zip(local_controls, old_local_controls)] if old_local_controls else local_controls )
The aggregation looks like this:
total_client_count = self.total_client_count selected_clients_count = len(client_parameters) global_params = self.global_weights global_controls = self.global_controls global_lr = self.global_lr delta_weights = [ [ c_layer - g_layer for g_layer, c_layer in zip(global_params, client_i) ] for client_i in client_parameters ] delta_avg_weights = [ reduce(np.add, layer_updates) / selected_clients_count for layer_updates in zip(*delta_weights) ] delta_controls = [ [ c_layer - g_layer for g_layer, c_layer in zip(global_controls, client_i) ] for client_i in client_contorls ] delta_avg_controls = [ reduce(np.add, layer_updates) / selected_clients_count for layer_updates in zip(*delta_controls) ] # calc new global weights for next round # x = x + lr_g * delta_x new_global_weights = [ global_layer + global_lr * delta_avg for global_layer, delta_avg in zip(global_params, delta_avg_weights) ] # clac new global control variates for next round # c = c + |S|/N * delta_ci new_global_controls = [ global_layer + (selected_clients_count/total_client_count) * delta_avg for global_layer, delta_avg in zip(global_controls, delta_avg_controls) ]
The new model weights and global control variates will be sent to each client.