Training Keras with the SLURM Scheduler#

This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model and submit a training job in HPC environments. Although NJIT is quoted the approach must be the same for NYU.

Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook
import tensorflow as tf
import tensorflow_datasets as tfds
2023-10-03 09:29:30.258272: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-03 09:29:30.258321: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-03 09:29:30.258358: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Set 0: VPN and Login#

You need to install a VPN client and establish a VPN connection to NJIT data center. Consult https://ist.njit.edu/vpn for help in doing so.

# select from the two options below and ssh into the HPC server
ssh ucid@HPC_HOST.njit.edu 
ssh ucid@wulver.njit.edu

Step 1: Create your input pipeline#

Start by building an efficient input pipeline using advices from:

Load a dataset#

Load the MNIST dataset with the following arguments:

  • shuffle_files=True: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it’s good practice to shuffle them when training.

  • as_supervised=True: Returns a tuple (img, label) instead of a dictionary {'image': img, 'label': label}.

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2023-10-03 09:29:33.682941: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Upload dataset to server#

command: scp -r /path_of_file_locally ucid@HPC_HOST.njit.edu:/path_in_server or scp -r /path_of_file_locally ucid@wulver.njit.edu:/path_in_server

Build a training pipeline#

Apply the following transformations:

  • tf.data.Dataset.map: TFDS provide images of type tf.uint8, while the model expects tf.float32. Therefore, you need to normalize images.

  • tf.data.Dataset.cache As you fit the dataset in memory, cache it before shuffling for a better performance.
    Note: Random transformations should be applied after caching.

  • tf.data.Dataset.shuffle: For true randomness, set the shuffle buffer to the full dataset size.
    Note: For large datasets that can’t fit in memory, use buffer_size=1000 if your system allows it.

  • tf.data.Dataset.batch: Batch elements of the dataset after shuffling to get unique batches at each epoch.

  • tf.data.Dataset.prefetch: It is good practice to end the pipeline by prefetching for performance.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Build an evaluation pipeline#

Your testing pipeline is similar to the training pipeline with small differences:

  • You don’t need to call tf.data.Dataset.shuffle.

  • Caching is done after batching because batches can be the same between epochs.

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Run code on server#

create a .sh file first and here is a sample .sh file which runs your_job_name in general partition, you might need to change nodelist,nodes,gres,mem in your .sh file. If you want to comment something in .sh file, use several # to make sure it is comment, such as #####your comment.

#!/bin/bash -1

#SBATCH --job-name=your_job_name
#SBATCH --output=xxxx.out
#SBATCH --partition=general
#SBATCH --nodelist=node816
#SBATCH --nodes=1
#SBATCH --gres=gpu:1
#SBATCH --mem=16G

Use: sinfo on terminal to check if the job is run successful on the server

Use: squeue -p gpu to check usage of gpu

It is hard to visualize things in server so make sure you save the output of everything into a file in server and download into your local enviroment to visualize if needed.

command: scp ucid@HPC_HOST.njit.edu:/path_in_server /path_of_file_locally or scp ucid@wulver.njit.edu:/path_in_server /path_of_file_locally

Step 2: Create and train the model#

Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6

  1/469 [..............................] - ETA: 16:54 - loss: 2.5053 - sparse_categorical_accuracy: 0.0703
 20/469 [>.............................] - ETA: 1s - loss: 1.6709 - sparse_categorical_accuracy: 0.5254   
 41/469 [=>............................] - ETA: 1s - loss: 1.1971 - sparse_categorical_accuracy: 0.6812
 62/469 [==>...........................] - ETA: 1s - loss: 0.9548 - sparse_categorical_accuracy: 0.7452
 84/469 [====>.........................] - ETA: 0s - loss: 0.8144 - sparse_categorical_accuracy: 0.7816
106/469 [=====>........................] - ETA: 0s - loss: 0.7238 - sparse_categorical_accuracy: 0.8053
128/469 [=======>......................] - ETA: 0s - loss: 0.6602 - sparse_categorical_accuracy: 0.8217
150/469 [========>.....................] - ETA: 0s - loss: 0.6103 - sparse_categorical_accuracy: 0.8352
173/469 [==========>...................] - ETA: 0s - loss: 0.5696 - sparse_categorical_accuracy: 0.8461
196/469 [===========>..................] - ETA: 0s - loss: 0.5361 - sparse_categorical_accuracy: 0.8550
218/469 [============>.................] - ETA: 0s - loss: 0.5094 - sparse_categorical_accuracy: 0.8617
241/469 [==============>...............] - ETA: 0s - loss: 0.4872 - sparse_categorical_accuracy: 0.8675
264/469 [===============>..............] - ETA: 0s - loss: 0.4659 - sparse_categorical_accuracy: 0.8735
287/469 [=================>............] - ETA: 0s - loss: 0.4483 - sparse_categorical_accuracy: 0.8782
310/469 [==================>...........] - ETA: 0s - loss: 0.4336 - sparse_categorical_accuracy: 0.8823
332/469 [====================>.........] - ETA: 0s - loss: 0.4211 - sparse_categorical_accuracy: 0.8856
354/469 [=====================>........] - ETA: 0s - loss: 0.4087 - sparse_categorical_accuracy: 0.8890
376/469 [=======================>......] - ETA: 0s - loss: 0.3992 - sparse_categorical_accuracy: 0.8914
398/469 [========================>.....] - ETA: 0s - loss: 0.3887 - sparse_categorical_accuracy: 0.8942
421/469 [=========================>....] - ETA: 0s - loss: 0.3791 - sparse_categorical_accuracy: 0.8966
444/469 [===========================>..] - ETA: 0s - loss: 0.3716 - sparse_categorical_accuracy: 0.8985
467/469 [============================>.] - ETA: 0s - loss: 0.3625 - sparse_categorical_accuracy: 0.9009
469/469 [==============================] - 4s 4ms/step - loss: 0.3621 - sparse_categorical_accuracy: 0.9011 - val_loss: 0.1925 - val_sparse_categorical_accuracy: 0.9463
Epoch 2/6

  1/469 [..............................] - ETA: 35s - loss: 0.1062 - sparse_categorical_accuracy: 0.9766
 23/469 [>.............................] - ETA: 1s - loss: 0.1855 - sparse_categorical_accuracy: 0.9467 
 45/469 [=>............................] - ETA: 0s - loss: 0.1784 - sparse_categorical_accuracy: 0.9497
 67/469 [===>..........................] - ETA: 0s - loss: 0.1805 - sparse_categorical_accuracy: 0.9488
 89/469 [====>.........................] - ETA: 0s - loss: 0.1765 - sparse_categorical_accuracy: 0.9506
111/469 [======>.......................] - ETA: 0s - loss: 0.1760 - sparse_categorical_accuracy: 0.9505
133/469 [=======>......................] - ETA: 0s - loss: 0.1784 - sparse_categorical_accuracy: 0.9496
155/469 [========>.....................] - ETA: 0s - loss: 0.1765 - sparse_categorical_accuracy: 0.9494
178/469 [==========>...................] - ETA: 0s - loss: 0.1737 - sparse_categorical_accuracy: 0.9505
200/469 [===========>..................] - ETA: 0s - loss: 0.1710 - sparse_categorical_accuracy: 0.9512
222/469 [=============>................] - ETA: 0s - loss: 0.1716 - sparse_categorical_accuracy: 0.9512
244/469 [==============>...............] - ETA: 0s - loss: 0.1695 - sparse_categorical_accuracy: 0.9521
266/469 [================>.............] - ETA: 0s - loss: 0.1676 - sparse_categorical_accuracy: 0.9525
288/469 [=================>............] - ETA: 0s - loss: 0.1670 - sparse_categorical_accuracy: 0.9526
310/469 [==================>...........] - ETA: 0s - loss: 0.1657 - sparse_categorical_accuracy: 0.9528
333/469 [====================>.........] - ETA: 0s - loss: 0.1653 - sparse_categorical_accuracy: 0.9528
355/469 [=====================>........] - ETA: 0s - loss: 0.1644 - sparse_categorical_accuracy: 0.9529
377/469 [=======================>......] - ETA: 0s - loss: 0.1643 - sparse_categorical_accuracy: 0.9530
400/469 [========================>.....] - ETA: 0s - loss: 0.1633 - sparse_categorical_accuracy: 0.9532
423/469 [==========================>...] - ETA: 0s - loss: 0.1616 - sparse_categorical_accuracy: 0.9538
446/469 [===========================>..] - ETA: 0s - loss: 0.1606 - sparse_categorical_accuracy: 0.9542
469/469 [==============================] - ETA: 0s - loss: 0.1602 - sparse_categorical_accuracy: 0.9543
469/469 [==============================] - 1s 3ms/step - loss: 0.1602 - sparse_categorical_accuracy: 0.9543 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9588
Epoch 3/6

  1/469 [..............................] - ETA: 32s - loss: 0.1546 - sparse_categorical_accuracy: 0.9609
 23/469 [>.............................] - ETA: 1s - loss: 0.1243 - sparse_categorical_accuracy: 0.9640 
 45/469 [=>............................] - ETA: 0s - loss: 0.1216 - sparse_categorical_accuracy: 0.9648
 67/469 [===>..........................] - ETA: 0s - loss: 0.1272 - sparse_categorical_accuracy: 0.9628
 90/469 [====>.........................] - ETA: 0s - loss: 0.1230 - sparse_categorical_accuracy: 0.9653
113/469 [======>.......................] - ETA: 0s - loss: 0.1236 - sparse_categorical_accuracy: 0.9649
136/469 [=======>......................] - ETA: 0s - loss: 0.1238 - sparse_categorical_accuracy: 0.9649
159/469 [=========>....................] - ETA: 0s - loss: 0.1227 - sparse_categorical_accuracy: 0.9651
182/469 [==========>...................] - ETA: 0s - loss: 0.1205 - sparse_categorical_accuracy: 0.9658
205/469 [============>.................] - ETA: 0s - loss: 0.1186 - sparse_categorical_accuracy: 0.9664
227/469 [=============>................] - ETA: 0s - loss: 0.1192 - sparse_categorical_accuracy: 0.9661
249/469 [==============>...............] - ETA: 0s - loss: 0.1189 - sparse_categorical_accuracy: 0.9661
271/469 [================>.............] - ETA: 0s - loss: 0.1182 - sparse_categorical_accuracy: 0.9662
293/469 [=================>............] - ETA: 0s - loss: 0.1178 - sparse_categorical_accuracy: 0.9666
316/469 [===================>..........] - ETA: 0s - loss: 0.1185 - sparse_categorical_accuracy: 0.9662
338/469 [====================>.........] - ETA: 0s - loss: 0.1186 - sparse_categorical_accuracy: 0.9662
361/469 [======================>.......] - ETA: 0s - loss: 0.1185 - sparse_categorical_accuracy: 0.9660
384/469 [=======================>......] - ETA: 0s - loss: 0.1177 - sparse_categorical_accuracy: 0.9661
407/469 [=========================>....] - ETA: 0s - loss: 0.1178 - sparse_categorical_accuracy: 0.9663
430/469 [==========================>...] - ETA: 0s - loss: 0.1177 - sparse_categorical_accuracy: 0.9662
453/469 [===========================>..] - ETA: 0s - loss: 0.1173 - sparse_categorical_accuracy: 0.9664
469/469 [==============================] - 1s 2ms/step - loss: 0.1174 - sparse_categorical_accuracy: 0.9664 - val_loss: 0.1084 - val_sparse_categorical_accuracy: 0.9693
Epoch 4/6

  1/469 [..............................] - ETA: 32s - loss: 0.0986 - sparse_categorical_accuracy: 0.9688
 24/469 [>.............................] - ETA: 0s - loss: 0.1070 - sparse_categorical_accuracy: 0.9684 
 47/469 [==>...........................] - ETA: 0s - loss: 0.1030 - sparse_categorical_accuracy: 0.9714
 70/469 [===>..........................] - ETA: 0s - loss: 0.0953 - sparse_categorical_accuracy: 0.9733
 93/469 [====>.........................] - ETA: 0s - loss: 0.0948 - sparse_categorical_accuracy: 0.9733
116/469 [======>.......................] - ETA: 0s - loss: 0.0920 - sparse_categorical_accuracy: 0.9744
139/469 [=======>......................] - ETA: 0s - loss: 0.0917 - sparse_categorical_accuracy: 0.9747
161/469 [=========>....................] - ETA: 0s - loss: 0.0920 - sparse_categorical_accuracy: 0.9746
184/469 [==========>...................] - ETA: 0s - loss: 0.0915 - sparse_categorical_accuracy: 0.9748
206/469 [============>.................] - ETA: 0s - loss: 0.0912 - sparse_categorical_accuracy: 0.9745
228/469 [=============>................] - ETA: 0s - loss: 0.0913 - sparse_categorical_accuracy: 0.9746
251/469 [===============>..............] - ETA: 0s - loss: 0.0923 - sparse_categorical_accuracy: 0.9744
273/469 [================>.............] - ETA: 0s - loss: 0.0920 - sparse_categorical_accuracy: 0.9746
295/469 [=================>............] - ETA: 0s - loss: 0.0923 - sparse_categorical_accuracy: 0.9742
317/469 [===================>..........] - ETA: 0s - loss: 0.0929 - sparse_categorical_accuracy: 0.9740
339/469 [====================>.........] - ETA: 0s - loss: 0.0931 - sparse_categorical_accuracy: 0.9738
361/469 [======================>.......] - ETA: 0s - loss: 0.0924 - sparse_categorical_accuracy: 0.9740
383/469 [=======================>......] - ETA: 0s - loss: 0.0923 - sparse_categorical_accuracy: 0.9741
405/469 [========================>.....] - ETA: 0s - loss: 0.0921 - sparse_categorical_accuracy: 0.9741
428/469 [==========================>...] - ETA: 0s - loss: 0.0917 - sparse_categorical_accuracy: 0.9742
451/469 [===========================>..] - ETA: 0s - loss: 0.0913 - sparse_categorical_accuracy: 0.9743
469/469 [==============================] - 1s 3ms/step - loss: 0.0911 - sparse_categorical_accuracy: 0.9743 - val_loss: 0.0968 - val_sparse_categorical_accuracy: 0.9714
Epoch 5/6

  1/469 [..............................] - ETA: 32s - loss: 0.0625 - sparse_categorical_accuracy: 0.9922
 23/469 [>.............................] - ETA: 1s - loss: 0.0698 - sparse_categorical_accuracy: 0.9820 
 45/469 [=>............................] - ETA: 0s - loss: 0.0682 - sparse_categorical_accuracy: 0.9811
 68/469 [===>..........................] - ETA: 0s - loss: 0.0707 - sparse_categorical_accuracy: 0.9805
 90/469 [====>.........................] - ETA: 0s - loss: 0.0729 - sparse_categorical_accuracy: 0.9793
112/469 [======>.......................] - ETA: 0s - loss: 0.0734 - sparse_categorical_accuracy: 0.9794
135/469 [=======>......................] - ETA: 0s - loss: 0.0752 - sparse_categorical_accuracy: 0.9791
158/469 [=========>....................] - ETA: 0s - loss: 0.0742 - sparse_categorical_accuracy: 0.9794
181/469 [==========>...................] - ETA: 0s - loss: 0.0738 - sparse_categorical_accuracy: 0.9795
204/469 [============>.................] - ETA: 0s - loss: 0.0728 - sparse_categorical_accuracy: 0.9797
227/469 [=============>................] - ETA: 0s - loss: 0.0727 - sparse_categorical_accuracy: 0.9797
249/469 [==============>...............] - ETA: 0s - loss: 0.0715 - sparse_categorical_accuracy: 0.9802
271/469 [================>.............] - ETA: 0s - loss: 0.0704 - sparse_categorical_accuracy: 0.9806
293/469 [=================>............] - ETA: 0s - loss: 0.0705 - sparse_categorical_accuracy: 0.9805
315/469 [===================>..........] - ETA: 0s - loss: 0.0709 - sparse_categorical_accuracy: 0.9804
337/469 [====================>.........] - ETA: 0s - loss: 0.0708 - sparse_categorical_accuracy: 0.9802
360/469 [======================>.......] - ETA: 0s - loss: 0.0720 - sparse_categorical_accuracy: 0.9797
383/469 [=======================>......] - ETA: 0s - loss: 0.0729 - sparse_categorical_accuracy: 0.9791
406/469 [========================>.....] - ETA: 0s - loss: 0.0727 - sparse_categorical_accuracy: 0.9792
429/469 [==========================>...] - ETA: 0s - loss: 0.0731 - sparse_categorical_accuracy: 0.9791
452/469 [===========================>..] - ETA: 0s - loss: 0.0735 - sparse_categorical_accuracy: 0.9791
469/469 [==============================] - 1s 2ms/step - loss: 0.0738 - sparse_categorical_accuracy: 0.9790 - val_loss: 0.0881 - val_sparse_categorical_accuracy: 0.9735
Epoch 6/6

  1/469 [..............................] - ETA: 32s - loss: 0.0219 - sparse_categorical_accuracy: 1.0000
 24/469 [>.............................] - ETA: 1s - loss: 0.0620 - sparse_categorical_accuracy: 0.9821 
 47/469 [==>...........................] - ETA: 0s - loss: 0.0629 - sparse_categorical_accuracy: 0.9830
 70/469 [===>..........................] - ETA: 0s - loss: 0.0640 - sparse_categorical_accuracy: 0.9830
 93/469 [====>.........................] - ETA: 0s - loss: 0.0619 - sparse_categorical_accuracy: 0.9830
116/469 [======>.......................] - ETA: 0s - loss: 0.0614 - sparse_categorical_accuracy: 0.9829
139/469 [=======>......................] - ETA: 0s - loss: 0.0620 - sparse_categorical_accuracy: 0.9823
162/469 [=========>....................] - ETA: 0s - loss: 0.0632 - sparse_categorical_accuracy: 0.9820
185/469 [==========>...................] - ETA: 0s - loss: 0.0630 - sparse_categorical_accuracy: 0.9821
208/469 [============>.................] - ETA: 0s - loss: 0.0624 - sparse_categorical_accuracy: 0.9823
231/469 [=============>................] - ETA: 0s - loss: 0.0634 - sparse_categorical_accuracy: 0.9821
254/469 [===============>..............] - ETA: 0s - loss: 0.0630 - sparse_categorical_accuracy: 0.9822
277/469 [================>.............] - ETA: 0s - loss: 0.0634 - sparse_categorical_accuracy: 0.9819
300/469 [==================>...........] - ETA: 0s - loss: 0.0633 - sparse_categorical_accuracy: 0.9821
322/469 [===================>..........] - ETA: 0s - loss: 0.0633 - sparse_categorical_accuracy: 0.9821
345/469 [=====================>........] - ETA: 0s - loss: 0.0627 - sparse_categorical_accuracy: 0.9821
368/469 [======================>.......] - ETA: 0s - loss: 0.0628 - sparse_categorical_accuracy: 0.9820
390/469 [=======================>......] - ETA: 0s - loss: 0.0625 - sparse_categorical_accuracy: 0.9822
413/469 [=========================>....] - ETA: 0s - loss: 0.0622 - sparse_categorical_accuracy: 0.9823
435/469 [==========================>...] - ETA: 0s - loss: 0.0618 - sparse_categorical_accuracy: 0.9823
457/469 [============================>.] - ETA: 0s - loss: 0.0618 - sparse_categorical_accuracy: 0.9823
469/469 [==============================] - 1s 2ms/step - loss: 0.0617 - sparse_categorical_accuracy: 0.9823 - val_loss: 0.0793 - val_sparse_categorical_accuracy: 0.9749
<keras.src.callbacks.History at 0x7fc41e0cb880>