Distributed training across multiple computational resources within TensorFlow/Keras is implemented through the tf.distribute module. Tensorflow/Keras provides support for different strategies, depending on how one wants to distribute the computation and on what resources that will be distributed over. In this example, we will consider the use of Keras to carry out synchronous, data-parallel training across multiple GPUs attached to a single CPU node, as described in the tutorial on Distributed training with Keras. We will run this code example on the GPU nodes on Frontera.

The MNIST problem is revisited in the Keras distributed training tutorial, with a Jupyter notebook providing the source code and commentary. For this particular code example, there is not a separate link provided to the raw python source code, so you will want to download the notebook and extract the source code using the instructions that we have provided previously. (You can either view the Jupyter notebook in Github or download it directly from this link.) For the purposes of discussion below, we are saving the python source code from this notebook to a file named keras_distributed.py. We will not reproduce all the source code in this page, but will instead highlight some of the key features of the distributed multi-GPU code.

The key difference between this distributed TensorFlow/Keras code, and the single-GPU code that we considered previously, is that this code introduces a strategy for distributed computation:

strategy = tf.distributed.MirroredStrategy()
Python

A MirroredStrategy() is one that replicates, or mirrors, the model across multiple GPUs on a single machine or CPU node, in order to implement a data parallel training strategy. Each GPU contains a full copy of the model, but processes only part of the data, and an all-reduce method is used to combine the gradients to allow all the GPUs to update the model parameters in a consistent manner. Other possible strategies are discussed in the Distributed training with TensorFlow guide.

Once a MirroredStrategy is constructed, there are downstream impacts on the rest of the code. The strategy is able to determine how many GPUs are available, and in the default behavior, all those GPUs are used for the computation. (One can choose to use only a subset, however, if that is desirable.) Because the training data will be split across GPUs, the batch size needs to be appropriately modified. In the code linked above, this modification makes use of information that the strategy holds about the number of model replicas, which is then followed by an suitable subdivision of the training and evaluation datasets:

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
# ...skip some lines
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
Python

Finally, the last important impact of the introduction of the distributed strategy is that the model definition gets embedded within a context (using the python with keyword), which enables each of the replicas to hold a full copy of the model:

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
Python

Once this distributed strategy and model are set up, TensorFlow/Keras can take care of communicating information between the model replicas, and combining everything into a model ready for evaluation.

We are going to reuse the tf282 virtual environment that we created previously, along with the necessary Lmod modules. The mechanics of running this code on Frontera are the same as what we used in TensorFlow on Frontera, and we will just use the interactive idev-based method here rather than the slurm sbatch method, although you can certainly consult the previous instructions if you prefer the latter.

module load python3/3.9.2
module load cuda/11.3 cudnn nccl
module load phdf5
source tf282/bin/activate  # or point to the tf282 venv if it is elsewhere
idev -N 1 -n 4 -p rtx-dev
python3 keras_distributed.py
Plain text

The output from this run is reproduced at the bottom of this page (with some minor edits to adjust the formatting). We highlight a few points of interest. If for some reason you encounter errors when building the environment or running the code, consult the troubleshooting page to see if those errors are addressed.

Using all 4 GPUs

As intended, this distributed training code is now using all 4 of the GPUs attached to the Frontera rtx node that we have been allocated. You might want to ssh into your rtx node in a second terminal, in order to monitor the GPU utilization with the nvidia-smi command.

on_train_batch_end

You might see a warning like this:

WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to
the batch time (batch time: 0.0119s vs `on_train_batch_end` time: 0.0193s).
Check your callbacks.
Plain text

It is not obvious what this means, but if one searches the web for this type of warning, you will see that — for this problem, at least — breaking the dataset into batches of this size results in a slow-down of the program rather than a speed-up, due to the overhead of setting up the batches.

Suggested code modifications and experiments

Here are some suggestions about how you might explore this code example further:

  • Modifying batch size: In the keras_distributed.py code, increase the batch size to see if (a) the code runs faster and (b) the warning about the on_train_batch_end callback method is eliminated.
  • Try a different dataset, maybe with more data elements than the mnist problem. There is, for example, the emnist dataset, which is intended as a larger version of mnist-like data. If you decide to modify the keras_distributed.py code to handle another dataset, you might need to:
    1. Reconfigure some of the data loading and splitting code, depending on how it is packaged
    2. Modify the model definition to handle to number of class labels associated with the new dataset (currently hardwired at 10 for mnist). There might be an informational data structure linked to the dataset that you can query to get the number of class labels.
Output from keras_distributed.py
MPI startup(): PMI server not found. Please set I_MPI_PMI_LIBRARY variable if it is not a singleton case.
2.7.0
2022-05-27 08:16:33.296975: I tensorflow/core/platform/cpu_feature_guard.cc:151]
  This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)
  to use the following CPU instructions in performance-critical operations:  AVX2 FMA
  To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-27 08:16:42.054589: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525]
  Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14793 MB memory:  ->
  device: 0, name: Quadro RTX 5000, pci bus id: 0000:02:00.0, compute capability: 7.5
2022-05-27 08:16:42.081273: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525]
  Created device /job:localhost/replica:0/task:0/device:GPU:1 with 14793 MB memory:  ->
  device: 1, name: Quadro RTX 5000, pci bus id: 0000:03:00.0, compute capability: 7.5
2022-05-27 08:16:42.082654: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525]
  Created device /job:localhost/replica:0/task:0/device:GPU:2 with 14793 MB memory:  ->
  device: 2, name: Quadro RTX 5000, pci bus id: 0000:82:00.0, compute capability: 7.5
2022-05-27 08:16:42.084287: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525]
  Created device /job:localhost/replica:0/task:0/device:GPU:3 with 14793 MB memory:  ->
  device: 3, name: Quadro RTX 5000, pci bus id: 0000:83:00.0, compute capability: 7.5
Number of devices: 4
2022-05-27 08:16:45.555647: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:537]
  The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/12
2022-05-27 08:16:54.736253: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8201
2022-05-27 08:16:55.802513: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8201
2022-05-27 08:16:57.005261: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8201
2022-05-27 08:16:58.037802: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8201
  5/235 [..............................] - ETA: 3s - loss: 2.1001 - accuracy: 0.3219     \
  WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time
  (batch time: 0.0119s vs `on_train_batch_end` time: 0.0193s). Check your callbacks.
  WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time
  (batch time: 0.0119s vs `on_train_batch_end` time: 0.0193s). Check your callbacks.
235/235 [==============================] - ETA: 0s - loss: 0.3398 - accuracy: 0.9050
Learning rate for epoch 1 is 0.0010000000474974513
235/235 [==============================] - 17s 9ms/step - loss: 0.3398 - accuracy: 0.9050 - lr: 0.0010
Epoch 2/12
235/235 [==============================] - ETA: 0s - loss: 0.1047 - accuracy: 0.9706
Learning rate for epoch 2 is 0.0010000000474974513
235/235 [==============================] - 2s 6ms/step - loss: 0.1047 - accuracy: 0.9706 - lr: 0.0010
Epoch 3/12
227/235 [===========================>..] - ETA: 0s - loss: 0.0717 - accuracy: 0.9795
Learning rate for epoch 3 is 0.0010000000474974513
235/235 [==============================] - 2s 7ms/step - loss: 0.0713 - accuracy: 0.9796 - lr: 0.0010
Epoch 4/12
227/235 [===========================>..] - ETA: 0s - loss: 0.0504 - accuracy: 0.9864
Learning rate for epoch 4 is 9.999999747378752e-05
235/235 [==============================] - 2s 7ms/step - loss: 0.0501 - accuracy: 0.9865 - lr: 1.0000e-04
Epoch 5/12
231/235 [============================>.] - ETA: 0s - loss: 0.0473 - accuracy: 0.9876
Learning rate for epoch 5 is 9.999999747378752e-05
235/235 [==============================] - 2s 6ms/step - loss: 0.0471 - accuracy: 0.9876 - lr: 1.0000e-04
Epoch 6/12
233/235 [============================>.] - ETA: 0s - loss: 0.0456 - accuracy: 0.9878
Learning rate for epoch 6 is 9.999999747378752e-05
235/235 [==============================] - 1s 6ms/step - loss: 0.0454 - accuracy: 0.9879 - lr: 1.0000e-04
Epoch 7/12
233/235 [============================>.] - ETA: 0s - loss: 0.0440 - accuracy: 0.9883
Learning rate for epoch 7 is 9.999999747378752e-05
235/235 [==============================] - 2s 6ms/step - loss: 0.0439 - accuracy: 0.9883 - lr: 1.0000e-04
Epoch 8/12
233/235 [============================>.] - ETA: 0s - loss: 0.0417 - accuracy: 0.9889
Learning rate for epoch 8 is 9.999999747378752e-06
235/235 [==============================] - 2s 6ms/step - loss: 0.0417 - accuracy: 0.9889 - lr: 1.0000e-05
Epoch 9/12
230/235 [============================>.] - ETA: 0s - loss: 0.0415 - accuracy: 0.9890
Learning rate for epoch 9 is 9.999999747378752e-06
235/235 [==============================] - 2s 6ms/step - loss: 0.0415 - accuracy: 0.9890 - lr: 1.0000e-05
Epoch 10/12
231/235 [============================>.] - ETA: 0s - loss: 0.0410 - accuracy: 0.9892
Learning rate for epoch 10 is 9.999999747378752e-06
235/235 [==============================] - 2s 6ms/step - loss: 0.0413 - accuracy: 0.9891 - lr: 1.0000e-05
Epoch 11/12
232/235 [============================>.] - ETA: 0s - loss: 0.0410 - accuracy: 0.9892
Learning rate for epoch 11 is 9.999999747378752e-06
235/235 [==============================] - 2s 6ms/step - loss: 0.0411 - accuracy: 0.9892 - lr: 1.0000e-05
Epoch 12/12
229/235 [============================>.] - ETA: 0s - loss: 0.0406 - accuracy: 0.9892
Learning rate for epoch 12 is 9.999999747378752e-06
235/235 [==============================] - 2s 6ms/step - loss: 0.0410 - accuracy: 0.9892 - lr: 1.0000e-05
2022-05-27 08:17:23.972408: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:537]
  The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
40/40 [==============================] - 2s 8ms/step - loss: 0.0505 - accuracy: 0.9830
Eval loss: 0.05051109567284584, Eval accuracy: 0.9829999804496765
2022-05-27 08:17:26.653449: W tensorflow/python/util/util.cc:368]
  Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
40/40 [==============================] - 0s 2ms/step - loss: 0.0505 - accuracy: 0.9830
Eval loss: 0.05051109194755554, Eval Accuracy: 0.9829999804496765
2022-05-27 08:17:28.444708: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:537]
  The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
40/40 [==============================] - 2s 5ms/step - loss: 0.0505 - accuracy: 0.9830
Eval loss: 0.05051109567284584, Eval Accuracy: 0.9829999804496765
Plain text
 
© 2025  |   Cornell University    |   Center for Advanced Computing    |   Copyright Statement    |   Access Statement