!pip install einops
Stable Diffusion - Part 1
Binxu Wang
Nov.2022
This notebook walk you through how to build your Unet architecture from scratch! All the network components are defined in a single notebook.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from collections import OrderedDict
from easydict import EasyDict as edict
Define our Unet Architecture
Build Our ResBlock
# backbone, Residual Block (Checked)
class ResBlock(nn.Module):
def __init__(self, in_channel, time_emb_dim, out_channel=None, ):
super().__init__()
if out_channel is None:
= in_channel
out_channel self.norm1 = nn.GroupNorm(32, in_channel, eps=1e-05, affine=True)
self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1)
self.time_emb_proj = nn.Linear(in_features=time_emb_dim, out_features=out_channel, bias=True)
self.norm2 = nn.GroupNorm(32, out_channel, eps=1e-05, affine=True)
self.dropout = nn.Dropout(p=0.0, inplace=False)
self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1)
self.nonlinearity = nn.SiLU()
if in_channel == out_channel:
self.conv_shortcut = nn.Identity()
else:
self.conv_shortcut = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1)
def forward(self, x, t_emb, cond=None):
# Input conv
= self.norm1(x)
h = self.nonlinearity(h)
h = self.conv1(h)
h # Time modulation
if t_emb is not None:
= self.time_emb_proj(self.nonlinearity(t_emb))
t_hidden = h + t_hidden[:, :, None, None]
h # Output conv
= self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
h # Skip connection
return h + self.conv_shortcut(x)
# UpSampling (Checked)
class UpSample(nn.Module):
def __init__(self, channel, scale_factor=2, mode='nearest'):
super(UpSample, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.conv = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1, )
def forward(self, x):
= F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
x return self.conv(x)
# DownSampling (Checked)
class DownSample(nn.Module):
def __init__(self, channel, ):
super(DownSample, self).__init__()
self.conv = nn.Conv2d(channel, channel, kernel_size=3, stride=2, padding=1, )
def forward(self, x):
return self.conv(x) # F.interpolate(x, scale_factor=1/self.scale_factor, mode=self.mode)
Build Our Attention / Transformer
# Self and Cross Attention mechanism (Checked)
class CrossAttention(nn.Module):
"""General implementation of Cross & Self Attention multi-head
"""
def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=8, ):
super(CrossAttention, self).__init__()
self.hidden_dim = hidden_dim
self.context_dim = context_dim
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.to_q = nn.Linear(hidden_dim, embed_dim, bias=False)
if context_dim is None:
# Self Attention
self.to_k = nn.Linear(hidden_dim, embed_dim, bias=False)
self.to_v = nn.Linear(hidden_dim, embed_dim, bias=False)
self.self_attn = True
else:
# Cross Attention
self.to_k = nn.Linear(context_dim, embed_dim, bias=False)
self.to_v = nn.Linear(context_dim, embed_dim, bias=False)
self.self_attn = False
self.to_out = nn.Sequential(
=True)
nn.Linear(embed_dim, hidden_dim, bias# this could be omitted
)
def forward(self, tokens, context=None):
= self.to_q(tokens)
Q = self.to_k(tokens) if self.self_attn else self.to_k(context)
K = self.to_v(tokens) if self.self_attn else self.to_v(context)
V # print(Q.shape, K.shape, V.shape)
# transform heads onto batch dimension
= rearrange(Q, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim)
Q = rearrange(K, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim)
K = rearrange(V, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim)
V # print(Q.shape, K.shape, V.shape)
= torch.einsum("BTD,BSD->BTS", Q, K)
scoremats = F.softmax(scoremats / math.sqrt(self.head_dim), dim=-1)
attnmats # print(scoremats.shape, attnmats.shape, )
= torch.einsum("BTS,BSD->BTD", attnmats, V)
ctx_vecs # split the heads transform back to hidden.
= rearrange(ctx_vecs, '(B H) T D -> B T (H D)', H=self.num_heads, D=self.head_dim)
ctx_vecs # TODO: note this `to_out` is also a linear layer, could be in principle merged into the to_value layer.
return self.to_out(ctx_vecs)
# Transformer layers
class TransformerBlock(nn.Module):
def __init__(self, hidden_dim, context_dim, num_heads=8):
super(TransformerBlock, self).__init__()
self.attn1 = CrossAttention(hidden_dim, hidden_dim, num_heads=num_heads) # self attention
self.attn2 = CrossAttention(hidden_dim, hidden_dim, context_dim, num_heads=num_heads) # cross attention
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.norm3 = nn.LayerNorm(hidden_dim)
# to be compatible with Diffuser, could simplify.
self.ff = FeedForward_GEGLU(hidden_dim, )
# A more common version used in transformers.
# self.ff = nn.Sequential(
# nn.Linear(hidden_dim, 3 * hidden_dim),
# nn.GELU(),
# nn.Linear(3 * hidden_dim, hidden_dim)
# )
def forward(self, x, context=None):
= self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
x return x
class GEGLU_proj(nn.Module):
def __init__(self, in_dim, out_dim):
super(GEGLU_proj, self).__init__()
self.proj = nn.Linear(in_dim, 2 * out_dim)
def forward(self, x):
= self.proj(x)
x = x.chunk(2, dim=-1)
x, gates return x * F.gelu(gates)
class FeedForward_GEGLU(nn.Module):
# https://github.com/huggingface/diffusers/blob/95414bd6bf9bb34a312a7c55f10ba9b379f33890/src/diffusers/models/attention.py#L339
# A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
def __init__(self, hidden_dim, mult=4):
super(FeedForward_GEGLU, self).__init__()
self.net = nn.Sequential(
* hidden_dim),
GEGLU_proj(hidden_dim, mult 0.0),
nn.Dropout(* hidden_dim, hidden_dim)
nn.Linear(mult # to be compatible with Diffuser, could simplify.
)
def forward(self, x, ):
return self.net(x)
class SpatialTransformer(nn.Module):
def __init__(self, hidden_dim, context_dim, num_heads=8):
super(SpatialTransformer, self).__init__()
self.norm = nn.GroupNorm(32, hidden_dim, eps=1e-6, affine=True)
self.proj_in = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1, stride=1, padding=0)
# self.transformer = TransformerBlock(hidden_dim, context_dim, num_heads=8)
self.transformer_blocks = nn.Sequential(
=8)
TransformerBlock(hidden_dim, context_dim, num_heads# to be compatible with Diffuser, could simplify.
) self.proj_out = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1, stride=1, padding=0)
def forward(self, x, cond=None):
= x.shape
b, c, h, w = x
x_in # context = rearrange(context, "b c T -> b T c")
= self.proj_in(self.norm(x))
x = rearrange(x, "b c h w->b (h w) c")
x = self.transformer_blocks[0](x, cond)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x return self.proj_out(x) + x_in
Container of ResBlock and Spatial Transformers
# Modified Container. Modify the nn.Sequential to time modulated Sequential
class TimeModulatedSequential(nn.Sequential):
""" Modify the nn.Sequential to time modulated Sequential """
def forward(self, x, t_emb, cond=None):
for module in self:
if isinstance(module, TimeModulatedSequential):
= module(x, t_emb, cond)
x elif isinstance(module, ResBlock):
# For certain layers, add the time modulation.
= module(x, t_emb)
x elif isinstance(module, SpatialTransformer):
# For certain layers, add the class conditioning.
= module(x, cond=cond)
x else:
= module(x)
x
return x
Putting it Together into UNet!
class UNet_SD(nn.Module):
def __init__(self, in_channels=4,
=320,
base_channels=1280,
time_emb_dim=768,
context_dim=(1, 2, 4, 4),
multipliers=(0, 1, 2),
attn_levels=2,
nResAttn_block=True):
cat_unetsuper().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.in_channels = in_channels
self.out_channels = in_channels
= base_channels
base_channels = time_emb_dim
time_emb_dim = context_dim
context_dim = multipliers
multipliers = len(multipliers)
nlevel self.base_channels = base_channels
# attn_levels = [0, 1, 2]
= [base_channels * mult for mult in multipliers]
level_channels # Transform time into embedding
self.time_embedding = nn.Sequential(OrderedDict({
"linear_1": nn.Linear(base_channels, time_emb_dim, bias=True),
"act": nn.SiLU(),
"linear_2": nn.Linear(time_emb_dim, time_emb_dim, bias=True),
})# 2 layer MLP
) self.conv_in = nn.Conv2d(self.in_channels, base_channels, 3, stride=1, padding=1)
# Tensor Downsample blocks
= nResAttn_block
nResAttn_block self.down_blocks = TimeModulatedSequential() # nn.ModuleList()
self.down_blocks_channels = [base_channels]
= base_channels
cur_chan for i in range(nlevel):
for j in range(nResAttn_block):
= TimeModulatedSequential()
res_attn_sandwich # input_chan of first ResBlock is different from the rest.
=cur_chan, time_emb_dim=time_emb_dim, out_channel=level_channels[i]))
res_attn_sandwich.append(ResBlock(in_channelif i in attn_levels:
# add attention except for the last level
=context_dim))
res_attn_sandwich.append(SpatialTransformer(level_channels[i], context_dim= level_channels[i]
cur_chan self.down_blocks.append(res_attn_sandwich)
self.down_blocks_channels.append(cur_chan)
# res_attn_sandwich.append(DownSample(level_channels[i]))
if not i == nlevel - 1:
self.down_blocks.append(TimeModulatedSequential(DownSample(level_channels[i])))
self.down_blocks_channels.append(cur_chan)
self.mid_block = TimeModulatedSequential(
ResBlock(cur_chan, time_emb_dim),=context_dim),
SpatialTransformer(cur_chan, context_dim
ResBlock(cur_chan, time_emb_dim),
)
# Tensor Upsample blocks
self.up_blocks = nn.ModuleList() # TimeModulatedSequential() #
for i in reversed(range(nlevel)):
for j in range(nResAttn_block + 1):
= TimeModulatedSequential()
res_attn_sandwich =cur_chan + self.down_blocks_channels.pop(),
res_attn_sandwich.append(ResBlock(in_channel=time_emb_dim, out_channel=level_channels[i]))
time_emb_dimif i in attn_levels:
=context_dim))
res_attn_sandwich.append(SpatialTransformer(level_channels[i], context_dim= level_channels[i]
cur_chan if j == nResAttn_block and i != 0:
res_attn_sandwich.append(UpSample(level_channels[i]))self.up_blocks.append(res_attn_sandwich)
# Read out from tensor to latent space
self.output = nn.Sequential(
32, base_channels, ),
nn.GroupNorm(
nn.SiLU(),self.out_channels, 3, padding=1),
nn.Conv2d(base_channels,
)self.to(self.device)
def time_proj(self, time_steps, max_period: int = 10000):
if time_steps.ndim == 0:
= time_steps.unsqueeze(0)
time_steps = self.base_channels // 2
half = torch.exp(- math.log(max_period)
frequencies * torch.arange(start=0, end=half, dtype=torch.float32) / half
=time_steps.device)
).to(device= time_steps[:, None].float() * frequencies[None, :]
angles return torch.cat([torch.cos(angles), torch.sin(angles)], dim=-1)
def forward(self, x, time_steps, cond=None, encoder_hidden_states=None, output_dict=True):
if cond is None and encoder_hidden_states is not None:
= encoder_hidden_states
cond = self.time_proj(time_steps)
t_emb = self.time_embedding(t_emb)
t_emb = self.conv_in(x)
x = [x]
down_x_cache for module in self.down_blocks:
= module(x, t_emb, cond)
x
down_x_cache.append(x)= self.mid_block(x, t_emb, cond)
x for module in self.up_blocks:
= module(torch.cat((x, down_x_cache.pop()), dim=1), t_emb, cond)
x = self.output(x)
x if output_dict:
return edict(sample=x)
else:
return x
Unit test the components with the UNet implementation
Check ResNet
def test_ResBlock(pipe):
= ResBlock(320, 1280).cuda()
tmp_blk = pipe.unet.down_blocks[0].resnets[0]
std_blk = std_blk.state_dict()
SD
tmp_blk.load_state_dict(SD)= torch.randn(3, 320, 32, 32)
lat_tmp = torch.randn(3, 1280)
temb with torch.no_grad():
= pipe.unet.down_blocks[0].resnets[0](lat_tmp.cuda(),temb.cuda())
out = tmp_blk(lat_tmp.cuda(), temb.cuda())
out2
assert torch.allclose(out2, out)
test_ResBlock(pipe)
def test_downsampler(pipe):
= DownSample(320).cuda()
tmpdsp = pipe.unet.down_blocks[0].downsamplers[0]
stddsp = stddsp.state_dict()
SD
tmpdsp.load_state_dict(SD)= torch.randn(3, 320, 32, 32)
lat_tmp with torch.no_grad():
= stddsp(lat_tmp.cuda())
out = tmpdsp(lat_tmp.cuda())
out2
assert torch.allclose(out2, out)
def test_upsampler(pipe):
= UpSample(1280).cuda()
tmpusp = pipe.unet.up_blocks[1].upsamplers[0]
stdusp = stdusp.state_dict()
SD
tmpusp.load_state_dict(SD)= torch.randn(3, 1280, 32, 32)
lat_tmp with torch.no_grad():
= stdusp(lat_tmp.cuda())
out = tmpusp(lat_tmp.cuda())
out2
assert torch.allclose(out2, out)
test_downsampler(pipe) test_upsampler(pipe)
Check Attention
def test_self_attention(pipe):
= CrossAttention(320, 320, context_dim=None, num_heads=8).cuda()
tmpSattn = pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].attn1
stdSattn # checked
tmpSattn.load_state_dict(stdSattn.state_dict()) with torch.no_grad():
= torch.randn(3, 32, 320)
lat_tmp = stdSattn(lat_tmp.cuda())
out = tmpSattn(lat_tmp.cuda())
out2 assert torch.allclose(out2, out) # False
#%%
# Check Cross attention
def test_cross_attention(pipe):
= CrossAttention(320, 320, context_dim=768, num_heads=8).cuda()
tmpXattn = pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].attn2
stdXattn # checked
tmpXattn.load_state_dict(stdXattn.state_dict()) with torch.no_grad():
= torch.randn(3, 32, 320)
lat_tmp = torch.randn(3, 5, 768)
ctx_tmp = stdXattn(lat_tmp.cuda(), ctx_tmp.cuda())
out = tmpXattn(lat_tmp.cuda(), ctx_tmp.cuda())
out2 assert torch.allclose(out2, out) # False
test_self_attention(pipe) test_cross_attention(pipe)
Check Transformer
#%% test TransformerBlock Implementation
def test_TransformerBlock(pipe):
= TransformerBlock(320, context_dim=768, num_heads=8).cuda()
tmpTfmer = pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0]
stdTfmer # checked
tmpTfmer.load_state_dict(stdTfmer.state_dict()) with torch.no_grad():
= torch.randn(3, 32, 320)
lat_tmp = torch.randn(3, 5, 768)
ctx_tmp = tmpTfmer(lat_tmp.cuda(), ctx_tmp.cuda())
out = stdTfmer(lat_tmp.cuda(), ctx_tmp.cuda())
out2 assert torch.allclose(out2, out) # False
#%% test SpatialTransformer Implementation
def test_SpatialTransformer(pipe):
= SpatialTransformer(320, context_dim=768, num_heads=8).cuda()
tmpSpTfmer = pipe.unet.down_blocks[0].attentions[0]
stdSpTfmer # checked
tmpSpTfmer.load_state_dict(stdSpTfmer.state_dict()) with torch.no_grad():
= torch.randn(3, 320, 8, 8)
lat_tmp = torch.randn(3, 5, 768)
ctx_tmp = tmpSpTfmer(lat_tmp.cuda(), ctx_tmp.cuda())
out = stdSpTfmer(lat_tmp.cuda(), ctx_tmp.cuda())
out2 assert torch.allclose(out2, out) # False
test_TransformerBlock(pipe) test_SpatialTransformer(pipe)
Load Weights into our UNet!
!pip install diffusers transformers tokenizers
from huggingface_hub import notebook_login
notebook_login()
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import matplotlib.pyplot as plt
= StableDiffusionPipeline.from_pretrained(
pipe "CompVis/stable-diffusion-v1-4",
=True
use_auth_token"cuda")
).to(def dummy_checker(images, **kwargs): return images, False
= dummy_checker pipe.safety_checker
Test the Entire UNet model
#@title Utils to load weights
def load_pipe_into_our_UNet(myUNet, pipe_unet):
# load the pretrained weights from the pipe into our UNet.
# Loading input and output layers.
0].load_state_dict(pipe_unet.conv_norm_out.state_dict())
myUNet.output[2].load_state_dict(pipe_unet.conv_out.state_dict())
myUNet.output[
myUNet.conv_in.load_state_dict(pipe_unet.conv_in.state_dict())
myUNet.time_embedding.load_state_dict(pipe_unet.time_embedding.state_dict())#% Loading the down blocks
0][0].load_state_dict(pipe_unet.down_blocks[0].resnets[0].state_dict())
myUNet.down_blocks[0][1].load_state_dict(pipe_unet.down_blocks[0].attentions[0].state_dict())
myUNet.down_blocks[1][0].load_state_dict(pipe_unet.down_blocks[0].resnets[1].state_dict())
myUNet.down_blocks[1][1].load_state_dict(pipe_unet.down_blocks[0].attentions[1].state_dict())
myUNet.down_blocks[2][0].load_state_dict(pipe_unet.down_blocks[0].downsamplers[0].state_dict())
myUNet.down_blocks[
3][0].load_state_dict(pipe_unet.down_blocks[1].resnets[0].state_dict())
myUNet.down_blocks[3][1].load_state_dict(pipe_unet.down_blocks[1].attentions[0].state_dict())
myUNet.down_blocks[4][0].load_state_dict(pipe_unet.down_blocks[1].resnets[1].state_dict())
myUNet.down_blocks[4][1].load_state_dict(pipe_unet.down_blocks[1].attentions[1].state_dict())
myUNet.down_blocks[5][0].load_state_dict(pipe_unet.down_blocks[1].downsamplers[0].state_dict())
myUNet.down_blocks[
6][0].load_state_dict(pipe_unet.down_blocks[2].resnets[0].state_dict())
myUNet.down_blocks[6][1].load_state_dict(pipe_unet.down_blocks[2].attentions[0].state_dict())
myUNet.down_blocks[7][0].load_state_dict(pipe_unet.down_blocks[2].resnets[1].state_dict())
myUNet.down_blocks[7][1].load_state_dict(pipe_unet.down_blocks[2].attentions[1].state_dict())
myUNet.down_blocks[8][0].load_state_dict(pipe_unet.down_blocks[2].downsamplers[0].state_dict())
myUNet.down_blocks[
9][0].load_state_dict(pipe_unet.down_blocks[3].resnets[0].state_dict())
myUNet.down_blocks[10][0].load_state_dict(pipe_unet.down_blocks[3].resnets[1].state_dict())
myUNet.down_blocks[
#% Loading the middle blocks
0].load_state_dict(pipe_unet.mid_block.resnets[0].state_dict())
myUNet.mid_block[1].load_state_dict(pipe_unet.mid_block.attentions[0].state_dict())
myUNet.mid_block[2].load_state_dict(pipe_unet.mid_block.resnets[1].state_dict())
myUNet.mid_block[# % Loading the up blocks
# upblock 0
0][0].load_state_dict(pipe_unet.up_blocks[0].resnets[0].state_dict())
myUNet.up_blocks[1][0].load_state_dict(pipe_unet.up_blocks[0].resnets[1].state_dict())
myUNet.up_blocks[2][0].load_state_dict(pipe_unet.up_blocks[0].resnets[2].state_dict())
myUNet.up_blocks[2][1].load_state_dict(pipe_unet.up_blocks[0].upsamplers[0].state_dict())
myUNet.up_blocks[# % upblock 1
3][0].load_state_dict(pipe_unet.up_blocks[1].resnets[0].state_dict())
myUNet.up_blocks[3][1].load_state_dict(pipe_unet.up_blocks[1].attentions[0].state_dict())
myUNet.up_blocks[4][0].load_state_dict(pipe_unet.up_blocks[1].resnets[1].state_dict())
myUNet.up_blocks[4][1].load_state_dict(pipe_unet.up_blocks[1].attentions[1].state_dict())
myUNet.up_blocks[5][0].load_state_dict(pipe_unet.up_blocks[1].resnets[2].state_dict())
myUNet.up_blocks[5][1].load_state_dict(pipe_unet.up_blocks[1].attentions[2].state_dict())
myUNet.up_blocks[5][2].load_state_dict(pipe_unet.up_blocks[1].upsamplers[0].state_dict())
myUNet.up_blocks[# % upblock 2
6][0].load_state_dict(pipe_unet.up_blocks[2].resnets[0].state_dict())
myUNet.up_blocks[6][1].load_state_dict(pipe_unet.up_blocks[2].attentions[0].state_dict())
myUNet.up_blocks[7][0].load_state_dict(pipe_unet.up_blocks[2].resnets[1].state_dict())
myUNet.up_blocks[7][1].load_state_dict(pipe_unet.up_blocks[2].attentions[1].state_dict())
myUNet.up_blocks[8][0].load_state_dict(pipe_unet.up_blocks[2].resnets[2].state_dict())
myUNet.up_blocks[8][1].load_state_dict(pipe_unet.up_blocks[2].attentions[2].state_dict())
myUNet.up_blocks[8][2].load_state_dict(pipe_unet.up_blocks[2].upsamplers[0].state_dict())
myUNet.up_blocks[# % upblock 3
9][0].load_state_dict(pipe_unet.up_blocks[3].resnets[0].state_dict())
myUNet.up_blocks[9][1].load_state_dict(pipe_unet.up_blocks[3].attentions[0].state_dict())
myUNet.up_blocks[10][0].load_state_dict(pipe_unet.up_blocks[3].resnets[1].state_dict())
myUNet.up_blocks[10][1].load_state_dict(pipe_unet.up_blocks[3].attentions[1].state_dict())
myUNet.up_blocks[11][0].load_state_dict(pipe_unet.up_blocks[3].resnets[2].state_dict())
myUNet.up_blocks[11][1].load_state_dict(pipe_unet.up_blocks[3].attentions[2].state_dict()) myUNet.up_blocks[
= UNet_SD()
myunet = pipe.unet.cpu()
original_unet load_pipe_into_our_UNet(myunet, original_unet)
= myunet.cuda() pipe.unet
= "A ballerina riding a Harley Motorcycle, CG Art"
prompt with autocast("cuda"):
= pipe(prompt)["sample"][0]
image
"astronaut_rides_horse.png") image.save(