MNIST Example
Let’s start by creating code that trains a classifier for the MNIST dataset. We will then modify this code to run in parallel. For simplicity we will leave out code that evalautes our model with testing data.
Non-Distributed Code
Get Data
In the function below we download the MNIST dataset and pass it to a dataloader.
Next, let’s visualize a few images from the MNIST dataset. If you are unfamiliar with the MNIST data set you can learn more about it here.

Build Network
Next, we build a network that will be used to train our model.
Train Model
Below we create two functions. The first function, called train_loop
, performs an epoch in the training process. The second function, called main
, does everything we need to train a model: download and set up a dataloader, instantiate our model, loss and optimizer, and finally run multiple epochs by calling the train_loop
function.
Finally, let’s train our model by calling the main
function.
Distributed Code for Multiple GPUs on One Node
Note we will re-use the code from above and modify it to use PyTorch’s DDP. As we mentioned previously there are five main modifications needed to run DDP:
- Create a process group
- Use PyTorch’s DistributedSampler to ensure that data passed to each GPU is different
- Wrap Model with PyTorch’s DistributedDataParallel
- Modify Training Loop to write model from one GPU
- Close process group
The modifications needed for the five changes highlighted above are visually denoted with two lines of #
. Note, we re-use the class for Net
defined in the serial version above. Note that in the serial code we use the variable device
to refer to the GPU or CPU we are using to run the code. In the distributed implementation we will use the variables local_rank
and world_size
where:
local_rank
: the device id of a GPU on one nodeworld_size
: the number of GPUs on one node
Note, world_size
will change when we use multiple nodes later on in this course.
To run this code in parallel we will utilize the torch.multiprocessing
package which is a wrapper around Python’s native multiprocessing module. In particular, we will use the spawn
method. Spawn creates new processes from the parent process but will only inherit the resources necessary to run the run()
method. Below we highlight the code used to execute training in a Python script called mnist_parallel.py
that we will run to execute the parallel implementation of this code.
Finally we can run train the MNIST classifier using DDP.
CVW material development is supported by NSF OAC awards 1854828, 2321040, 2323116 (UT Austin) and 2005506 (Indiana University)