import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
# Set up parameters
= 784
n_input = 256
n_dense
# Custom weight and bias initializers
class RandomNormalInitializer:
def __init__(self, mean=0.0, std=1.0):
self.mean = mean
self.std = std
def __call__(self, tensor):
return nn.init.normal_(tensor, mean=self.mean, std=self.std)
class ZerosInitializer:
def __call__(self, tensor):
return nn.init.zeros_(tensor)
class GlorotNormalInitializer:
def __call__(self, tensor):
return nn.init.xavier_normal_(tensor)
class GlorotUniformInitializer:
def __call__(self, tensor):
return nn.init.xavier_uniform_(tensor)
class HeNormalInitializer:
def __call__(self, tensor):
return nn.init.kaiming_normal_(tensor, nonlinearity='relu')
class HeUniformInitializer:
def __call__(self, tensor):
return nn.init.kaiming_uniform_(tensor, nonlinearity='relu')
# Create a simple MLP model
class SimpleMLP(nn.Module):
def __init__(self, n_input, n_dense, w_init, b_init):
super(SimpleMLP, self).__init__()
self.fc = nn.Linear(n_input, n_dense)
# Initialize weights and biases
self.fc.weight)
w_init(self.fc.bias)
b_init(self.activation = nn.ReLU() #nn.Sigmoid() # You can change to Tanh or ReLU if needed
def forward(self, x):
= self.fc(x)
x = self.activation(x)
x return x
# Initialize the model
= HeNormalInitializer() #RandomNormalInitializer(std=1.0) # Replace with desired initializer
w_init = ZerosInitializer()
b_init = SimpleMLP(n_input, n_dense, w_init, b_init)
model
# Generate random input values
= torch.randn((1, n_input))
x
# Forward propagate through the network
= model(x)
a
= x.detach().numpy() # Convert to numpy for plotting
x_np = plt.hist(x_np.T)
_ "Input Distribution")
plt.title("Output Value")
plt.xlabel("Frequency")
plt.ylabel(
plt.show()
Batch Normalization
Batch Normalization can be understood as a technique that helps address the problem of shifting of the data distribution \(p_{data}(x)\) as the data propagates through the network - this is often referred to as internal covariate shift.
The input to each subsequent layer is affected by the learned parameters of the previous layers. This means that the distribution of activations can change significantly during training. It makes the optimization harder because each layer has to continuously adapt to the changing distributions of inputs, slowing down convergence and potentially leading to a less stable training process.
Input normalization
We have been traditionally normalizing the input data to have a mean of 0 and a standard deviation of 1. This is done to ensure that the input data is centered around 0 and has a similar scale to ensure that the optimization process is stable and converges faster. To justify this consider a limiting example of two parameters as shown below.
In this contour plot of the loss, the SGD trajectory shown is not smooth which means that the algorithm converges very slowly. As shown in the backpropagation exercise with the single neuron, the gradient of the neuron output with respect to the weight is proportional to the input \(x\) and the proportionality factor can be very small if the dot product is either very large or too small. In the case where all inputs are positive, the changes to the weights are all of the same sign across parameters. The reason is that there is a much larger dynamic range in the x-axis and this means that the gradient with respect to one of the parameters will dominate its direction creating a zig-zag pattern.
The best way to correct the situation is to normalize the input data around a mean of 0 and this will result into a much faster SGD convergence. After normalization you can get a more rounded (bivariate in this case) distribution where gradient directions can be diverse.
Normalizing the input is effective but it is not enough. The reason is that the input data is not the only thing that changes as the data propagates through the network. The distribution of activations also changes which is affected by the value of the parameters (weights).
Parameter initialization
Various techniques have been proposed to address this issue, such as careful initialization of the weights, or using activation functions that are less sensitive to the scale of the weights. In the code below we can see how various initializations can affect the output of a fully connected layer.
# Plot the output
= a.detach().numpy() # Convert to numpy for plotting
a_np = plt.hist(a_np.T)
_ "Output Distribution")
plt.title("Output Value")
plt.xlabel("Frequency")
plt.ylabel( plt.show()
Bach Normalization Steps
Ioffe and Szegedy called Batch Normalization a technique that alleviates a lot of headaches with properly initializing neural networks by explicitly forcing the activations throughout a network to have a specific distribution during training. Applying this technique amounts to insert the BatchNorm layer immediately after fully connected layers (or convolutional layers), and before activations although empirically we have found that good training behavior is obtained of batch normalization is applied after the activation function. It involves two steps:
Normalization
It normalizes the input to each layer such that the activations have zero mean and unit variance. The key idea here is to explicitly control and stabilize the distribution of the intermediate activations. For each mini-batch during training, the activations are normalized by subtracting the mean and dividing by the standard deviation of the mini-batch. This helps bring the data back to a more consistent distribution with a mean of zero and variance of one.
\[ \hat{x} = \frac{x - \mu_{\text{batch}}}{\sqrt{\sigma_{\text{batch}}^2 + \epsilon}} \]
Where: - \(\mu_{\text{batch}}\) is the mean of the mini-batch.
\(\sigma_{\text{batch}}^2\) is the variance of the mini-batch.
\(\epsilon\) is a small value added to prevent division by zero.
Scaling and Shifting
After normalization, the activations are scaled and shifted using two learnable parameters: \(\gamma\) and \(\beta\). This step allows the network to recover the original representation of the data if necessary and prevents over-restricting the learned features.
\[ y = \gamma \hat{x} + \beta \]
This scaling and shifting ensure that the transformation is expressive enough to recover any input distribution while keeping it more stable.
The main benefit of Batch Normalization is that of faster convergence, allowing for higher learning rates since it helps to stabilize the training. It also reduces the sensitivity to the initialization of the network parameters.
In Practice
In the provided code, Batch Normalization is implemented manually after each convolutional layer (conv1
and conv2
).
- During training, the mean and variance are computed based on the mini-batch.
- The running mean and variance are updated using a momentum term.
- During inference, the running mean and variance are used to normalize the activations instead of the batch statistics since its typical the batch size during training and inference to have different setting.
In practice networks that use Batch Normalization are significantly more robust to parameter initialization. Additionally, batch normalization can be interpreted as doing preprocessing at every layer of the network, but integrated into the network itself in a differentiable manner.
The effects of BN is reflected clearly in the distribution of the gradients for the same set of parameters as shown below.
The following code demonstrates the effect of Batch Normalization on a simple neural network.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm
# Define a custom Batch Normalization layer
class CustomBatchNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(CustomBatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
self.running_mean = torch.zeros(num_features)
self.running_var = torch.ones(num_features)
def forward(self, x):
if self.training:
# Calculate batch mean and variance
= x.mean(dim=[0, 2, 3], keepdim=True)
batch_mean = x.var(dim=[0, 2, 3], keepdim=True, unbiased=False)
batch_var # Normalize
= (x - batch_mean) / torch.sqrt(batch_var + self.eps)
x_hat # Scale and shift
= self.gamma.view(1, -1, 1, 1) * x_hat + self.beta.view(1, -1, 1, 1)
out # Update running statistics
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean.view(-1)
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var.view(-1)
else:
# Use running mean and variance during inference
= (x - self.running_mean.view(1, -1, 1, 1)) / torch.sqrt(self.running_var.view(1, -1, 1, 1) + self.eps)
x_hat = self.gamma.view(1, -1, 1, 1) * x_hat + self.beta.view(1, -1, 1, 1)
out return out
# Define a simple CNN model with custom Batch Normalization
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.bn1 = CustomBatchNorm(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn2 = CustomBatchNorm(20)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
= F.relu(F.max_pool2d(self.conv1(x), 2))
x = self.bn1(x)
x = self.dropout(x)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = self.bn2(x)
x = self.dropout(x)
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
x return F.log_softmax(x, dim=1)
# Set up training parameters
= 64
batch_size = 0.01
learning_rate = 1e-4 # L2 regularization parameter
weight_decay = 20 # Early stopping patience
patience
# Load the dataset
= datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor())
train_dataset = train_test_split(train_dataset, test_size=0.2, random_state=42)
train_data, val_data
= torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
train_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
val_loader
# Check if GPU is available and use it
= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
# Initialize the model, loss function, and optimizer
= SimpleCNN().to(device)
model = nn.NLLLoss()
criterion = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer
# Training loop with Early Stopping and TQDM for Epoch Progress
= 100
num_epochs = []
train_losses = []
val_losses = np.inf
min_val_loss = 0
patience_counter
for epoch in tqdm(range(num_epochs), desc="Epoch Progress", position=0):
model.train()= 0
total_train_loss for batch_idx, (data, target) in enumerate(train_loader):
= data.to(device), target.to(device)
data, target # Zero the gradients
optimizer.zero_grad() = model(data) # Forward pass
output = criterion(output, target) # Compute the loss
loss # Backpropagate the gradients
loss.backward() # Update the weights
optimizer.step() += loss.item()
total_train_loss = total_train_loss / len(train_loader)
avg_train_loss
train_losses.append(avg_train_loss)print(f'Epoch {epoch + 1}: Train Loss: {avg_train_loss:.6f}')
# Validation loss
eval()
model.= 0
total_val_loss with torch.no_grad():
for data, target in val_loader:
= data.to(device), target.to(device)
data, target = model(data)
output = criterion(output, target)
loss += loss.item()
total_val_loss = total_val_loss / len(val_loader)
avg_val_loss
val_losses.append(avg_val_loss)print(f'Epoch {epoch + 1}: Validation Loss: {avg_val_loss:.6f}')
# Early stopping check
if avg_val_loss < min_val_loss:
= avg_val_loss
min_val_loss = 0
patience_counter = model.state_dict()
best_model_state else:
+= 1
patience_counter if patience_counter >= patience:
print(f'Early stopping triggered after {epoch + 1} epochs.')
break
# Load the best model state (if early stopping was triggered)
model.load_state_dict(best_model_state)
# Plotting Train and Validation Loss vs Epochs
range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.plot('Epochs')
plt.xlabel('Loss')
plt.ylabel('Train and Validation Loss vs Epochs with Dropout, Batch Normalization, and Early Stopping')
plt.title(
plt.legend()True)
plt.grid( plt.show()
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:128: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
return torch._C._cuda_getDeviceCount() > 0
Epoch Progress: 0%| | 0/50 [00:00<?, ?it/s]
Epoch 1: Train Loss: 0.770188
Epoch Progress: 2%|▏ | 1/50 [00:03<03:01, 3.70s/it]
Epoch 1: Validation Loss: 0.224352
Epoch 2: Train Loss: 0.263023
Epoch Progress: 4%|▍ | 2/50 [00:08<03:29, 4.37s/it]
Epoch 2: Validation Loss: 0.140612
Epoch Progress: 6%|▌ | 3/50 [00:11<03:02, 3.89s/it]
Epoch 3: Train Loss: 0.200401
Epoch 3: Validation Loss: 0.100889
Epoch Progress: 8%|▊ | 4/50 [00:15<02:56, 3.83s/it]
Epoch 4: Train Loss: 0.171019
Epoch 4: Validation Loss: 0.095516
Epoch Progress: 10%|█ | 5/50 [00:19<03:00, 4.02s/it]
Epoch 5: Train Loss: 0.154492
Epoch 5: Validation Loss: 0.089253
Epoch 6: Train Loss: 0.145619
Epoch Progress: 12%|█▏ | 6/50 [00:23<02:48, 3.82s/it]
Epoch 6: Validation Loss: 0.076544
Epoch 7: Train Loss: 0.134191
Epoch Progress: 14%|█▍ | 7/50 [00:27<02:47, 3.89s/it]
Epoch 7: Validation Loss: 0.066166
Epoch 8: Train Loss: 0.128683
Epoch Progress: 16%|█▌ | 8/50 [00:32<02:55, 4.17s/it]
Epoch 8: Validation Loss: 0.065748
Epoch 9: Train Loss: 0.122401
Epoch Progress: 18%|█▊ | 9/50 [00:35<02:45, 4.03s/it]
Epoch 9: Validation Loss: 0.058561
Epoch 10: Train Loss: 0.119494
Epoch Progress: 20%|██ | 10/50 [00:39<02:37, 3.93s/it]
Epoch 10: Validation Loss: 0.058964
Epoch 11: Train Loss: 0.113731
Epoch Progress: 22%|██▏ | 11/50 [00:43<02:30, 3.87s/it]
Epoch 11: Validation Loss: 0.055343
Epoch 12: Train Loss: 0.109625
Epoch Progress: 24%|██▍ | 12/50 [00:47<02:24, 3.81s/it]
Epoch 12: Validation Loss: 0.054525
Epoch 13: Train Loss: 0.105728
Epoch Progress: 26%|██▌ | 13/50 [00:51<02:33, 4.15s/it]
Epoch 13: Validation Loss: 0.052386
Epoch Progress: 28%|██▊ | 14/50 [00:56<02:32, 4.23s/it]
Epoch 14: Train Loss: 0.105186
Epoch 14: Validation Loss: 0.053551
Epoch 15: Train Loss: 0.102408
Epoch Progress: 30%|███ | 15/50 [00:59<02:20, 4.01s/it]
Epoch 15: Validation Loss: 0.048707
Epoch 16: Train Loss: 0.097259
Epoch Progress: 32%|███▏ | 16/50 [01:03<02:11, 3.86s/it]
Epoch 16: Validation Loss: 0.048579
Epoch 17: Train Loss: 0.097598
Epoch Progress: 34%|███▍ | 17/50 [01:07<02:05, 3.80s/it]
Epoch 17: Validation Loss: 0.048843
Epoch 18: Train Loss: 0.095315
Epoch Progress: 36%|███▌ | 18/50 [01:10<01:57, 3.67s/it]
Epoch 18: Validation Loss: 0.045575
Epoch 19: Train Loss: 0.093268
Epoch Progress: 38%|███▊ | 19/50 [01:14<01:54, 3.69s/it]
Epoch 19: Validation Loss: 0.045303
Epoch 20: Train Loss: 0.094709
Epoch Progress: 40%|████ | 20/50 [01:17<01:50, 3.67s/it]
Epoch 20: Validation Loss: 0.044520
Epoch 21: Train Loss: 0.091885
Epoch Progress: 42%|████▏ | 21/50 [01:21<01:46, 3.68s/it]
Epoch 21: Validation Loss: 0.043269
Epoch Progress: 44%|████▍ | 22/50 [01:25<01:45, 3.77s/it]
Epoch 22: Train Loss: 0.089536
Epoch 22: Validation Loss: 0.045272
Epoch Progress: 46%|████▌ | 23/50 [01:28<01:37, 3.60s/it]
Epoch 23: Train Loss: 0.088439
Epoch 23: Validation Loss: 0.042783
Epoch 24: Train Loss: 0.087619
Epoch Progress: 48%|████▊ | 24/50 [01:32<01:33, 3.60s/it]
Epoch 24: Validation Loss: 0.041457
Epoch 25: Train Loss: 0.084606
Epoch Progress: 50%|█████ | 25/50 [01:35<01:28, 3.54s/it]
Epoch 25: Validation Loss: 0.039366
Epoch Progress: 52%|█████▏ | 26/50 [01:40<01:31, 3.80s/it]
Epoch 26: Train Loss: 0.083805
Epoch 26: Validation Loss: 0.040424
Epoch 27: Train Loss: 0.083419
Epoch Progress: 54%|█████▍ | 27/50 [01:43<01:22, 3.61s/it]
Epoch 27: Validation Loss: 0.040376
Epoch 28: Train Loss: 0.080004
Epoch Progress: 56%|█████▌ | 28/50 [01:46<01:17, 3.53s/it]
Epoch 28: Validation Loss: 0.038821
Epoch 29: Train Loss: 0.083876
Epoch Progress: 58%|█████▊ | 29/50 [01:50<01:14, 3.56s/it]
Epoch 29: Validation Loss: 0.037479
Epoch 30: Train Loss: 0.080369
Epoch Progress: 60%|██████ | 30/50 [01:53<01:09, 3.47s/it]
Epoch 30: Validation Loss: 0.037347
Epoch Progress: 62%|██████▏ | 31/50 [01:56<01:04, 3.39s/it]
Epoch 31: Train Loss: 0.079569
Epoch 31: Validation Loss: 0.037011
Epoch 32: Train Loss: 0.078634
Epoch Progress: 64%|██████▍ | 32/50 [02:01<01:08, 3.78s/it]
Epoch 32: Validation Loss: 0.037458
Epoch 33: Train Loss: 0.080423
Epoch Progress: 66%|██████▌ | 33/50 [02:06<01:09, 4.06s/it]
Epoch 33: Validation Loss: 0.038080
Epoch Progress: 66%|██████▌ | 33/50 [02:10<01:07, 3.95s/it]
Epoch 34: Train Loss: 0.077929
Epoch 34: Validation Loss: 0.038055
Early stopping triggered after 34 epochs.