Cutting edge deep learning models are growing at an exponential rate: where last year’s GPT-2 had ~750 million parameters, this year’s GPT-3 has 175 billion. GPT is a somewhat extreme example; nevertheless, the “enbiggening” of the SOTA is driving larger and larger models into production applications, challenging the ability of even the most powerful of GPU cards to finish model training jobs in a reasonable amount of time.

To deal with these problems, practitioners are increasingly turning to distributed training. Distributed training is the set of techniques for training a deep learning model using multiple GPUs and/or multiple machines. Distributing training jobs allow you to push past the single-GPU memory bottleneck, developing ever larger and powerful models by leveraging many GPUs simultaneously.

This blog post is an introduction to the distributed training in pure PyTorch using the torch.nn.parallel.DistributedDataParallel API. We will:

  • Discuss distributed training in general and data parallelization in particular
  • Cover the relevant features of the torch.dist and DistributedDataParallel and show how they are used by example
  • And benchmark a real training script to see the time savings in action

You can follow along in code by checking out the companion GitHub repo.

What is distributed training?

Before we can dive into DistributedDataParallel, we first need to acquire some background knowledge about distributed training in general.

There are basically two different forms of distributed training in common use today: data parallelization and model parallelization.

In data parallelization, the model training job is split on the data. Each GPU in the job receives its own independent slice of the data batch, e.g. its own “batch slice”. Each GPU uses this data to independently calculate a gradient update. For example, if you were to use two GPUs and a batch size of 32, one GPU would handle forward and back propagation on the first 16 records, and the second the last 16. These gradient updates are then synchronized among the GPUs, averaged together, and finally applied to the model.

(the synchronization step is technically optional, but theoretically faster asynchronous update strategies are still an active area of research)

#distributed-systems #pytorch #neural-networks #neural networks

Distributed model training in PyTorch using DistributedDataParallel
2.10 GEEK