from huggingface_hub import notebook_login
notebook_login()
Stable Diffusion Playground
Binxu Wang
- Play with generating art from prompt.
- See the effect of the parameters for generating process.
- Visualizing the diffusion process and latents
- Looking under the hood of the sampling function.
- Inspect the internal network architecture of the components of Stable Diffusion.
Note you need to login to use the pre-trained weights! You should register an account at Huggingface, then use one Access Tokens
to login in the following block.
Make sure you have a runtime with GPU!
import torch
assert torch.cuda.is_available()
!nvidia-smi
Wed May 22 03:21:47 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 36C P8 10W / 70W | 3MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
Loading Stable Diffusion
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from diffusers import StableDiffusionPipeline
import matplotlib.pyplot as plt
def plt_show_image(image):
=(8, 8))
plt.figure(figsize
plt.imshow(image)"off")
plt.axis(
plt.tight_layout() plt.show()
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
assert torch.cuda.is_available()
!nvidia-smi
Wed May 22 03:22:00 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 36C P8 9W / 70W | 3MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
Here fp16
checkpoint is loaded just to save memory and compute time. if you have a great gpu, you can remove the line revision="fp16", torch_dtype=torch.float16
.
= StableDiffusionPipeline.from_pretrained(
pipe "CompVis/stable-diffusion-v1-4",
=True,
use_auth_token="fp16", torch_dtype=torch.float16
revision"cuda")
).to(# Disable the safety checkers
def dummy_checker(images, **kwargs): return images, [False] * images.shape[0]
= dummy_checker pipe.safety_checker
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with:
```
pip install accelerate
```
.
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning:
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
warnings.warn(
/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/pipeline_loading_utils.py:212: FutureWarning: You are loading the variant fp16 from CompVis/stable-diffusion-v1-4 via `revision='fp16'` even though you can load it via `variant=`fp16`. Loading model variants via `revision='fp16'` is deprecated and will be removed in diffusers v1. Please use `variant='fp16'` instead.
warnings.warn(
text_encoder/model.safetensors not found
Keyword arguments {'use_auth_token': True} are not expected by StableDiffusionPipeline and will be ignored.
/usr/local/lib/python3.10/dist-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.
warnings.warn(
def dummy_checker(images, **kwargs): return images, [False] * images.shape[0]
= dummy_checker pipe.safety_checker
Generative Playground
= "a lovely cat running in the desert in Van Gogh style, trending art."
prompt = pipe(prompt).images[0] # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)
image
# Now to display an image you can do either save it such as:
f"lovely_cat.png")
image.save( image
# "a lovely cat running in the desert in Van Gogh style, trending art." image
Fixing the random seed
= torch.Generator("cuda").manual_seed(1024)
generator
= "a sleeping cat enjoying the sunshine."
prompt = pipe(prompt, generator=generator).images[0] # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)
image
# Now to display an image you can do either save it such as:
f"lovely_cat_sun.png")
image.save( image
Changing (Denoising) Diffusion steps
= "a sleeping cat enjoying the sunshine."
prompt = pipe(prompt, num_inference_steps=25).images[0] # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)
image
# Now to display an image you can do either save it such as:
f"lovely_cat_sun.png")
image.save( image
Adding Negative prompt
Adding negative prompt can control what you do not want.
= "a sleeping cat enjoying the sunshine."
prompt = pipe(prompt, generator=generator,
image ="tree and leaves").images[0] # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)
negative_prompt
# Now to display an image you can do either save it such as:
f"lovely_cat_sun_no_trees.png")
image.save( image
Visualizing the Diffusion in Action
First import some utils for showing videos in colab.
# https://colab.research.google.com/github/google/mediapy/blob/main/mediapy_examples.ipynb#scrollTo=u0kuKXep2pfr
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import itertools
import math
import mediapy as media
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/1.6 MB ? eta -:--:-- ━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.2/1.6 MB 6.9 MB/s eta 0:00:01 ━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━ 0.6/1.6 MB 8.2 MB/s eta 0:00:01 ━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━ 1.0/1.6 MB 9.3 MB/s eta 0:00:01 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━ 1.5/1.6 MB 10.8 MB/s eta 0:00:01 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 9.8 MB/s eta 0:00:00
!mkdir diffprocess
= []
image_reservoir = []
latents_reservoir
@torch.no_grad()
def plot_show_callback(i, t, latents):
latents_reservoir.append(latents.detach().cpu())= pipe.vae.decode(1 / 0.18215 * latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()[0]
image # plt_show_image(image)
f"diffprocess/sample_{i:02d}.png", image)
plt.imsave(
image_reservoir.append(image)
@torch.no_grad()
def save_latents(i, t, latents):
latents_reservoir.append(latents.detach().cpu())
@torch.no_grad()
def saveimg_callback(i, t, latents):
latents_reservoir.append(latents.detach().cpu())= pipe.vae.decode(1 / 0.18215 * latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()[0]
image # plt_show_image(image)
f"diffprocess/sample_{i:02d}.png", image)
plt.imsave( image_reservoir.append(image)
These callback functions will save the image in the process into a list image_reservoir
and latents into latents_reservoir
.
= "a handsome cat dressed like Lincoln, trending art."
prompt with torch.no_grad():
= pipe(prompt, callback=plot_show_callback).images[0] # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)
image
# Now to display an image you can do either save it such as:
f"lovely_cat_lincoln.png")
image.save( image
= "a lovely cat running in the desert in Van Gogh style, trending art."
prompt with torch.no_grad():
= pipe(prompt, callback=plot_show_callback, callback_steps=1).images[0] # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)
image
# Now to display an image you can do either save it such as:
f"lovely_cat.png")
image.save(
image# video1 = media.moving_circle((65, 65), num_images=10)
=5) media.show_video(image_reservoir, fps
/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:851: FutureWarning: `callback_steps` is deprecated and will be removed in version 1.0.0. Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`
deprecate(
= "a lovely cat running in the desert in Van Gogh style, trending art."
prompt = []
image_reservoir with torch.no_grad():
= pipe(prompt, callback=plot_show_callback, callback_steps=1).images[0] # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)
image
# Now to display an image you can do either save it such as:
f"lovely_cat2.png")
image.save(
image=5) media.show_video(image_reservoir, fps
/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:845: FutureWarning: `callback` is deprecated and will be removed in version 1.0.0. Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`
deprecate(
/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:851: FutureWarning: `callback_steps` is deprecated and will be removed in version 1.0.0. Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`
deprecate(
= "a lovely cat running in the desert in Van Gogh style, trending art."
prompt = []
image_reservoir with torch.no_grad():
= pipe(prompt, callback=plot_show_callback, callback_steps=1).images[0] # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)
image
# Now to display an image you can do either save it such as:
f"lovely_cat3.png")
image.save(
image=5) media.show_video(image_reservoir, fps
Visualizing Image sequence
# video1 = media.moving_circle((65, 65), num_images=10)
=5) media.show_video(image_reservoir, fps
Visualizing latents
What about the latents? How do they change in the diffusion process?
0].shape latents_reservoir[
torch.Size([1, 4, 64, 64])
Since we have 4 channel in the latent tensor, we can choose to visualize any 3 of them as RGB. You can put any number in 0,1,2,3 in the Chan2RGB
list. see what it visualize
= [0,1,2]
Chan2RGB = [tsr[0,Chan2RGB].permute(1,2,0).numpy() for tsr in latents_reservoir] latents_np_seq
=5) media.show_video(latents_np_seq, fps
Write a simple text2img sampling function
Here I provide a simplified version of the sampling function! See what happened under the hood when you run pipe(prompt)
Feel free to print out tensors and record their shape within this function!
@torch.no_grad()
def generate_simplified(
= ["a lovely cat"],
prompt = [""],
negative_prompt = 50,
num_inference_steps = 7.5):
guidance_scale # do_classifier_free_guidance
= 1
batch_size = 512, 512
height, width = None
generator # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# get prompt text embeddings
= pipe.tokenizer(
text_inputs
prompt,="max_length",
padding=pipe.tokenizer.model_max_length,
max_length="pt",
return_tensors
)= text_inputs.input_ids
text_input_ids = pipe.text_encoder(text_input_ids.to(pipe.device))[0]
text_embeddings = text_embeddings.shape
bs_embed, seq_len, _
# get negative prompts text embedding
= text_input_ids.shape[-1]
max_length = pipe.tokenizer(
uncond_input
negative_prompt,="max_length",
padding=max_length,
max_length=True,
truncation="pt",
return_tensors
)= pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]
uncond_embeddings
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
= uncond_embeddings.shape[1]
seq_len = uncond_embeddings.repeat(batch_size, 1, 1)
uncond_embeddings = uncond_embeddings.view(batch_size, seq_len, -1)
uncond_embeddings
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
= torch.cat([uncond_embeddings, text_embeddings])
text_embeddings
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
= (batch_size, pipe.unet.in_channels, height // 8, width // 8)
latents_shape = text_embeddings.dtype
latents_dtype = torch.randn(latents_shape, generator=generator, device=pipe.device, dtype=latents_dtype)
latents
# set timesteps
pipe.scheduler.set_timesteps(num_inference_steps)# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
= pipe.scheduler.timesteps.to(pipe.device)
timesteps_tensor # scale the initial noise by the standard deviation required by the scheduler
= latents * pipe.scheduler.init_noise_sigma
latents
# Main diffusion process
for i, t in enumerate(pipe.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
= torch.cat([latents] * 2)
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
latent_model_input # predict the noise residual
= pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred # perform guidance
= noise_pred.chunk(2)
noise_pred_uncond, noise_pred_text = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred # compute the previous noisy sample x_t -> x_t-1
= pipe.scheduler.step(noise_pred, t, latents, ).prev_sample
latents
= 1 / 0.18215 * latents
latents = pipe.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
= image.cpu().permute(0, 2, 3, 1).float().numpy()
image return image
= generate_simplified(
image = ["a lovely cat"],
prompt = ["Sunshine"],)
negative_prompt 0]) plt_show_image(image[
= generate_simplified(
image = ["a cat dressed like a ballerina"],
prompt = [""],)
negative_prompt 0]) plt_show_image(image[
Image to Image Translation Playground
from diffusers import StableDiffusionImg2ImgPipeline
= "cuda"
device = "CompVis/stable-diffusion-v1-4"
model_path
= StableDiffusionImg2ImgPipeline.from_pretrained(
pipe
model_path,="fp16", torch_dtype=torch.float16,
revision=True
use_auth_token
)= pipe.to(device) pipe
import requests
from io import BytesIO
from PIL import Image
= "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
url
= requests.get(url)
response = Image.open(BytesIO(response.content)).convert("RGB")
init_img = init_img.resize((768, 512))
init_img init_img
= "A fantasy landscape, trending on artstation"
prompt = torch.Generator(device=device).manual_seed(1024)
generator with autocast("cuda"):
= pipe(prompt=prompt, init_image=init_img,
image =0.75, guidance_scale=7.5,
strength=generator).images[0]
generator
image
Write a simple img2img sampling function
@torch.no_grad()
def generate_img2img_simplified():
= ["A fantasy landscape, trending on artstation"]
prompt = [""]
negative_prompt = 0.5 # strength of the image conditioning
strength = 1
batch_size
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# set timesteps
pipe.scheduler.set_timesteps(num_inference_steps)
# get prompt text embeddings
= pipe.tokenizer(
text_inputs
prompt,="max_length",
padding=pipe.tokenizer.model_max_length,
max_length="pt",
return_tensors
)= text_inputs.input_ids
text_input_ids = pipe.text_encoder(text_input_ids.to(pipe.device))[0]
text_embeddings
# get unconditional embeddings for classifier free guidance
= negative_prompt
uncond_tokens = text_input_ids.shape[-1]
max_length = pipe.tokenizer(
uncond_input
uncond_tokens,="max_length",
padding=max_length,
max_length=True,
truncation="pt",
return_tensors
)= pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]
uncond_embeddings
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
= torch.cat([uncond_embeddings, text_embeddings])
text_embeddings
# encode the init image into latents and scale the latents
= text_embeddings.dtype
latents_dtype if isinstance(init_image, PIL.Image.Image):
= preprocess(init_image)
init_image = init_image.to(device=pipe.device, dtype=latents_dtype)
init_image = pipe.vae.encode(init_image).latent_dist
init_latent_dist = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
init_latents
# get the original timestep using init_timestep
= pipe.scheduler.config.get("steps_offset", 0)
offset = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
init_timestep
= pipe.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, device=pipe.device)
timesteps
# add noise to latents using the timesteps
= torch.randn(init_latents.shape, generator=generator, device=pipe.device, dtype=latents_dtype)
noise = pipe.scheduler.add_noise(init_latents, noise, timesteps)
init_latents
= init_latents
latents
= max(num_inference_steps - init_timestep + offset, 0)
t_start # Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
= pipe.scheduler.timesteps[t_start:].to(pipe.device)
timesteps
for i, t in enumerate(pipe.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
= torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
latent_model_input
# predict the noise residual
= pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred
# perform guidance
= noise_pred.chunk(2)
noise_pred_uncond, noise_pred_text = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred
# compute the previous noisy sample x_t -> x_t-1
= pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents
= 1 / 0.18215 * latents
latents = pipe.vae.decode(latents).sample
image
= (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image return image
The Internal Structure of Model
First I’ll define a function to help you see the internal structure of diffusion model. Use the deepest
argument to choose how much details you want to see.
def recursive_print(module, prefix="", depth=0, deepest=3):
"""Simulating print(module) for torch.nn.Modules
but with depth control. Print to the `deepest` level. `deepest=0` means no print
"""
if depth == 0:
print(f"[{type(module).__name__}]")
if depth >= deepest:
return
for name, child in module.named_children():
if len([*child.named_children()]) == 0:
print(f"{prefix}({name}): {child}")
else:
if isinstance(child, nn.ModuleList):
print(f"{prefix}({name}): {type(child).__name__} len={len(child)}")
else:
print(f"{prefix}({name}): {type(child).__name__}")
+ " ", depth + 1, deepest) recursive_print(child, prefix
Text encoding model
Now let’s look at the text encoding model.
# The text model or our CLIP model
=3) recursive_print(pipe.text_encoder, deepest
[CLIPTextModel]
(text_model): CLIPTextTransformer
(embeddings): CLIPTextEmbeddings
(token_embedding): Embedding(49408, 768)
(position_embedding): Embedding(77, 768)
(encoder): CLIPEncoder
(layers): ModuleList len=12
(final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
Let’s zoom in onto the encoder
, you can see it’s basically a series of Transformer blocks.
# the internal structure of a text encoder is basically a series of Transformer blocks
=3) recursive_print(pipe.text_encoder.text_model.encoder, deepest
[CLIPEncoder]
(layers): ModuleList len=12
(0): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(1): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(2): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(3): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(4): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(5): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(6): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(7): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(8): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(9): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(10): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(11): CLIPEncoderLayer
(self_attn): CLIPAttention
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
What’s inside one of these blocks? Can you guess from what you’ve learned about attention?
0], deepest=3) recursive_print(pipe.text_encoder.text_model.encoder.layers[
[CLIPEncoderLayer]
(self_attn): CLIPAttention
(k_proj): Linear(in_features=768, out_features=768, bias=True)
(v_proj): Linear(in_features=768, out_features=768, bias=True)
(q_proj): Linear(in_features=768, out_features=768, bias=True)
(out_proj): Linear(in_features=768, out_features=768, bias=True)
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
(activation_fn): QuickGELUActivation()
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
UNet model
Let’s dive into the most complicated part of the Stable Diffusion model, the UNet.
In the following block you can see the grand design of UNet, down_blocks
up_blocks
mid_block
.
=2) recursive_print(pipe.unet, deepest
[UNet2DConditionModel]
(conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_proj): Timesteps()
(time_embedding): TimestepEmbedding
(linear_1): Linear(in_features=320, out_features=1280, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=1280, out_features=1280, bias=True)
(down_blocks): ModuleList len=4
(0): CrossAttnDownBlock2D
(1): CrossAttnDownBlock2D
(2): CrossAttnDownBlock2D
(3): DownBlock2D
(up_blocks): ModuleList len=4
(0): UpBlock2D
(1): CrossAttnUpBlock2D
(2): CrossAttnUpBlock2D
(3): CrossAttnUpBlock2D
(mid_block): UNetMidBlock2DCrossAttn
(attentions): ModuleList len=1
(resnets): ModuleList len=2
(conv_norm_out): GroupNorm(32, 320, eps=1e-05, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
So what’s inside one CrossAttnDownBlock2D
, basically it’s the double sandwich of attentions
and resnets
. You can see similar things for the CrossAttnUpBlock2D
.
0], deepest=2) recursive_print(pipe.unet.down_blocks[
[CrossAttnDownBlock2D]
(attentions): ModuleList len=2
(0): SpatialTransformer
(1): SpatialTransformer
(resnets): ModuleList len=2
(0): ResnetBlock2D
(1): ResnetBlock2D
(downsamplers): ModuleList len=1
(0): Downsample2D
2], deepest=2) recursive_print(pipe.unet.up_blocks[
[CrossAttnUpBlock2D]
(attentions): ModuleList len=3
(0): SpatialTransformer
(1): SpatialTransformer
(2): SpatialTransformer
(resnets): ModuleList len=3
(0): ResnetBlock2D
(1): ResnetBlock2D
(2): ResnetBlock2D
(upsamplers): ModuleList len=1
(0): Upsample2D
Thus I think I’ve convinced you that understanding the SpatialTransformer
and the ResnetBlock2D
, then you basically understand the building block of the network.
Spatial Transformer
In the lecture, you have seen the spatial transformer is basically a composition of * self-attention, * cross-attention * a feed-forward network.
You can see this structure in one of the SpatialTransformer
layer.
0].attentions[0], deepest=3) recursive_print(pipe.unet.down_blocks[
[SpatialTransformer]
(norm): GroupNorm(32, 320, eps=1e-06, affine=True)
(proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList len=1
(0): BasicTransformerBlock
(attn1): CrossAttention
(ff): FeedForward
(attn2): CrossAttention
(norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
Next, let’s focus on the two attn1
,attn2
layers.
Quiz time: Can you guess what will be different between them? Specifically, which weight tensor will have a different shape? Which of these? * Q * K * V
# the img2img cross attention
0].attentions[0].transformer_blocks[0].attn1, deepest=3) recursive_print(pipe.unet.down_blocks[
# the img2text cross attention
0].attentions[0].transformer_blocks[0].attn2, deepest=3) recursive_print(pipe.unet.down_blocks[
Finally, we take a look at the feed forward network.
Here it used a special kind of activation function GEGLU
where the output of proj
is cut in half, one half controling a sigmoid gate, the other half create the activation. But most of the time simper activation are used like GeLU
SiLU
.
0].attentions[0].transformer_blocks[0].ff, deepest=3) recursive_print(pipe.unet.down_blocks[
[FeedForward]
(net): Sequential
(0): GEGLU
(proj): Linear(in_features=320, out_features=2560, bias=True)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=1280, out_features=320, bias=True)
ResnetBlock2D
ResNet Block is the simplest part of the Unet! Basically a CNN like ResNet.
0].resnets[0], deepest=3) recursive_print(pipe.unet.down_blocks[
[ResnetBlock2D]
(norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
(conv1): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)
(norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
Notice the time_emb_proj
this is the linear projection that output the time modulating signal for each channel.
Time Embedding
The function that creates the sin
cos
Fourier basis reads like this. (adapting the original code)
import math
def time_proj(time_steps, max_period: int = 10000, time_emb_dim=320):
if time_steps.ndim == 0:
= time_steps.unsqueeze(0)
time_steps = time_emb_dim // 2
half = torch.exp(- math.log(max_period)
frequencies * torch.arange(start=0, end=half, dtype=torch.float32) / half
=time_steps.device)
).to(device= time_steps[:, None].float() * frequencies[None, :]
angles return torch.cat([torch.cos(angles), torch.sin(angles)], dim=-1)
These outputs are sent to time_embedding
network which is also really simple, basically a 2 layer MLP, expanding its dimensionality.
recursive_print(pipe.unet.time_embedding)
[TimestepEmbedding]
(linear_1): Linear(in_features=320, out_features=1280, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=1280, out_features=1280, bias=True)
Autoencoder model
The Autoencoder is basically a ResNet based CNN.
# The autoencoderKL or VAE
=3) recursive_print(pipe.vae, deepest
[AutoencoderKL]
(encoder): Encoder
(conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(down_blocks): ModuleList len=4
(0): DownEncoderBlock2D
(1): DownEncoderBlock2D
(2): DownEncoderBlock2D
(3): DownEncoderBlock2D
(mid_block): UNetMidBlock2D
(attentions): ModuleList len=1
(resnets): ModuleList len=2
(conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(512, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder): Decoder
(conv_in): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(up_blocks): ModuleList len=4
(0): UpDecoderBlock2D
(1): UpDecoderBlock2D
(2): UpDecoderBlock2D
(3): UpDecoderBlock2D
(mid_block): UNetMidBlock2D
(attentions): ModuleList len=1
(resnets): ModuleList len=2
(conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(quant_conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
(post_quant_conv): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))
If you look at the VAE encoder
, you can find it’s very similar to the encoder side of the UNet
, but without any SpatialTransformer
.
It only has ResnetBlock2D
and downsamplers. It’s reasonable that the VAE is not modulated by word or time, so it doesn’t need cross attention.
=3) recursive_print(pipe.vae.encoder.down_blocks, deepest
[ModuleList]
(0): DownEncoderBlock2D
(resnets): ModuleList len=2
(0): ResnetBlock2D
(1): ResnetBlock2D
(downsamplers): ModuleList len=1
(0): Downsample2D
(1): DownEncoderBlock2D
(resnets): ModuleList len=2
(0): ResnetBlock2D
(1): ResnetBlock2D
(downsamplers): ModuleList len=1
(0): Downsample2D
(2): DownEncoderBlock2D
(resnets): ModuleList len=2
(0): ResnetBlock2D
(1): ResnetBlock2D
(downsamplers): ModuleList len=1
(0): Downsample2D
(3): DownEncoderBlock2D
(resnets): ModuleList len=2
(0): ResnetBlock2D
(1): ResnetBlock2D
So based on this knowledge, can you guess what will be different in the ResnetBlock2D
here from the ResnetBlock2D
in the Unet? What feature will be missing?
0].resnets[0], deepest=3) recursive_print(pipe.vae.encoder.down_blocks[
[ResnetBlock2D]
(norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()