Stable Diffusion Playground

Binxu Wang

Note you need to login to use the pre-trained weights! You should register an account at Huggingface, then use one Access Tokens to login in the following block.

from huggingface_hub import notebook_login

notebook_login()

Make sure you have a runtime with GPU!

import torch
assert torch.cuda.is_available()
!nvidia-smi
Wed May 22 03:21:47 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8              10W /  70W |      3MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

Loading Stable Diffusion

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from diffusers import StableDiffusionPipeline
import matplotlib.pyplot as plt

def plt_show_image(image):
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.axis("off")
    plt.tight_layout()
    plt.show()
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
assert torch.cuda.is_available()
!nvidia-smi
Wed May 22 03:22:00 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8               9W /  70W |      3MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

Here fp16 checkpoint is loaded just to save memory and compute time. if you have a great gpu, you can remove the line revision="fp16", torch_dtype=torch.float16.

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    use_auth_token=True,
    revision="fp16", torch_dtype=torch.float16
).to("cuda")
# Disable the safety checkers
def dummy_checker(images, **kwargs): return images, [False] * images.shape[0]
pipe.safety_checker = dummy_checker
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: 
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/pipeline_loading_utils.py:212: FutureWarning: You are loading the variant fp16 from CompVis/stable-diffusion-v1-4 via `revision='fp16'` even though you can load it via `variant=`fp16`. Loading model variants via `revision='fp16'` is deprecated and will be removed in diffusers v1. Please use `variant='fp16'` instead.
  warnings.warn(
text_encoder/model.safetensors not found
Keyword arguments {'use_auth_token': True} are not expected by StableDiffusionPipeline and will be ignored.
/usr/local/lib/python3.10/dist-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.
  warnings.warn(
def dummy_checker(images, **kwargs): return images, [False] * images.shape[0]
pipe.safety_checker = dummy_checker

Generative Playground

prompt = "a lovely cat running in the desert in Van Gogh style, trending art."
image = pipe(prompt).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat.png")
image

image #  "a lovely cat running in the desert in Van Gogh style, trending art."

Fixing the random seed

generator = torch.Generator("cuda").manual_seed(1024)

prompt = "a sleeping cat enjoying the sunshine."
image = pipe(prompt, generator=generator).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat_sun.png")
image

Changing (Denoising) Diffusion steps

prompt = "a sleeping cat enjoying the sunshine."
image = pipe(prompt, num_inference_steps=25).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat_sun.png")
image

Adding Negative prompt

Adding negative prompt can control what you do not want.

prompt = "a sleeping cat enjoying the sunshine."
image = pipe(prompt, generator=generator,
               negative_prompt="tree and leaves").images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat_sun_no_trees.png")
image

Visualizing the Diffusion in Action

First import some utils for showing videos in colab.

# https://colab.research.google.com/github/google/mediapy/blob/main/mediapy_examples.ipynb#scrollTo=u0kuKXep2pfr
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import itertools
import math
import mediapy as media
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/1.6 MB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.2/1.6 MB 6.9 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.6/1.6 MB 8.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.6 MB 9.3 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.5/1.6 MB 10.8 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 9.8 MB/s eta 0:00:00

!mkdir diffprocess
image_reservoir = []
latents_reservoir = []


@torch.no_grad()
def plot_show_callback(i, t, latents):
    latents_reservoir.append(latents.detach().cpu())
    image = pipe.vae.decode(1 / 0.18215 * latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()[0]
    # plt_show_image(image)
    plt.imsave(f"diffprocess/sample_{i:02d}.png", image)
    image_reservoir.append(image)


@torch.no_grad()
def save_latents(i, t, latents):
    latents_reservoir.append(latents.detach().cpu())


@torch.no_grad()
def saveimg_callback(i, t, latents):
    latents_reservoir.append(latents.detach().cpu())
    image = pipe.vae.decode(1 / 0.18215 * latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()[0]
    # plt_show_image(image)
    plt.imsave(f"diffprocess/sample_{i:02d}.png", image)
    image_reservoir.append(image)

These callback functions will save the image in the process into a list image_reservoir and latents into latents_reservoir.

prompt = "a handsome cat dressed like Lincoln, trending art."
with torch.no_grad():
  image = pipe(prompt, callback=plot_show_callback).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat_lincoln.png")
image

prompt = "a lovely cat running in the desert in Van Gogh style, trending art."
with torch.no_grad():
  image = pipe(prompt, callback=plot_show_callback, callback_steps=1).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat.png")
image
# video1 = media.moving_circle((65, 65), num_images=10)
media.show_video(image_reservoir, fps=5)
/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:851: FutureWarning: `callback_steps` is deprecated and will be removed in version 1.0.0. Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`
  deprecate(
prompt = "a lovely cat running in the desert in Van Gogh style, trending art."
image_reservoir = []
with torch.no_grad():
  image = pipe(prompt, callback=plot_show_callback, callback_steps=1).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat2.png")
image
media.show_video(image_reservoir, fps=5)
/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:845: FutureWarning: `callback` is deprecated and will be removed in version 1.0.0. Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`
  deprecate(
/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:851: FutureWarning: `callback_steps` is deprecated and will be removed in version 1.0.0. Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`
  deprecate(
prompt = "a lovely cat running in the desert in Van Gogh style, trending art."
image_reservoir = []
with torch.no_grad():
  image = pipe(prompt, callback=plot_show_callback, callback_steps=1).images[0]  # image here is in [PIL format](https://pillow.readthedocs.io/en/stable/)

# Now to display an image you can do either save it such as:
image.save(f"lovely_cat3.png")
image
media.show_video(image_reservoir, fps=5)

Visualizing Image sequence

# video1 = media.moving_circle((65, 65), num_images=10)
media.show_video(image_reservoir, fps=5)

Visualizing latents

What about the latents? How do they change in the diffusion process?

latents_reservoir[0].shape
torch.Size([1, 4, 64, 64])

Since we have 4 channel in the latent tensor, we can choose to visualize any 3 of them as RGB. You can put any number in 0,1,2,3 in the Chan2RGB list. see what it visualize

Chan2RGB = [0,1,2]
latents_np_seq = [tsr[0,Chan2RGB].permute(1,2,0).numpy() for tsr in latents_reservoir]
media.show_video(latents_np_seq, fps=5)

Write a simple text2img sampling function

Here I provide a simplified version of the sampling function! See what happened under the hood when you run pipe(prompt)

Feel free to print out tensors and record their shape within this function!

@torch.no_grad()
def generate_simplified(
    prompt = ["a lovely cat"],
    negative_prompt = [""],
    num_inference_steps = 50,
    guidance_scale = 7.5):
    # do_classifier_free_guidance
    batch_size = 1
    height, width = 512, 512
    generator = None
    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.

    # get prompt text embeddings
    text_inputs = pipe.tokenizer(
        prompt,
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    text_embeddings = pipe.text_encoder(text_input_ids.to(pipe.device))[0]
    bs_embed, seq_len, _ = text_embeddings.shape

    # get negative prompts  text embedding
    max_length = text_input_ids.shape[-1]
    uncond_input = pipe.tokenizer(
        negative_prompt,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
    )
    uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]

    # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
    seq_len = uncond_embeddings.shape[1]
    uncond_embeddings = uncond_embeddings.repeat(batch_size, 1, 1)
    uncond_embeddings = uncond_embeddings.view(batch_size, seq_len, -1)

    # For classifier free guidance, we need to do two forward passes.
    # Here we concatenate the unconditional and text embeddings into a single batch
    # to avoid doing two forward passes
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    # get the initial random noise unless the user supplied it
    # Unlike in other pipelines, latents need to be generated in the target device
    # for 1-to-1 results reproducibility with the CompVis implementation.
    # However this currently doesn't work in `mps`.
    latents_shape = (batch_size, pipe.unet.in_channels, height // 8, width // 8)
    latents_dtype = text_embeddings.dtype
    latents = torch.randn(latents_shape, generator=generator, device=pipe.device, dtype=latents_dtype)

    # set timesteps
    pipe.scheduler.set_timesteps(num_inference_steps)
    # Some schedulers like PNDM have timesteps as arrays
    # It's more optimized to move all timesteps to correct device beforehand
    timesteps_tensor = pipe.scheduler.timesteps.to(pipe.device)
    # scale the initial noise by the standard deviation required by the scheduler
    latents = latents * pipe.scheduler.init_noise_sigma

    # Main diffusion process
    for i, t in enumerate(pipe.progress_bar(timesteps_tensor)):
        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
        # predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        # compute the previous noisy sample x_t -> x_t-1
        latents = pipe.scheduler.step(noise_pred, t, latents, ).prev_sample

    latents = 1 / 0.18215 * latents
    image = pipe.vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
    return image

image = generate_simplified(
    prompt = ["a lovely cat"],
    negative_prompt = ["Sunshine"],)
plt_show_image(image[0])

image = generate_simplified(
    prompt = ["a cat dressed like a ballerina"],
    negative_prompt = [""],)
plt_show_image(image[0])

Image to Image Translation Playground

from diffusers import StableDiffusionImg2ImgPipeline
device = "cuda"
model_path = "CompVis/stable-diffusion-v1-4"

pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    model_path,
    revision="fp16",  torch_dtype=torch.float16,
    use_auth_token=True
)
pipe = pipe.to(device)
import requests
from io import BytesIO
from PIL import Image

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

response = requests.get(url)
init_img = Image.open(BytesIO(response.content)).convert("RGB")
init_img = init_img.resize((768, 512))
init_img
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=device).manual_seed(1024)
with autocast("cuda"):
    image = pipe(prompt=prompt, init_image=init_img,
                 strength=0.75, guidance_scale=7.5,
                 generator=generator).images[0]

image

Write a simple img2img sampling function

@torch.no_grad()
def generate_img2img_simplified():
    prompt = ["A fantasy landscape, trending on artstation"]
    negative_prompt = [""]
    strength = 0.5 # strength of the image conditioning
    batch_size = 1

    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    # set timesteps
    pipe.scheduler.set_timesteps(num_inference_steps)

    # get prompt text embeddings
    text_inputs = pipe.tokenizer(
        prompt,
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    text_embeddings = pipe.text_encoder(text_input_ids.to(pipe.device))[0]

    # get unconditional embeddings for classifier free guidance
    uncond_tokens = negative_prompt
    max_length = text_input_ids.shape[-1]
    uncond_input = pipe.tokenizer(
        uncond_tokens,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
    )
    uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]

    # For classifier free guidance, we need to do two forward passes.
    # Here we concatenate the unconditional and text embeddings into a single batch
    # to avoid doing two forward passes
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    # encode the init image into latents and scale the latents
    latents_dtype = text_embeddings.dtype
    if isinstance(init_image, PIL.Image.Image):
        init_image = preprocess(init_image)
    init_image = init_image.to(device=pipe.device, dtype=latents_dtype)
    init_latent_dist = pipe.vae.encode(init_image).latent_dist
    init_latents = init_latent_dist.sample(generator=generator)
    init_latents = 0.18215 * init_latents

    # get the original timestep using init_timestep
    offset = pipe.scheduler.config.get("steps_offset", 0)
    init_timestep = int(num_inference_steps * strength) + offset
    init_timestep = min(init_timestep, num_inference_steps)

    timesteps = pipe.scheduler.timesteps[-init_timestep]
    timesteps = torch.tensor([timesteps] * batch_size, device=pipe.device)

    # add noise to latents using the timesteps
    noise = torch.randn(init_latents.shape, generator=generator, device=pipe.device, dtype=latents_dtype)
    init_latents = pipe.scheduler.add_noise(init_latents, noise, timesteps)

    latents = init_latents

    t_start = max(num_inference_steps - init_timestep + offset, 0)
    # Some schedulers like PNDM have timesteps as arrays
    # It's more optimized to move all timesteps to correct device beforehand
    timesteps = pipe.scheduler.timesteps[t_start:].to(pipe.device)

    for i, t in enumerate(pipe.progress_bar(timesteps)):
        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

    latents = 1 / 0.18215 * latents
    image = pipe.vae.decode(latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    return image

The Internal Structure of Model

First I’ll define a function to help you see the internal structure of diffusion model. Use the deepest argument to choose how much details you want to see.

def recursive_print(module, prefix="", depth=0, deepest=3):
    """Simulating print(module) for torch.nn.Modules
        but with depth control. Print to the `deepest` level. `deepest=0` means no print
    """
    if depth == 0:
        print(f"[{type(module).__name__}]")
    if depth >= deepest:
        return
    for name, child in module.named_children():
        if len([*child.named_children()]) == 0:
            print(f"{prefix}({name}): {child}")
        else:
            if isinstance(child, nn.ModuleList):
                print(f"{prefix}({name}): {type(child).__name__} len={len(child)}")
            else:
                print(f"{prefix}({name}): {type(child).__name__}")
        recursive_print(child, prefix + "  ", depth + 1, deepest)

Text encoding model

Now let’s look at the text encoding model.

# The text model or our CLIP model
recursive_print(pipe.text_encoder, deepest=3)
[CLIPTextModel]
(text_model): CLIPTextTransformer
  (embeddings): CLIPTextEmbeddings
    (token_embedding): Embedding(49408, 768)
    (position_embedding): Embedding(77, 768)
  (encoder): CLIPEncoder
    (layers): ModuleList len=12
  (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)

Let’s zoom in onto the encoder, you can see it’s basically a series of Transformer blocks.

# the internal structure of a text encoder is basically a series of Transformer blocks
recursive_print(pipe.text_encoder.text_model.encoder, deepest=3)
[CLIPEncoder]
(layers): ModuleList len=12
  (0): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (1): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (2): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (3): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (4): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (5): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (6): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (7): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (8): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (9): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (10): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (11): CLIPEncoderLayer
    (self_attn): CLIPAttention
    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): CLIPMLP
    (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)

What’s inside one of these blocks? Can you guess from what you’ve learned about attention?

recursive_print(pipe.text_encoder.text_model.encoder.layers[0], deepest=3)
[CLIPEncoderLayer]
(self_attn): CLIPAttention
  (k_proj): Linear(in_features=768, out_features=768, bias=True)
  (v_proj): Linear(in_features=768, out_features=768, bias=True)
  (q_proj): Linear(in_features=768, out_features=768, bias=True)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP
  (activation_fn): QuickGELUActivation()
  (fc1): Linear(in_features=768, out_features=3072, bias=True)
  (fc2): Linear(in_features=3072, out_features=768, bias=True)
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)

UNet model

Let’s dive into the most complicated part of the Stable Diffusion model, the UNet.

In the following block you can see the grand design of UNet, down_blocks up_blocks mid_block.

recursive_print(pipe.unet, deepest=2)
[UNet2DConditionModel]
(conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_proj): Timesteps()
(time_embedding): TimestepEmbedding
  (linear_1): Linear(in_features=320, out_features=1280, bias=True)
  (act): SiLU()
  (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
(down_blocks): ModuleList len=4
  (0): CrossAttnDownBlock2D
  (1): CrossAttnDownBlock2D
  (2): CrossAttnDownBlock2D
  (3): DownBlock2D
(up_blocks): ModuleList len=4
  (0): UpBlock2D
  (1): CrossAttnUpBlock2D
  (2): CrossAttnUpBlock2D
  (3): CrossAttnUpBlock2D
(mid_block): UNetMidBlock2DCrossAttn
  (attentions): ModuleList len=1
  (resnets): ModuleList len=2
(conv_norm_out): GroupNorm(32, 320, eps=1e-05, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

So what’s inside one CrossAttnDownBlock2D, basically it’s the double sandwich of attentions and resnets. You can see similar things for the CrossAttnUpBlock2D.

recursive_print(pipe.unet.down_blocks[0], deepest=2)
[CrossAttnDownBlock2D]
(attentions): ModuleList len=2
  (0): SpatialTransformer
  (1): SpatialTransformer
(resnets): ModuleList len=2
  (0): ResnetBlock2D
  (1): ResnetBlock2D
(downsamplers): ModuleList len=1
  (0): Downsample2D
recursive_print(pipe.unet.up_blocks[2], deepest=2)
[CrossAttnUpBlock2D]
(attentions): ModuleList len=3
  (0): SpatialTransformer
  (1): SpatialTransformer
  (2): SpatialTransformer
(resnets): ModuleList len=3
  (0): ResnetBlock2D
  (1): ResnetBlock2D
  (2): ResnetBlock2D
(upsamplers): ModuleList len=1
  (0): Upsample2D

Thus I think I’ve convinced you that understanding the SpatialTransformer and the ResnetBlock2D, then you basically understand the building block of the network.

Spatial Transformer

In the lecture, you have seen the spatial transformer is basically a composition of * self-attention, * cross-attention * a feed-forward network.

You can see this structure in one of the SpatialTransformer layer.

recursive_print(pipe.unet.down_blocks[0].attentions[0], deepest=3)
[SpatialTransformer]
(norm): GroupNorm(32, 320, eps=1e-06, affine=True)
(proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList len=1
  (0): BasicTransformerBlock
    (attn1): CrossAttention
    (ff): FeedForward
    (attn2): CrossAttention
    (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
(proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))

Next, let’s focus on the two attn1,attn2 layers.

Quiz time: Can you guess what will be different between them? Specifically, which weight tensor will have a different shape? Which of these? * Q * K * V

# the img2img cross attention
recursive_print(pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].attn1, deepest=3)
# the img2text cross attention
recursive_print(pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].attn2, deepest=3)

Finally, we take a look at the feed forward network.

Here it used a special kind of activation function GEGLU where the output of proj is cut in half, one half controling a sigmoid gate, the other half create the activation. But most of the time simper activation are used like GeLU SiLU.

recursive_print(pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].ff, deepest=3)
[FeedForward]
(net): Sequential
  (0): GEGLU
    (proj): Linear(in_features=320, out_features=2560, bias=True)
  (1): Dropout(p=0.0, inplace=False)
  (2): Linear(in_features=1280, out_features=320, bias=True)

ResnetBlock2D

ResNet Block is the simplest part of the Unet! Basically a CNN like ResNet.

recursive_print(pipe.unet.down_blocks[0].resnets[0], deepest=3)
[ResnetBlock2D]
(norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
(conv1): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)
(norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()

Notice the time_emb_proj this is the linear projection that output the time modulating signal for each channel.

Time Embedding

The function that creates the sin cos Fourier basis reads like this. (adapting the original code)

import math
def time_proj(time_steps, max_period: int = 10000, time_emb_dim=320):
    if time_steps.ndim == 0:
        time_steps = time_steps.unsqueeze(0)
    half = time_emb_dim // 2
    frequencies = torch.exp(- math.log(max_period)
                            * torch.arange(start=0, end=half, dtype=torch.float32) / half
                            ).to(device=time_steps.device)
    angles = time_steps[:, None].float() * frequencies[None, :]
    return torch.cat([torch.cos(angles), torch.sin(angles)], dim=-1)

These outputs are sent to time_embedding network which is also really simple, basically a 2 layer MLP, expanding its dimensionality.

recursive_print(pipe.unet.time_embedding)
[TimestepEmbedding]
(linear_1): Linear(in_features=320, out_features=1280, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=1280, out_features=1280, bias=True)

Autoencoder model

The Autoencoder is basically a ResNet based CNN.

# The autoencoderKL or VAE
recursive_print(pipe.vae, deepest=3)
[AutoencoderKL]
(encoder): Encoder
  (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (down_blocks): ModuleList len=4
    (0): DownEncoderBlock2D
    (1): DownEncoderBlock2D
    (2): DownEncoderBlock2D
    (3): DownEncoderBlock2D
  (mid_block): UNetMidBlock2D
    (attentions): ModuleList len=1
    (resnets): ModuleList len=2
  (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)
  (conv_act): SiLU()
  (conv_out): Conv2d(512, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder): Decoder
  (conv_in): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (up_blocks): ModuleList len=4
    (0): UpDecoderBlock2D
    (1): UpDecoderBlock2D
    (2): UpDecoderBlock2D
    (3): UpDecoderBlock2D
  (mid_block): UNetMidBlock2D
    (attentions): ModuleList len=1
    (resnets): ModuleList len=2
  (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)
  (conv_act): SiLU()
  (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(quant_conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
(post_quant_conv): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))

If you look at the VAE encoder, you can find it’s very similar to the encoder side of the UNet, but without any SpatialTransformer.

It only has ResnetBlock2D and downsamplers. It’s reasonable that the VAE is not modulated by word or time, so it doesn’t need cross attention.

recursive_print(pipe.vae.encoder.down_blocks, deepest=3)
[ModuleList]
(0): DownEncoderBlock2D
  (resnets): ModuleList len=2
    (0): ResnetBlock2D
    (1): ResnetBlock2D
  (downsamplers): ModuleList len=1
    (0): Downsample2D
(1): DownEncoderBlock2D
  (resnets): ModuleList len=2
    (0): ResnetBlock2D
    (1): ResnetBlock2D
  (downsamplers): ModuleList len=1
    (0): Downsample2D
(2): DownEncoderBlock2D
  (resnets): ModuleList len=2
    (0): ResnetBlock2D
    (1): ResnetBlock2D
  (downsamplers): ModuleList len=1
    (0): Downsample2D
(3): DownEncoderBlock2D
  (resnets): ModuleList len=2
    (0): ResnetBlock2D
    (1): ResnetBlock2D

So based on this knowledge, can you guess what will be different in the ResnetBlock2D here from the ResnetBlock2D in the Unet? What feature will be missing?

recursive_print(pipe.vae.encoder.down_blocks[0].resnets[0], deepest=3)
[ResnetBlock2D]
(norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()