# Training Keras with the SLURM Scheduler  

This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model and submit a training job in HPC environments. Although NJIT is quoted the approach must be the same for NYU.  


Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/datasets/keras_example"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/datasets/blob/master/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/datasets/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

2023-10-03 09:29:30.258272: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-03 09:29:30.258321: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-03 09:29:30.258358: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Set 0: VPN and Login 

You need to install a VPN client and establish a VPN connection to NJIT data center. Consult https://ist.njit.edu/vpn for help in doing so. 

```bash 
# select from the two options below and ssh into the HPC server
ssh ucid@HPC_HOST.njit.edu 
ssh ucid@wulver.njit.edu
```

## Step 1: Create your input pipeline

Start by building an efficient input pipeline using advices from:
* The [Performance tips](https://www.tensorflow.org/datasets/performances) guide
* The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide


### Load a dataset

Load the MNIST dataset with the following arguments:

* `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.
* `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`.

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

2023-10-03 09:29:33.682941: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


#### Upload dataset to server

command:
scp -r /path_of_file_locally ucid@HPC_HOST.njit.edu:/path_in_server
or 
scp -r /path_of_file_locally ucid@wulver.njit.edu:/path_in_server

### Build a training pipeline

Apply the following transformations:

* `tf.data.Dataset.map`: TFDS provide images of type `tf.uint8`, while the model expects `tf.float32`. Therefore, you need to normalize images.
* `tf.data.Dataset.cache` As you fit the dataset in memory, cache it before shuffling for a better performance.<br/>
__Note:__ Random transformations should be applied after caching.
* `tf.data.Dataset.shuffle`: For true randomness, set the shuffle buffer to the full dataset size.<br/>
__Note:__ For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.
* `tf.data.Dataset.batch`: Batch elements of the dataset after shuffling to get unique batches at each epoch.
* `tf.data.Dataset.prefetch`: It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching).

In [3]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

### Build an evaluation pipeline

Your testing pipeline is similar to the training pipeline with small differences:

 * You don't need to call `tf.data.Dataset.shuffle`.
 * Caching is done after batching because batches can be the same between epochs.

In [4]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

#### Run code on server

create a .sh file first and here is a sample .sh file which runs your_job_name in general partition, you might need to change nodelist,nodes,gres,mem in your .sh file. If you want to comment something in .sh file, use several # to make sure it is comment, such as #####your comment.

In [None]:
#!/bin/bash -1

#SBATCH --job-name=your_job_name
#SBATCH --output=xxxx.out
#SBATCH --partition=general
#SBATCH --nodelist=node816
#SBATCH --nodes=1
#SBATCH --gres=gpu:1
#SBATCH --mem=16G

Use: sinfo on terminal to check if the job is run successful on the server

Use: squeue -p gpu to check usage of gpu

It is hard to visualize things in server so make sure you save the output of everything into a file in server and download into your local enviroment to visualize if needed.

command:
scp ucid@HPC_HOST.njit.edu:/path_in_server /path_of_file_locally 
or 
scp ucid@wulver.njit.edu:/path_in_server /path_of_file_locally 

## Step 2: Create and train the model

Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.

In [5]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

Epoch 1/6



  1/469 [..............................] - ETA: 16:54 - loss: 2.5053 - sparse_categorical_accuracy: 0.0703


 20/469 [>.............................] - ETA: 1s - loss: 1.6709 - sparse_categorical_accuracy: 0.5254   


 41/469 [=>............................] - ETA: 1s - loss: 1.1971 - sparse_categorical_accuracy: 0.6812


 62/469 [==>...........................] - ETA: 1s - loss: 0.9548 - sparse_categorical_accuracy: 0.7452


 84/469 [====>.........................] - ETA: 0s - loss: 0.8144 - sparse_categorical_accuracy: 0.7816


106/469 [=====>........................] - ETA: 0s - loss: 0.7238 - sparse_categorical_accuracy: 0.8053




































Epoch 2/6



  1/469 [..............................] - ETA: 35s - loss: 0.1062 - sparse_categorical_accuracy: 0.9766


 23/469 [>.............................] - ETA: 1s - loss: 0.1855 - sparse_categorical_accuracy: 0.9467 


 45/469 [=>............................] - ETA: 0s - loss: 0.1784 - sparse_categorical_accuracy: 0.9497


 67/469 [===>..........................] - ETA: 0s - loss: 0.1805 - sparse_categorical_accuracy: 0.9488


 89/469 [====>.........................] - ETA: 0s - loss: 0.1765 - sparse_categorical_accuracy: 0.9506






































Epoch 3/6



  1/469 [..............................] - ETA: 32s - loss: 0.1546 - sparse_categorical_accuracy: 0.9609


 23/469 [>.............................] - ETA: 1s - loss: 0.1243 - sparse_categorical_accuracy: 0.9640 


 45/469 [=>............................] - ETA: 0s - loss: 0.1216 - sparse_categorical_accuracy: 0.9648


 67/469 [===>..........................] - ETA: 0s - loss: 0.1272 - sparse_categorical_accuracy: 0.9628


 90/469 [====>.........................] - ETA: 0s - loss: 0.1230 - sparse_categorical_accuracy: 0.9653




































Epoch 4/6



  1/469 [..............................] - ETA: 32s - loss: 0.0986 - sparse_categorical_accuracy: 0.9688


 24/469 [>.............................] - ETA: 0s - loss: 0.1070 - sparse_categorical_accuracy: 0.9684 


 47/469 [==>...........................] - ETA: 0s - loss: 0.1030 - sparse_categorical_accuracy: 0.9714


 70/469 [===>..........................] - ETA: 0s - loss: 0.0953 - sparse_categorical_accuracy: 0.9733


 93/469 [====>.........................] - ETA: 0s - loss: 0.0948 - sparse_categorical_accuracy: 0.9733




































Epoch 5/6



  1/469 [..............................] - ETA: 32s - loss: 0.0625 - sparse_categorical_accuracy: 0.9922


 23/469 [>.............................] - ETA: 1s - loss: 0.0698 - sparse_categorical_accuracy: 0.9820 


 45/469 [=>............................] - ETA: 0s - loss: 0.0682 - sparse_categorical_accuracy: 0.9811


 68/469 [===>..........................] - ETA: 0s - loss: 0.0707 - sparse_categorical_accuracy: 0.9805


 90/469 [====>.........................] - ETA: 0s - loss: 0.0729 - sparse_categorical_accuracy: 0.9793




































Epoch 6/6



  1/469 [..............................] - ETA: 32s - loss: 0.0219 - sparse_categorical_accuracy: 1.0000


 24/469 [>.............................] - ETA: 1s - loss: 0.0620 - sparse_categorical_accuracy: 0.9821 


 47/469 [==>...........................] - ETA: 0s - loss: 0.0629 - sparse_categorical_accuracy: 0.9830


 70/469 [===>..........................] - ETA: 0s - loss: 0.0640 - sparse_categorical_accuracy: 0.9830


 93/469 [====>.........................] - ETA: 0s - loss: 0.0619 - sparse_categorical_accuracy: 0.9830




































<keras.src.callbacks.History at 0x7fc41e0cb880>