Training Keras with the SLURM Scheduler#
This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model and submit a training job in HPC environments. Although NJIT is quoted the approach must be the same for NYU.
Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0
View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
import tensorflow as tf
import tensorflow_datasets as tfds
2023-10-03 09:29:30.258272: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-03 09:29:30.258321: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-03 09:29:30.258358: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Set 0: VPN and Login#
You need to install a VPN client and establish a VPN connection to NJIT data center. Consult https://ist.njit.edu/vpn for help in doing so.
# select from the two options below and ssh into the HPC server
ssh ucid@HPC_HOST.njit.edu
ssh ucid@wulver.njit.edu
Step 1: Create your input pipeline#
Start by building an efficient input pipeline using advices from:
The Performance tips guide
The Better performance with the
tf.data
API guide
Load a dataset#
Load the MNIST dataset with the following arguments:
shuffle_files=True
: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it’s good practice to shuffle them when training.as_supervised=True
: Returns a tuple(img, label)
instead of a dictionary{'image': img, 'label': label}
.
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
2023-10-03 09:29:33.682941: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Upload dataset to server#
command: scp -r /path_of_file_locally ucid@HPC_HOST.njit.edu:/path_in_server or scp -r /path_of_file_locally ucid@wulver.njit.edu:/path_in_server
Build a training pipeline#
Apply the following transformations:
tf.data.Dataset.map
: TFDS provide images of typetf.uint8
, while the model expectstf.float32
. Therefore, you need to normalize images.tf.data.Dataset.cache
As you fit the dataset in memory, cache it before shuffling for a better performance.
Note: Random transformations should be applied after caching.tf.data.Dataset.shuffle
: For true randomness, set the shuffle buffer to the full dataset size.
Note: For large datasets that can’t fit in memory, usebuffer_size=1000
if your system allows it.tf.data.Dataset.batch
: Batch elements of the dataset after shuffling to get unique batches at each epoch.tf.data.Dataset.prefetch
: It is good practice to end the pipeline by prefetching for performance.
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
Build an evaluation pipeline#
Your testing pipeline is similar to the training pipeline with small differences:
You don’t need to call
tf.data.Dataset.shuffle
.Caching is done after batching because batches can be the same between epochs.
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)
Run code on server#
create a .sh file first and here is a sample .sh file which runs your_job_name in general partition, you might need to change nodelist,nodes,gres,mem in your .sh file. If you want to comment something in .sh file, use several # to make sure it is comment, such as #####your comment.
#!/bin/bash -1
#SBATCH --job-name=your_job_name
#SBATCH --output=xxxx.out
#SBATCH --partition=general
#SBATCH --nodelist=node816
#SBATCH --nodes=1
#SBATCH --gres=gpu:1
#SBATCH --mem=16G
Use: sinfo on terminal to check if the job is run successful on the server
Use: squeue -p gpu to check usage of gpu
It is hard to visualize things in server so make sure you save the output of everything into a file in server and download into your local enviroment to visualize if needed.
command: scp ucid@HPC_HOST.njit.edu:/path_in_server /path_of_file_locally or scp ucid@wulver.njit.edu:/path_in_server /path_of_file_locally
Step 2: Create and train the model#
Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model.fit(
ds_train,
epochs=6,
validation_data=ds_test,
)
Epoch 1/6
1/469 [..............................] - ETA: 16:54 - loss: 2.5053 - sparse_categorical_accuracy: 0.0703
20/469 [>.............................] - ETA: 1s - loss: 1.6709 - sparse_categorical_accuracy: 0.5254
41/469 [=>............................] - ETA: 1s - loss: 1.1971 - sparse_categorical_accuracy: 0.6812
62/469 [==>...........................] - ETA: 1s - loss: 0.9548 - sparse_categorical_accuracy: 0.7452
84/469 [====>.........................] - ETA: 0s - loss: 0.8144 - sparse_categorical_accuracy: 0.7816
106/469 [=====>........................] - ETA: 0s - loss: 0.7238 - sparse_categorical_accuracy: 0.8053
128/469 [=======>......................] - ETA: 0s - loss: 0.6602 - sparse_categorical_accuracy: 0.8217
150/469 [========>.....................] - ETA: 0s - loss: 0.6103 - sparse_categorical_accuracy: 0.8352
173/469 [==========>...................] - ETA: 0s - loss: 0.5696 - sparse_categorical_accuracy: 0.8461
196/469 [===========>..................] - ETA: 0s - loss: 0.5361 - sparse_categorical_accuracy: 0.8550
218/469 [============>.................] - ETA: 0s - loss: 0.5094 - sparse_categorical_accuracy: 0.8617
241/469 [==============>...............] - ETA: 0s - loss: 0.4872 - sparse_categorical_accuracy: 0.8675
264/469 [===============>..............] - ETA: 0s - loss: 0.4659 - sparse_categorical_accuracy: 0.8735
287/469 [=================>............] - ETA: 0s - loss: 0.4483 - sparse_categorical_accuracy: 0.8782
310/469 [==================>...........] - ETA: 0s - loss: 0.4336 - sparse_categorical_accuracy: 0.8823
332/469 [====================>.........] - ETA: 0s - loss: 0.4211 - sparse_categorical_accuracy: 0.8856
354/469 [=====================>........] - ETA: 0s - loss: 0.4087 - sparse_categorical_accuracy: 0.8890
376/469 [=======================>......] - ETA: 0s - loss: 0.3992 - sparse_categorical_accuracy: 0.8914
398/469 [========================>.....] - ETA: 0s - loss: 0.3887 - sparse_categorical_accuracy: 0.8942
421/469 [=========================>....] - ETA: 0s - loss: 0.3791 - sparse_categorical_accuracy: 0.8966
444/469 [===========================>..] - ETA: 0s - loss: 0.3716 - sparse_categorical_accuracy: 0.8985
467/469 [============================>.] - ETA: 0s - loss: 0.3625 - sparse_categorical_accuracy: 0.9009
469/469 [==============================] - 4s 4ms/step - loss: 0.3621 - sparse_categorical_accuracy: 0.9011 - val_loss: 0.1925 - val_sparse_categorical_accuracy: 0.9463
Epoch 2/6
1/469 [..............................] - ETA: 35s - loss: 0.1062 - sparse_categorical_accuracy: 0.9766
23/469 [>.............................] - ETA: 1s - loss: 0.1855 - sparse_categorical_accuracy: 0.9467
45/469 [=>............................] - ETA: 0s - loss: 0.1784 - sparse_categorical_accuracy: 0.9497
67/469 [===>..........................] - ETA: 0s - loss: 0.1805 - sparse_categorical_accuracy: 0.9488
89/469 [====>.........................] - ETA: 0s - loss: 0.1765 - sparse_categorical_accuracy: 0.9506
111/469 [======>.......................] - ETA: 0s - loss: 0.1760 - sparse_categorical_accuracy: 0.9505
133/469 [=======>......................] - ETA: 0s - loss: 0.1784 - sparse_categorical_accuracy: 0.9496
155/469 [========>.....................] - ETA: 0s - loss: 0.1765 - sparse_categorical_accuracy: 0.9494
178/469 [==========>...................] - ETA: 0s - loss: 0.1737 - sparse_categorical_accuracy: 0.9505
200/469 [===========>..................] - ETA: 0s - loss: 0.1710 - sparse_categorical_accuracy: 0.9512
222/469 [=============>................] - ETA: 0s - loss: 0.1716 - sparse_categorical_accuracy: 0.9512
244/469 [==============>...............] - ETA: 0s - loss: 0.1695 - sparse_categorical_accuracy: 0.9521
266/469 [================>.............] - ETA: 0s - loss: 0.1676 - sparse_categorical_accuracy: 0.9525
288/469 [=================>............] - ETA: 0s - loss: 0.1670 - sparse_categorical_accuracy: 0.9526
310/469 [==================>...........] - ETA: 0s - loss: 0.1657 - sparse_categorical_accuracy: 0.9528
333/469 [====================>.........] - ETA: 0s - loss: 0.1653 - sparse_categorical_accuracy: 0.9528
355/469 [=====================>........] - ETA: 0s - loss: 0.1644 - sparse_categorical_accuracy: 0.9529
377/469 [=======================>......] - ETA: 0s - loss: 0.1643 - sparse_categorical_accuracy: 0.9530
400/469 [========================>.....] - ETA: 0s - loss: 0.1633 - sparse_categorical_accuracy: 0.9532
423/469 [==========================>...] - ETA: 0s - loss: 0.1616 - sparse_categorical_accuracy: 0.9538
446/469 [===========================>..] - ETA: 0s - loss: 0.1606 - sparse_categorical_accuracy: 0.9542
469/469 [==============================] - ETA: 0s - loss: 0.1602 - sparse_categorical_accuracy: 0.9543
469/469 [==============================] - 1s 3ms/step - loss: 0.1602 - sparse_categorical_accuracy: 0.9543 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9588
Epoch 3/6
1/469 [..............................] - ETA: 32s - loss: 0.1546 - sparse_categorical_accuracy: 0.9609
23/469 [>.............................] - ETA: 1s - loss: 0.1243 - sparse_categorical_accuracy: 0.9640
45/469 [=>............................] - ETA: 0s - loss: 0.1216 - sparse_categorical_accuracy: 0.9648
67/469 [===>..........................] - ETA: 0s - loss: 0.1272 - sparse_categorical_accuracy: 0.9628
90/469 [====>.........................] - ETA: 0s - loss: 0.1230 - sparse_categorical_accuracy: 0.9653
113/469 [======>.......................] - ETA: 0s - loss: 0.1236 - sparse_categorical_accuracy: 0.9649
136/469 [=======>......................] - ETA: 0s - loss: 0.1238 - sparse_categorical_accuracy: 0.9649
159/469 [=========>....................] - ETA: 0s - loss: 0.1227 - sparse_categorical_accuracy: 0.9651
182/469 [==========>...................] - ETA: 0s - loss: 0.1205 - sparse_categorical_accuracy: 0.9658
205/469 [============>.................] - ETA: 0s - loss: 0.1186 - sparse_categorical_accuracy: 0.9664
227/469 [=============>................] - ETA: 0s - loss: 0.1192 - sparse_categorical_accuracy: 0.9661
249/469 [==============>...............] - ETA: 0s - loss: 0.1189 - sparse_categorical_accuracy: 0.9661
271/469 [================>.............] - ETA: 0s - loss: 0.1182 - sparse_categorical_accuracy: 0.9662
293/469 [=================>............] - ETA: 0s - loss: 0.1178 - sparse_categorical_accuracy: 0.9666
316/469 [===================>..........] - ETA: 0s - loss: 0.1185 - sparse_categorical_accuracy: 0.9662
338/469 [====================>.........] - ETA: 0s - loss: 0.1186 - sparse_categorical_accuracy: 0.9662
361/469 [======================>.......] - ETA: 0s - loss: 0.1185 - sparse_categorical_accuracy: 0.9660
384/469 [=======================>......] - ETA: 0s - loss: 0.1177 - sparse_categorical_accuracy: 0.9661
407/469 [=========================>....] - ETA: 0s - loss: 0.1178 - sparse_categorical_accuracy: 0.9663
430/469 [==========================>...] - ETA: 0s - loss: 0.1177 - sparse_categorical_accuracy: 0.9662
453/469 [===========================>..] - ETA: 0s - loss: 0.1173 - sparse_categorical_accuracy: 0.9664
469/469 [==============================] - 1s 2ms/step - loss: 0.1174 - sparse_categorical_accuracy: 0.9664 - val_loss: 0.1084 - val_sparse_categorical_accuracy: 0.9693
Epoch 4/6
1/469 [..............................] - ETA: 32s - loss: 0.0986 - sparse_categorical_accuracy: 0.9688
24/469 [>.............................] - ETA: 0s - loss: 0.1070 - sparse_categorical_accuracy: 0.9684
47/469 [==>...........................] - ETA: 0s - loss: 0.1030 - sparse_categorical_accuracy: 0.9714
70/469 [===>..........................] - ETA: 0s - loss: 0.0953 - sparse_categorical_accuracy: 0.9733
93/469 [====>.........................] - ETA: 0s - loss: 0.0948 - sparse_categorical_accuracy: 0.9733
116/469 [======>.......................] - ETA: 0s - loss: 0.0920 - sparse_categorical_accuracy: 0.9744
139/469 [=======>......................] - ETA: 0s - loss: 0.0917 - sparse_categorical_accuracy: 0.9747
161/469 [=========>....................] - ETA: 0s - loss: 0.0920 - sparse_categorical_accuracy: 0.9746
184/469 [==========>...................] - ETA: 0s - loss: 0.0915 - sparse_categorical_accuracy: 0.9748
206/469 [============>.................] - ETA: 0s - loss: 0.0912 - sparse_categorical_accuracy: 0.9745
228/469 [=============>................] - ETA: 0s - loss: 0.0913 - sparse_categorical_accuracy: 0.9746
251/469 [===============>..............] - ETA: 0s - loss: 0.0923 - sparse_categorical_accuracy: 0.9744
273/469 [================>.............] - ETA: 0s - loss: 0.0920 - sparse_categorical_accuracy: 0.9746
295/469 [=================>............] - ETA: 0s - loss: 0.0923 - sparse_categorical_accuracy: 0.9742
317/469 [===================>..........] - ETA: 0s - loss: 0.0929 - sparse_categorical_accuracy: 0.9740
339/469 [====================>.........] - ETA: 0s - loss: 0.0931 - sparse_categorical_accuracy: 0.9738
361/469 [======================>.......] - ETA: 0s - loss: 0.0924 - sparse_categorical_accuracy: 0.9740
383/469 [=======================>......] - ETA: 0s - loss: 0.0923 - sparse_categorical_accuracy: 0.9741
405/469 [========================>.....] - ETA: 0s - loss: 0.0921 - sparse_categorical_accuracy: 0.9741
428/469 [==========================>...] - ETA: 0s - loss: 0.0917 - sparse_categorical_accuracy: 0.9742
451/469 [===========================>..] - ETA: 0s - loss: 0.0913 - sparse_categorical_accuracy: 0.9743
469/469 [==============================] - 1s 3ms/step - loss: 0.0911 - sparse_categorical_accuracy: 0.9743 - val_loss: 0.0968 - val_sparse_categorical_accuracy: 0.9714
Epoch 5/6
1/469 [..............................] - ETA: 32s - loss: 0.0625 - sparse_categorical_accuracy: 0.9922
23/469 [>.............................] - ETA: 1s - loss: 0.0698 - sparse_categorical_accuracy: 0.9820
45/469 [=>............................] - ETA: 0s - loss: 0.0682 - sparse_categorical_accuracy: 0.9811
68/469 [===>..........................] - ETA: 0s - loss: 0.0707 - sparse_categorical_accuracy: 0.9805
90/469 [====>.........................] - ETA: 0s - loss: 0.0729 - sparse_categorical_accuracy: 0.9793
112/469 [======>.......................] - ETA: 0s - loss: 0.0734 - sparse_categorical_accuracy: 0.9794
135/469 [=======>......................] - ETA: 0s - loss: 0.0752 - sparse_categorical_accuracy: 0.9791
158/469 [=========>....................] - ETA: 0s - loss: 0.0742 - sparse_categorical_accuracy: 0.9794
181/469 [==========>...................] - ETA: 0s - loss: 0.0738 - sparse_categorical_accuracy: 0.9795
204/469 [============>.................] - ETA: 0s - loss: 0.0728 - sparse_categorical_accuracy: 0.9797
227/469 [=============>................] - ETA: 0s - loss: 0.0727 - sparse_categorical_accuracy: 0.9797
249/469 [==============>...............] - ETA: 0s - loss: 0.0715 - sparse_categorical_accuracy: 0.9802
271/469 [================>.............] - ETA: 0s - loss: 0.0704 - sparse_categorical_accuracy: 0.9806
293/469 [=================>............] - ETA: 0s - loss: 0.0705 - sparse_categorical_accuracy: 0.9805
315/469 [===================>..........] - ETA: 0s - loss: 0.0709 - sparse_categorical_accuracy: 0.9804
337/469 [====================>.........] - ETA: 0s - loss: 0.0708 - sparse_categorical_accuracy: 0.9802
360/469 [======================>.......] - ETA: 0s - loss: 0.0720 - sparse_categorical_accuracy: 0.9797
383/469 [=======================>......] - ETA: 0s - loss: 0.0729 - sparse_categorical_accuracy: 0.9791
406/469 [========================>.....] - ETA: 0s - loss: 0.0727 - sparse_categorical_accuracy: 0.9792
429/469 [==========================>...] - ETA: 0s - loss: 0.0731 - sparse_categorical_accuracy: 0.9791
452/469 [===========================>..] - ETA: 0s - loss: 0.0735 - sparse_categorical_accuracy: 0.9791
469/469 [==============================] - 1s 2ms/step - loss: 0.0738 - sparse_categorical_accuracy: 0.9790 - val_loss: 0.0881 - val_sparse_categorical_accuracy: 0.9735
Epoch 6/6
1/469 [..............................] - ETA: 32s - loss: 0.0219 - sparse_categorical_accuracy: 1.0000
24/469 [>.............................] - ETA: 1s - loss: 0.0620 - sparse_categorical_accuracy: 0.9821
47/469 [==>...........................] - ETA: 0s - loss: 0.0629 - sparse_categorical_accuracy: 0.9830
70/469 [===>..........................] - ETA: 0s - loss: 0.0640 - sparse_categorical_accuracy: 0.9830
93/469 [====>.........................] - ETA: 0s - loss: 0.0619 - sparse_categorical_accuracy: 0.9830
116/469 [======>.......................] - ETA: 0s - loss: 0.0614 - sparse_categorical_accuracy: 0.9829
139/469 [=======>......................] - ETA: 0s - loss: 0.0620 - sparse_categorical_accuracy: 0.9823
162/469 [=========>....................] - ETA: 0s - loss: 0.0632 - sparse_categorical_accuracy: 0.9820
185/469 [==========>...................] - ETA: 0s - loss: 0.0630 - sparse_categorical_accuracy: 0.9821
208/469 [============>.................] - ETA: 0s - loss: 0.0624 - sparse_categorical_accuracy: 0.9823
231/469 [=============>................] - ETA: 0s - loss: 0.0634 - sparse_categorical_accuracy: 0.9821
254/469 [===============>..............] - ETA: 0s - loss: 0.0630 - sparse_categorical_accuracy: 0.9822
277/469 [================>.............] - ETA: 0s - loss: 0.0634 - sparse_categorical_accuracy: 0.9819
300/469 [==================>...........] - ETA: 0s - loss: 0.0633 - sparse_categorical_accuracy: 0.9821
322/469 [===================>..........] - ETA: 0s - loss: 0.0633 - sparse_categorical_accuracy: 0.9821
345/469 [=====================>........] - ETA: 0s - loss: 0.0627 - sparse_categorical_accuracy: 0.9821
368/469 [======================>.......] - ETA: 0s - loss: 0.0628 - sparse_categorical_accuracy: 0.9820
390/469 [=======================>......] - ETA: 0s - loss: 0.0625 - sparse_categorical_accuracy: 0.9822
413/469 [=========================>....] - ETA: 0s - loss: 0.0622 - sparse_categorical_accuracy: 0.9823
435/469 [==========================>...] - ETA: 0s - loss: 0.0618 - sparse_categorical_accuracy: 0.9823
457/469 [============================>.] - ETA: 0s - loss: 0.0618 - sparse_categorical_accuracy: 0.9823
469/469 [==============================] - 1s 2ms/step - loss: 0.0617 - sparse_categorical_accuracy: 0.9823 - val_loss: 0.0793 - val_sparse_categorical_accuracy: 0.9749
<keras.src.callbacks.History at 0x7fc41e0cb880>