Access and modify weights sent from client on the server tensorflow federated

I’m using Tensorflow Federated, but i’m actually have some problem while trying to executes some operation on the server after reading the client update.

This is the function

@tff.federated_computation(federated_server_state_type,
                           federated_dataset_type)
def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.
    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.
    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
    tf.print("run_one_round")
    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, server_message_at_client))

    weight_denom = client_outputs.client_weight


    tf.print(client_outputs.weights_delta)
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
    round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

    return server_state, round_loss_metric, client_outputs.weights_delta.comp

I want to print the client_outputs.weights_delta and doing some operation on the weights that the client sent to the server before using the tff.federated_mean but i don’t get how to do so.

When i try to print i get this

Call(Intrinsic('federated_map', FunctionType(StructType([FunctionType(StructType([('weights_delta', StructType([TensorType(tf.float32, [5, 5, 1, 32]), TensorType(tf.float32, [32]), ....]) as ClientOutput, PlacementLiteral('clients'), False)))]))

Any way to modify those elements?

I tried with using return client_outputs.weights_delta.comp doing the modification in the main (i can do that) and then i tried to invocate a new method for doing the rest of the operations for the server update, but the error is:

AttributeError: 'IterativeProcess' object has no attribute 'calculate_federated_mean'
where calculate_federated_mean was the name of the new function i created.

This is the main:

 for round_num in range(FLAGS.total_rounds):
        print("--------------------------------------------------------")
        sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False)
        sampled_train_data = [train_data.create_tf_dataset_for_client(client) for client in sampled_clients]

        server_state, train_metrics, value_comp = iterative_process.next(server_state, sampled_train_data)

        print(f'Round {round_num}')
        print(f'\tTraining loss: {train_metrics:.4f}')
        if round_num % FLAGS.rounds_per_eval == 0:
            server_state.model_weights.assign_weights_to(keras_model)
            accuracy = evaluate(keras_model, test_data)
            print(f'\tValidation accuracy: {accuracy * 100.0:.2f}%')
            tf.print(tf.compat.v2.summary.scalar("Accuracy", accuracy * 100.0, step=round_num))

Based on the simple_fedavg project from github Tensorflow Federated simple_fedavg as basic project.