Federated Learning using TensorFlow Federated

August 17, 2021

Centralized machine learning involves having the model and the dataset on the same device. Companies such as Google upload data to the cloud to train their machine learning models.

Federated Learning flips this paradigm. Instead of transfering data to the cloud, we send cloud-based models to our devices. These models are then trained locally on our devices.

Once we have trained these models locally, the updated models are sent to the server instead of data. The server checks these models and then updates the global model on the cloud. This process is referred to as Federated Learning.

TensorFlow Federated is an open-source framework by Google that is used to implement Federated Learning.


In this article, we will learn how TensorFlow Federated can be utilized by researchers and machine learning experts to implement federated learning on datasets.


To understand the contents of this article, you need to be familiar with:

  • The Python programming language.
  • The TensorFlow machine learning framework.
  • Federated Learning.
  • NIST dataset


Introduction to TensorFlow Federated (TFF)

TFF is an open-source framework for Federated learning performed on decentralized data. It is spearheaded by Google and has gained popularity in the recent years.

TFF has three main features:

  1. TFF is architecture-agnostic.

This means that it can compile all code into an abstract representation. As a result, it can be deployed in a diverse environment.

  1. TFF saves effort.

It is designed to mitigate the pain points that we developers face when developing federated learning systems.

Some of these challenges include interleaving the different types of logic, the global vs local perspective on communication, and tension between the order of construction vs execution.

  1. TFF has many extensions.

Some of the available extensions include differential privacy, compression, and quantization.

TensorFlow federated layers

TFF offers two main layers:

  1. Federated Learning (FL) API The FL API is a high-level API that implements federated training and evaluation. It can be applied to existing TensorFlow models or data.

  2. Federated Core (FC) API FC is a low level framework below the Federated Learning API. This API provides generic expressions to run and simulate custom types of computations, as well as control your own orchestrations. It also has a local runtime that supports simulations.

In this tutorial, we will focus on the FL API and the code behind it.

Application of Federated learning

There are different ways you can get involved depending on your interest:

  1. A machine learning developer can apply Federated Learning APIs to existing TensorFlow models.

  2. A federated learning researcher can help to design new federated learning algorithms using the FC API.

  3. A systems researcher can assist in optimizing generated computation structures.

  4. A system developer can help in integrating TFF with different development environments.

The code behind TensorFlow Federated (TFF)

First, let’s briefly take a look at how the Keras model looks like:

    def create_compiled_keras_model():
        model = tf.keras.models.Sequential([
                10, activation=tf.nn.softmax, kernel_initializer = 'zeros', input_shape = (784, 


    return model

The Keras model uses a Sequential() API as it allows us to create models layer-by-layer. This is ideal for solving simple neural network problems.

However, its not ideal for complex networks that share layers or have many inputs/outputs such as residual and siamese networks.

In that case, functional APIs are used. The functional API has more flexibility since one can easily define models where layers connect to more than just the previous and next layers.

Refer to the following video to understand these differences in depth:

We will import it into our main function model_fn using the create_compiled_keras_model() method.

    def model_fn():

        keras_model = create_compiled_keras_model()

        return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

The above code shows where you will add the Keras model.

     state = train.initialize() 

    for _ in range (5):
        state, metrics = train.next(state, train_data)
        print (metrics.loss) 

In the above code, the initialize() method retrieves the initial server state. It then calls train.next which will run our federated training. This includes sending the initial server state to each of the clients.

Each client will run its own local rounds of training and then send an update to the server. The server stores the new aggregated global model produced from the decentralized data.

    eval = tff.learning.build_federated_evaluation(model_fn)
    metrics = eval(state.model, test_data)

Finally, we can perform federated evaluation to understand the state of our trained model. The build_federated_evaluation() method helps to perform this federated evaluation.

Here’s how the whole code looks like for TFF:

    train_data, test_data = 

    def model_fn():

        keras_model = create_keras_model()

        return tff.learning.from_keras_model(keras_model, sample_batch)
    train = tff.learning.build_federated_averaging_process(model_fn)

    state = train.initialize() 

    for _ in range (5):
        state, metrics = train.next(state, train_data)
        print (metrics.loss) 

    eval = tff.learning.build_federated_evaluation(model_fn)
    metrics = eval(state.model, test_data)

In summary, the general components for the FL API include:

  1. Models
  • tff.learning.Model

  • create_keras_model()

  1. Federated computation builders TFF provides two builder functions:
  • tff.learning.build_federated_averaging_process generates the federated computations for federated training.

  • tff.learning.build_federated_evaluation generates the federated computations for federated evaluation.

Let’s use the MNIST training example to introduce the Federated Learning (FL) API layer of TFF.

Step 1: Installing TensorFlow Federated

Please make sure to install TensorFlow Federated before importing it into your notebook. Failure to do this might result in an error.

We install TensorFlow Federated using the following command:

pip install tensorflow-federated --upgrade

Step 2: Importing dependencies into our notebook

import tensorflow as tf
import tensorflow_federated as tff

We’ve imported both tensorflow and tensorflow federated into our project.

Step 3: Simulation dataset

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
def client_data(n):
  return emnist_train.create_tf_dataset_for_client(source.client_ids[n]).map(
      lambda e: (tf.reshape(e['pixels'], [-1]), e['label'])

The simulation dataset used is the federated version of the MNIST dataset called NIST and is provided by the Leaf project. Leaf provides a benchmarking framework for federated learning.

Why a federated version of the dataset?

It’s because the dataset in FL is obtained from multiple users. This poses a unique set of challenges that normal versions of the dataset don’t exhibit.

We import the federated data into the project using the load_data() function.

Step 4: Training using Federated data

train_data = [client_data(n) for n in range(3)]

trainer = tff.learning.build_federated_averaging_process(
  client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
state = trainer.initialize()
for _ in range(50):
  state, metrics = trainer.next(state, train_data)

In the training bit, you’ll notice that only a subset of client devices are selected to receive the training model. This is because not all devices are eligible. At any given time, only a few devices may have relevant data to solve your problem.


In TFF, after the model has been trained on the selected devices, results are obtained and the loss calculated.

In the experiment above, the training loss is decreasing after each round of federated training, indicating that the model is converging.

We’ve set our training to go for 50 rounds. The training loss at the end of the training is 0.02700758 down from 12.931682 recorded at the start of the training.

In realistic situations, users can join and exit the experiment freely. This means that one would randomly select a sample of users for each round. However, to make things simple, and allow the system to converge quickly, we’ll reuse the same users.

Summary of the implementation

Feel free to modify parameters such as batch sizes, number of users, epochs, and learning rates to simulate training on random users.


This was a simple introduction to TensorFlow Federated and the FC API. We used the MNIST training example to introduce the Federated Learning (FL) API layer of TFF.

The code I’ve shown above is open-source and available on Github. You can access it using this link.

Remember, with Federated Learning, we can learn from everyone, without learning about anyone.

Additional resources

Peer Review Contributions by: Collins Ayuya

About the author

Willies Ogola

Willies Ogola is pursuing his Master’s in Computer Science in Hubei University of Technology, China. His research direction is on Artificial Intelligence and Embedded Systems. He likes researching during his free time and is passionate about technology.

This article was contributed by a student member of Section's Engineering Education Program. Please report any errors or innaccuracies to enged@section.io.