import torch
from torch import nn
class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=16, emb_size=768):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
= self.proj(x) # (B, emb_size, H/patch_size, W/patch_size)
x = x.flatten(2) # (B, emb_size, N_patches)
x = x.transpose(1, 2) # (B, N_patches, emb_size)
x return x
Vision Transformer (ViT) from Scratch in PyTorch
This notebook walks you through building a Vision Transformer (ViT) from scratch using PyTorch. We assume you’re familiar with Transformer architectures in NLP (especially decoder-based transformers).
The ViT treats an image as a sequence of patches, similar to tokens in NLP. We’ll implement each core building block step-by-step.
1. Patch Embedding
The first step in ViT is to split the image into fixed-size patches and project them into a latent embedding space using a convolution layer.
2. Positional Encoding
Adds positional information to the patch embeddings using learnable positional embeddings.
class PositionalEncoding(nn.Module):
def __init__(self, seq_len, emb_size):
super().__init__()
self.pos_embedding = nn.Parameter(torch.randn(1, seq_len, emb_size))
def forward(self, x):
return x + self.pos_embedding
3. Transformer Encoder Block
Each block contains multi-head self-attention, a feed-forward network, and residual connections.
class TransformerEncoderBlock(nn.Module):
def __init__(self, emb_size, num_heads, ff_hidden_mult=4, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(emb_size)
self.attn = nn.MultiheadAttention(emb_size, num_heads, dropout=dropout)
self.norm2 = nn.LayerNorm(emb_size)
self.ff = nn.Sequential(
* emb_size),
nn.Linear(emb_size, ff_hidden_mult
nn.GELU(),* emb_size, emb_size),
nn.Linear(ff_hidden_mult
nn.Dropout(dropout)
)
def forward(self, x):
= x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.ff(self.norm2(x))
x return x
4. Classification Head
Uses the CLS token to produce class predictions after the encoder blocks.
class ClassificationHead(nn.Module):
def __init__(self, emb_size, num_classes):
super().__init__()
self.norm = nn.LayerNorm(emb_size)
self.fc = nn.Linear(emb_size, num_classes)
def forward(self, x):
= self.norm(x)
x return self.fc(x[:, 0])
5. Vision Transformer Model
Combines all parts together into a full ViT model. Note that this model is replaced further below with a model configured for the CIFAR-10 dataset.
class VisionTransformer(nn.Module):
def __init__(self, in_channels=3, patch_size=16, emb_size=768, num_heads=12, num_layers=12, num_classes=1000, img_size=224):
super().__init__()
self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size)
= (img_size // patch_size) ** 2
num_patches self.pos_embed = PositionalEncoding(num_patches + 1, emb_size)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
self.encoder = nn.Sequential(*[TransformerEncoderBlock(emb_size, num_heads) for _ in range(num_layers)])
self.head = ClassificationHead(emb_size, num_classes)
def forward(self, x):
= self.patch_embed(x)
x = self.cls_token.expand(x.size(0), -1, -1)
cls_tokens = torch.cat((cls_tokens, x), dim=1)
x = self.pos_embed(x)
x = self.encoder(x)
x return self.head(x)
6. Dataset and Training Prep
Example: Loading CIFAR-10 to train the Vision Transformer. Note something about the transformation needed.
The original Vision Transformer (ViT) architecture was trained on ImageNet, where input images are 224×224 pixels. The model expects this size because the image is split into fixed-size patches (e.g., 16×16), and 224/16 = 14, resulting in 14×14 = 196 patches. This number of tokens (patches) is what the positional embeddings and transformer layers are designed for in ViT-Base.
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
= transforms.Compose([
transform 224, 224)),
transforms.Resize((
transforms.ToTensor(),
])
= datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_dataset = DataLoader(train_dataset, batch_size=32, shuffle=True) train_loader
Having said that the CIFAR10 images are too small for such resizing to make sense. Below we make the modifications to the model to work with CIFAR10 images.
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
= transforms.Compose([
transform # No resizing needed for CIFAR-10 32x32
transforms.ToTensor(),
])
= datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_dataset = DataLoader(train_dataset, batch_size=32, shuffle=True) train_loader
class VisionTransformer(nn.Module):
def __init__(self, in_channels=3, patch_size=4, emb_size=192, num_heads=3, num_layers=7, num_classes=10, img_size=32):
super().__init__()
self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size)
= (img_size // patch_size) ** 2
num_patches self.pos_embed = PositionalEncoding(num_patches + 1, emb_size)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
self.encoder = nn.Sequential(*[TransformerEncoderBlock(emb_size, num_heads) for _ in range(num_layers)])
self.head = ClassificationHead(emb_size, num_classes)
def forward(self, x):
= self.patch_embed(x)
x = self.cls_token.expand(x.size(0), -1, -1)
cls_tokens = torch.cat((cls_tokens, x), dim=1)
x = self.pos_embed(x)
x = self.encoder(x)
x return self.head(x)
7. Training Loop
Here’s a basic training loop to train the Vision Transformer on the CIFAR-10 dataset.
import torch.optim as optim
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
= VisionTransformer(num_classes=10).to(device)
model = nn.CrossEntropyLoss()
criterion = optim.Adam(model.parameters(), lr=3e-4)
optimizer
def train(model, dataloader, criterion, optimizer, device, epochs=5):
model.train()for epoch in range(epochs):
= 0
total_loss = 0
correct = 0
total for images, labels in dataloader:
= images.to(device), labels.to(device)
images, labels
optimizer.zero_grad()= model(images)
outputs = criterion(outputs, labels)
loss
loss.backward()
optimizer.step()
+= loss.item()
total_loss = outputs.max(1)
_, predicted += labels.size(0)
total += predicted.eq(labels).sum().item()
correct
print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}, Accuracy: {100.*correct/total:.2f}%")
# Train the model
train(model, train_loader, criterion, optimizer, device)
Epoch 1/5, Loss: 2.3192, Accuracy: 9.93%
Epoch 2/5, Loss: 2.3102, Accuracy: 10.23%
Epoch 3/5, Loss: 2.3079, Accuracy: 10.20%
Epoch 4/5, Loss: 2.3071, Accuracy: 9.89%
Epoch 5/5, Loss: 2.3063, Accuracy: 9.69%