Implementing Sparse Ternary Compression inside Tensorflow Federated simple_fedavg example

Hi all,

I’m actually trying to implement an Algorithm that i found by reading the paper Robust and Communication-Efficient Federated Learning from Non-IID Data by [Simon Wiedemann] [Klaus-Robert Müller] [Wojciech].

I wanted to try to implement it inside the simple_fedavg offered by Tensorflow Federated. I have actually already created the algorithm and seems to works fine in test case, the real problem is to put it inside the simple_fedavg project. I don’t get where i could change what the client send to the server and what the server expect to recieve.

So, basically, from client_update i don’t want to send weights_delta, but instead i want to send a simple list like [ [list of negatives indexes] [list of positives indexes] [average value] ], then on the server side i will recreate the weights like explained in the paper. But i can’t understand how to change this behaviour.

English is not my main language, so i hope i have explained the problem good enough.

  test = weights_delta.copy()
  for index in range(len(weights_delta)):
      original_shape = tf.shape(weights_delta[index])
      tensor = tf.reshape(test[index], [-1])
      negatives, positives, average = test_stc.stc_compression(tensor, sparsification_rate)
      test[index] = test_stc.stc_decompression(negatives, positives, average, tensor.get_shape().as_list(), original_shape)
      test[index] = test_stc.stc_compression(tensor, sparsification_rate)
  client_weight = tf.cast(num_examples, tf.float32)
  return ClientOutput(test, client_weight, loss_sum / client_weight)

This is the behaviour that i would like, stc_compression return a tuple. Then i would like to access to each “test” variable sent from a client inside the server and recreate all the weights.

@lgusm Can we have some TF federated team member subscribed to this federated tag?


I’m using Tensorflow Federated, but i’m actually have some problem while trying to executes some operation on the server after reading the client update.

This is the function

def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.
      server_state: A `ServerState`.
      federated_dataset: A federated `` with placement
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, server_message_at_client))

    weight_denom = client_outputs.client_weight

    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
    round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

    return server_state, round_loss_metric, client_outputs.weights_delta.comp

I want to print the client_outputs.weights_delta and doing some operation on the weights that the client sent to the server before using the tff.federated_mean but i don’t get how to do so.

When i try to print i get this

Call(Intrinsic('federated_map', FunctionType(StructType([FunctionType(StructType([('weights_delta', StructType([TensorType(tf.float32, [5, 5, 1, 32]), TensorType(tf.float32, [32]), ....]) as ClientOutput, PlacementLiteral('clients'), False)))]))

Any way to modify those elements?

I tried with using return client_outputs.weights_delta.comp doing the modification in the main (i can do that) and then i tried to invocate a new method for doing the rest of the operations for the server update, but the error is:

AttributeError: 'IterativeProcess' object has no attribute 'calculate_federated_mean'
where calculate_federated_mean was the name of the new function i created.

This is the main:

 for round_num in range(FLAGS.total_rounds):
        sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False)
        sampled_train_data = [train_data.create_tf_dataset_for_client(client) for client in sampled_clients]

        server_state, train_metrics, value_comp =, sampled_train_data)

        print(f'Round {round_num}')
        print(f'\tTraining loss: {train_metrics:.4f}')
        if round_num % FLAGS.rounds_per_eval == 0:
            accuracy = evaluate(keras_model, test_data)
            print(f'\tValidation accuracy: {accuracy * 100.0:.2f}%')
            tf.print(tf.compat.v2.summary.scalar("Accuracy", accuracy * 100.0, step=round_num))

Based on the simple_fedavg project from github Tensorflow Federated simple_fedavg as basic project.