TF-Replicator: Distributed Machine Learning for Researchers

At DeepMind, the Research Platform Team builds infrastructure to empower and accelerate our AI research. Today, we are excited to share how we developed TF-Replicator, a software library that helps researchers deploy their TensorFlow models on GPUs and Cloud TPUs with minimal effort and no previous experience with distributed systems. TF-Replicator’s programming model has now been open sourced as part of TensorFlow’s tf.distribute.Strategy. This blog post gives an overview of the ideas and technical challenges underlying TF-Replicator. For a more comprehensive description, please read our arXiv paper.

A recurring theme in recent AI breakthroughs - from AlphaFold to BigGAN to AlphaStar -  is the need for effortless and reliable scalability. Increasing amounts of computational capacity allow researchers to train ever-larger neural networks with new capabilities. To address this, the Research Platform Team developed TF-Replicator, which allows researchers to target different hardware accelerators for Machine Learning, scale up workloads to many devices, and seamlessly switch between different types of accelerators. While it was initially developed as a library on top of TensorFlow, TF-Replicator’s API has since been integrated into TensorFlow 2.0’s new tf.distribute.Strategy.

While TensorFlow provides direct support for CPU, GPU, and TPU (Tensor Processing Unit) devices, switching between targets requires substantial effort from the user. This typically involves specialising code for a particular hardware target, constraining research ideas to the capabilities of that platform. Some existing frameworks built on top of TensorFlow, e.g. Estimators, seek to address this problem. However, they are typically targeted at production use cases and lack the expressivity and flexibility required for rapid iteration of research ideas.

Building a Distributed Machine Learning Library

Our original motivation for developing TF-Replicator was to provide a simple API for DeepMind researchers to use TPUs. TPUs provide scalability for Machine Learning workloads, enabling research breakthroughs such as state-of-the-art image synthesis with our BigGAN model. TensorFlow’s native API for TPUs differs from how GPUs are targeted, forming a barrier to TPU adoption. TF-Replicator provides a simpler, more user-friendly API that hides the complexity of TensorFlow’s TPU API. Critically, the Research Platform Team developed the TF-Replicator API in close collaboration with researchers across various machine learning disciplines to ensure the necessary flexibility and ease-of-use.

The TF-Replicator API

Code written using TF-Replicator looks similar to code written in TensorFlow for a single device, allowing users the freedom to define their own model run loop. The user simply needs to define (1) an input function that exposes a Dataset, and (2) a step function that defines the logic of their model (e.g. a single step of gradient descent):

Scaling computation to multiple devices requires the devices to communicate with each other. In the context of training Machine Learning models, the most common form of communication is to accumulate gradients for use in optimisation algorithms such as Stochastic Gradient Descent. We therefore provide a convenient method to wrap TensorFlow Optimizers, so that gradients are accumulated across devices before updating the model’s parameters. For more general communication patterns we provide MPI-like primitives, such as `all_reduce` and `broadcast`. These make it trivial to implement operations such as global batch normalisation, a technique that is crucial to scale up training of our BigGAN models (see Section 3 of the paper).

Input data is sent from the host to each GPU, which begin processing immediately. When information needs to be exchanged between GPUs, they synchronise before sending the data.

Implementation

For multi-GPU computation TF-Replicator relies on an “in-graph replication” pattern, where the computation for each device is replicated in the same TensorFlow graph. Communication between devices is achieved by connecting nodes from the devices’ corresponding sub-graphs. Implementing this in TF-Replicator was challenging, as communication can occur at any point in the data-flow graph. The order in which computations are constructed is therefore critical.

Our first idea was to build each device’s sub-graph concurrently in a separate Python thread. When encountering a communication primitive, the threads synchronise and the main thread inserts the required cross-device computation. After that, each thread would continue building its device’s computation. However, at the time we considered this approach, TensorFlow’s graph building API was not thread-safe which made concurrently building sub-graphs in different threads very difficult. Instead, we used graph rewriting to insert the communication after all devices’ sub-graphs had been built. When constructing the sub-graphs, placeholders are inserted in places where communication is required. We then collect all matching placeholders across devices and replace them with the appropriate cross-device computation.

When TF-Replicator builds an in-graph replicated computation, it first builds the computation for each device independently and leaves placeholders where cross-device computation has been specified by the user. Once the sub-graphs for all devices have been built, TF-Replicator connects them by replacing the placeholders with actual cross-device computation.

Building a Platform for AI Research at DeepMind

By collaborating closely with researchers throughout the design and implementation of TF-Replicator, we were able to build a library that allows users to easily scale computation across many hardware accelerators, while leaving them with the control and flexibility required to do cutting-edge AI research. For example, we added MPI-style communication primitives such as all-reduce following discussion with researchers. TF-Replicator and other shared infrastructure allows us to build increasingly complex experiments on robust foundations and quickly spread best practices throughout DeepMind.

At the time of writing, TF-Replicator is the most widely used interface for TPU programming at DeepMind. While the library itself is not constrained to training neural networks, it is most commonly used for training on large batches of data. The BigGAN model, for example, was trained on batches of size 2048 across up to 512 cores of a TPUv3 pod. In Reinforcement Learning agents with a distributed actor-learner setup, such as our importance weighted actor-learner architectures, scalability is achieved by having many actors generating new experiences by interacting with the environment. This data is then processed by the learner to improve the agent’s policy, represented as a neural network. To cope with an increasing number of actors, TF-Replicator can be used to easily distribute the learner across many hardware accelerators. These and other examples are described in more detail in our arXiv paper.

TF-Replicator is just one of many examples of impactful technology built by DeepMind’s Research Platform Team. Many of DeepMind’s breakthroughs in AI, from AlphaGo to AlphaStar, were enabled by the team. If you share our mission and are excited about accelerating state-of-the-art AI research, look out for open Software Engineering positions in Research Platform at https://deepmind.com/careers (machine learning experience is optional for these roles).

This work was completed by the Research Platform Team at DeepMind. We’d like to thank Frederic Besse, Fabio Viola, John Aslanides, Andy Brock, Aidan Clark, Sergio Gómez Colmenarejo, Karen Simonyan, Sander Dieleman, Lasse Espeholt, Akihiro Matsukawa, Tim Harley, Jean-Baptiste Lespiau, Koray Kavukcuoglu, Dan Belov and many others at DeepMind for their valuable feedback throughout the development of TF-Replicator. We'd also like to thank Priya Gupta, Jonathan Hseu, Josh Levenberg, Martin Wicke and others at Google for making these ideas available to all TensorFlow users as part of tf.distribute.Strategy.