Stable Diffusion - Part 1

Binxu Wang

Nov.2022

This notebook walk you through how to build your Unet architecture from scratch! All the network components are defined in a single notebook.

!pip install einops
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from collections import OrderedDict
from easydict import EasyDict as edict

Define our Unet Architecture

Build Our ResBlock

# backbone, Residual Block (Checked)
class ResBlock(nn.Module):
    def __init__(self, in_channel, time_emb_dim, out_channel=None, ):
        super().__init__()
        if out_channel is None:
            out_channel = in_channel
        self.norm1 = nn.GroupNorm(32, in_channel, eps=1e-05, affine=True)
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1)
        self.time_emb_proj = nn.Linear(in_features=time_emb_dim, out_features=out_channel, bias=True)
        self.norm2 = nn.GroupNorm(32, out_channel, eps=1e-05, affine=True)
        self.dropout = nn.Dropout(p=0.0, inplace=False)
        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1)
        self.nonlinearity = nn.SiLU()
        if in_channel == out_channel:
            self.conv_shortcut = nn.Identity()
        else:
            self.conv_shortcut = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1)

    def forward(self, x, t_emb, cond=None):
        # Input conv
        h = self.norm1(x)
        h = self.nonlinearity(h)
        h = self.conv1(h)
        # Time modulation
        if t_emb is not None:
            t_hidden = self.time_emb_proj(self.nonlinearity(t_emb))
            h = h + t_hidden[:, :, None, None]
        # Output conv
        h = self.norm2(h)
        h = self.nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)
        # Skip connection
        return h + self.conv_shortcut(x)


# UpSampling (Checked)
class UpSample(nn.Module):
    def __init__(self, channel, scale_factor=2, mode='nearest'):
        super(UpSample, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1, )

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
        return self.conv(x)


# DownSampling (Checked)
class DownSample(nn.Module):
    def __init__(self, channel, ):
        super(DownSample, self).__init__()
        self.conv = nn.Conv2d(channel, channel, kernel_size=3, stride=2, padding=1, )

    def forward(self, x):
        return self.conv(x)  # F.interpolate(x, scale_factor=1/self.scale_factor, mode=self.mode)

Build Our Attention / Transformer

# Self and Cross Attention mechanism (Checked)
class CrossAttention(nn.Module):
    """General implementation of Cross & Self Attention multi-head
    """
    def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=8, ):
        super(CrossAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.context_dim = context_dim
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.to_q = nn.Linear(hidden_dim, embed_dim, bias=False)
        if context_dim is None:
            # Self Attention
            self.to_k = nn.Linear(hidden_dim, embed_dim, bias=False)
            self.to_v = nn.Linear(hidden_dim, embed_dim, bias=False)
            self.self_attn = True
        else:
            # Cross Attention
            self.to_k = nn.Linear(context_dim, embed_dim, bias=False)
            self.to_v = nn.Linear(context_dim, embed_dim, bias=False)
            self.self_attn = False
        self.to_out = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim, bias=True)
        )  # this could be omitted

    def forward(self, tokens, context=None):
        Q = self.to_q(tokens)
        K = self.to_k(tokens) if self.self_attn else self.to_k(context)
        V = self.to_v(tokens) if self.self_attn else self.to_v(context)
        # print(Q.shape, K.shape, V.shape)
        # transform heads onto batch dimension
        Q = rearrange(Q, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim)
        K = rearrange(K, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim)
        V = rearrange(V, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim)
        # print(Q.shape, K.shape, V.shape)
        scoremats = torch.einsum("BTD,BSD->BTS", Q, K)
        attnmats = F.softmax(scoremats / math.sqrt(self.head_dim), dim=-1)
        # print(scoremats.shape, attnmats.shape, )
        ctx_vecs = torch.einsum("BTS,BSD->BTD", attnmats, V)
        # split the heads transform back to hidden.
        ctx_vecs = rearrange(ctx_vecs, '(B H) T D -> B T (H D)', H=self.num_heads, D=self.head_dim)
        # TODO: note this `to_out` is also a linear layer, could be in principle merged into the to_value layer.
        return self.to_out(ctx_vecs)
# Transformer layers
class TransformerBlock(nn.Module):
    def __init__(self, hidden_dim, context_dim, num_heads=8):
        super(TransformerBlock, self).__init__()
        self.attn1 = CrossAttention(hidden_dim, hidden_dim, num_heads=num_heads)  # self attention
        self.attn2 = CrossAttention(hidden_dim, hidden_dim, context_dim, num_heads=num_heads)  # cross attention

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)
        # to be compatible with Diffuser, could simplify.
        self.ff = FeedForward_GEGLU(hidden_dim, )
        # A more common version used in transformers.
        # self.ff = nn.Sequential(
        #     nn.Linear(hidden_dim, 3 * hidden_dim),
        #     nn.GELU(),
        #     nn.Linear(3 * hidden_dim, hidden_dim)
        # )

    def forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x


class GEGLU_proj(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(GEGLU_proj, self).__init__()
        self.proj = nn.Linear(in_dim, 2 * out_dim)

    def forward(self, x):
        x = self.proj(x)
        x, gates = x.chunk(2, dim=-1)
        return x * F.gelu(gates)


class FeedForward_GEGLU(nn.Module):
    # https://github.com/huggingface/diffusers/blob/95414bd6bf9bb34a312a7c55f10ba9b379f33890/src/diffusers/models/attention.py#L339
    # A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
    def __init__(self, hidden_dim, mult=4):
        super(FeedForward_GEGLU, self).__init__()
        self.net = nn.Sequential(
            GEGLU_proj(hidden_dim, mult * hidden_dim),
            nn.Dropout(0.0),
            nn.Linear(mult * hidden_dim, hidden_dim)
        )  # to be compatible with Diffuser, could simplify.

    def forward(self, x, ):
        return self.net(x)


class SpatialTransformer(nn.Module):
    def __init__(self, hidden_dim, context_dim, num_heads=8):
        super(SpatialTransformer, self).__init__()
        self.norm = nn.GroupNorm(32, hidden_dim, eps=1e-6, affine=True)
        self.proj_in = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1, stride=1, padding=0)
        # self.transformer = TransformerBlock(hidden_dim, context_dim, num_heads=8)
        self.transformer_blocks = nn.Sequential(
            TransformerBlock(hidden_dim, context_dim, num_heads=8)
        )  # to be compatible with Diffuser, could simplify.
        self.proj_out = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1, stride=1, padding=0)

    def forward(self, x, cond=None):
        b, c, h, w = x.shape
        x_in = x
        # context = rearrange(context, "b c T -> b T c")
        x = self.proj_in(self.norm(x))
        x = rearrange(x, "b c h w->b (h w) c")
        x = self.transformer_blocks[0](x, cond)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        return self.proj_out(x) + x_in

Container of ResBlock and Spatial Transformers

# Modified Container. Modify the nn.Sequential to time modulated Sequential
class TimeModulatedSequential(nn.Sequential):
    """ Modify the nn.Sequential to time modulated Sequential """
    def forward(self, x, t_emb, cond=None):
        for module in self:
            if isinstance(module, TimeModulatedSequential):
                x = module(x, t_emb, cond)
            elif isinstance(module, ResBlock):
                # For certain layers, add the time modulation.
                x = module(x, t_emb)
            elif isinstance(module, SpatialTransformer):
                # For certain layers, add the class conditioning.
                x = module(x, cond=cond)
            else:
                x = module(x)

        return x

Putting it Together into UNet!

class UNet_SD(nn.Module):

    def __init__(self, in_channels=4,
                 base_channels=320,
                 time_emb_dim=1280,
                 context_dim=768,
                 multipliers=(1, 2, 4, 4),
                 attn_levels=(0, 1, 2),
                 nResAttn_block=2,
                 cat_unet=True):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.in_channels = in_channels
        self.out_channels = in_channels
        base_channels = base_channels
        time_emb_dim = time_emb_dim
        context_dim = context_dim
        multipliers = multipliers
        nlevel = len(multipliers)
        self.base_channels = base_channels
        # attn_levels = [0, 1, 2]
        level_channels = [base_channels * mult for mult in multipliers]
        # Transform time into embedding
        self.time_embedding = nn.Sequential(OrderedDict({
            "linear_1": nn.Linear(base_channels, time_emb_dim, bias=True),
            "act": nn.SiLU(),
            "linear_2": nn.Linear(time_emb_dim, time_emb_dim, bias=True),
        })
        )  # 2 layer MLP
        self.conv_in = nn.Conv2d(self.in_channels, base_channels, 3, stride=1, padding=1)

        # Tensor Downsample blocks
        nResAttn_block = nResAttn_block
        self.down_blocks = TimeModulatedSequential()  # nn.ModuleList()
        self.down_blocks_channels = [base_channels]
        cur_chan = base_channels
        for i in range(nlevel):
            for j in range(nResAttn_block):
                res_attn_sandwich = TimeModulatedSequential()
                # input_chan of first ResBlock is different from the rest.
                res_attn_sandwich.append(ResBlock(in_channel=cur_chan, time_emb_dim=time_emb_dim, out_channel=level_channels[i]))
                if i in attn_levels:
                    # add attention except for the last level
                    res_attn_sandwich.append(SpatialTransformer(level_channels[i], context_dim=context_dim))
                cur_chan = level_channels[i]
                self.down_blocks.append(res_attn_sandwich)
                self.down_blocks_channels.append(cur_chan)
            # res_attn_sandwich.append(DownSample(level_channels[i]))
            if not i == nlevel - 1:
                self.down_blocks.append(TimeModulatedSequential(DownSample(level_channels[i])))
                self.down_blocks_channels.append(cur_chan)

        self.mid_block = TimeModulatedSequential(
            ResBlock(cur_chan, time_emb_dim),
            SpatialTransformer(cur_chan, context_dim=context_dim),
            ResBlock(cur_chan, time_emb_dim),
        )

        # Tensor Upsample blocks
        self.up_blocks = nn.ModuleList() # TimeModulatedSequential()  #
        for i in reversed(range(nlevel)):
            for j in range(nResAttn_block + 1):
                res_attn_sandwich = TimeModulatedSequential()
                res_attn_sandwich.append(ResBlock(in_channel=cur_chan + self.down_blocks_channels.pop(),
                                                  time_emb_dim=time_emb_dim, out_channel=level_channels[i]))
                if i in attn_levels:
                    res_attn_sandwich.append(SpatialTransformer(level_channels[i], context_dim=context_dim))
                cur_chan = level_channels[i]
                if j == nResAttn_block and i != 0:
                    res_attn_sandwich.append(UpSample(level_channels[i]))
                self.up_blocks.append(res_attn_sandwich)
        # Read out from tensor to latent space
        self.output = nn.Sequential(
            nn.GroupNorm(32, base_channels, ),
            nn.SiLU(),
            nn.Conv2d(base_channels, self.out_channels, 3, padding=1),
        )
        self.to(self.device)

    def time_proj(self, time_steps, max_period: int = 10000):
        if time_steps.ndim == 0:
            time_steps = time_steps.unsqueeze(0)
        half = self.base_channels // 2
        frequencies = torch.exp(- math.log(max_period)
                                * torch.arange(start=0, end=half, dtype=torch.float32) / half
                                ).to(device=time_steps.device)
        angles = time_steps[:, None].float() * frequencies[None, :]
        return torch.cat([torch.cos(angles), torch.sin(angles)], dim=-1)

    def forward(self, x, time_steps, cond=None, encoder_hidden_states=None, output_dict=True):
        if cond is None and encoder_hidden_states is not None:
            cond = encoder_hidden_states
        t_emb = self.time_proj(time_steps)
        t_emb = self.time_embedding(t_emb)
        x = self.conv_in(x)
        down_x_cache = [x]
        for module in self.down_blocks:
            x = module(x, t_emb, cond)
            down_x_cache.append(x)
        x = self.mid_block(x, t_emb, cond)
        for module in self.up_blocks:
            x = module(torch.cat((x, down_x_cache.pop()), dim=1), t_emb, cond)
        x = self.output(x)
        if output_dict:
            return edict(sample=x)
        else:
            return x

Unit test the components with the UNet implementation

Check ResNet

def test_ResBlock(pipe):
    tmp_blk = ResBlock(320, 1280).cuda()
    std_blk = pipe.unet.down_blocks[0].resnets[0]
    SD = std_blk.state_dict()
    tmp_blk.load_state_dict(SD)
    lat_tmp = torch.randn(3, 320, 32, 32)
    temb = torch.randn(3, 1280)
    with torch.no_grad():
        out = pipe.unet.down_blocks[0].resnets[0](lat_tmp.cuda(),temb.cuda())
        out2 = tmp_blk(lat_tmp.cuda(), temb.cuda())

    assert torch.allclose(out2, out)

test_ResBlock(pipe)
def test_downsampler(pipe):
    tmpdsp = DownSample(320).cuda()
    stddsp = pipe.unet.down_blocks[0].downsamplers[0]
    SD = stddsp.state_dict()
    tmpdsp.load_state_dict(SD)
    lat_tmp = torch.randn(3, 320, 32, 32)
    with torch.no_grad():
        out = stddsp(lat_tmp.cuda())
        out2 = tmpdsp(lat_tmp.cuda())

    assert torch.allclose(out2, out)

def test_upsampler(pipe):
    tmpusp = UpSample(1280).cuda()
    stdusp = pipe.unet.up_blocks[1].upsamplers[0]
    SD = stdusp.state_dict()
    tmpusp.load_state_dict(SD)
    lat_tmp = torch.randn(3, 1280, 32, 32)
    with torch.no_grad():
        out = stdusp(lat_tmp.cuda())
        out2 = tmpusp(lat_tmp.cuda())

    assert torch.allclose(out2, out)

test_downsampler(pipe)
test_upsampler(pipe)

Check Attention

def test_self_attention(pipe):
    tmpSattn = CrossAttention(320, 320, context_dim=None, num_heads=8).cuda()
    stdSattn = pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].attn1
    tmpSattn.load_state_dict(stdSattn.state_dict())  # checked
    with torch.no_grad():
        lat_tmp = torch.randn(3, 32, 320)
        out = stdSattn(lat_tmp.cuda())
        out2 = tmpSattn(lat_tmp.cuda())
    assert torch.allclose(out2, out)  # False

#%%
# Check Cross attention
def test_cross_attention(pipe):
    tmpXattn = CrossAttention(320, 320, context_dim=768, num_heads=8).cuda()
    stdXattn = pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].attn2
    tmpXattn.load_state_dict(stdXattn.state_dict())  # checked
    with torch.no_grad():
        lat_tmp = torch.randn(3, 32, 320)
        ctx_tmp = torch.randn(3, 5, 768)
        out = stdXattn(lat_tmp.cuda(), ctx_tmp.cuda())
        out2 = tmpXattn(lat_tmp.cuda(), ctx_tmp.cuda())
    assert torch.allclose(out2, out)  # False

test_self_attention(pipe)
test_cross_attention(pipe)

Check Transformer

#%% test TransformerBlock Implementation
def test_TransformerBlock(pipe):
    tmpTfmer = TransformerBlock(320, context_dim=768, num_heads=8).cuda()
    stdTfmer = pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0]
    tmpTfmer.load_state_dict(stdTfmer.state_dict())  # checked
    with torch.no_grad():
        lat_tmp = torch.randn(3, 32, 320)
        ctx_tmp = torch.randn(3, 5, 768)
        out = tmpTfmer(lat_tmp.cuda(), ctx_tmp.cuda())
        out2 = stdTfmer(lat_tmp.cuda(), ctx_tmp.cuda())
    assert torch.allclose(out2, out)  # False


#%% test SpatialTransformer Implementation
def test_SpatialTransformer(pipe):
    tmpSpTfmer = SpatialTransformer(320, context_dim=768, num_heads=8).cuda()
    stdSpTfmer = pipe.unet.down_blocks[0].attentions[0]
    tmpSpTfmer.load_state_dict(stdSpTfmer.state_dict())  # checked
    with torch.no_grad():
        lat_tmp = torch.randn(3, 320, 8, 8)
        ctx_tmp = torch.randn(3, 5, 768)
        out = tmpSpTfmer(lat_tmp.cuda(), ctx_tmp.cuda())
        out2 = stdSpTfmer(lat_tmp.cuda(), ctx_tmp.cuda())
    assert torch.allclose(out2, out)  # False

test_TransformerBlock(pipe)
test_SpatialTransformer(pipe)

Load Weights into our UNet!

!pip install diffusers transformers tokenizers
from huggingface_hub import notebook_login

notebook_login()
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import matplotlib.pyplot as plt
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    use_auth_token=True
).to("cuda")
def dummy_checker(images, **kwargs): return images, False
pipe.safety_checker = dummy_checker

Test the Entire UNet model

#@title Utils to load weights
def load_pipe_into_our_UNet(myUNet, pipe_unet):
    # load the pretrained weights from the pipe into our UNet.
    # Loading input and output layers.
    myUNet.output[0].load_state_dict(pipe_unet.conv_norm_out.state_dict())
    myUNet.output[2].load_state_dict(pipe_unet.conv_out.state_dict())
    myUNet.conv_in.load_state_dict(pipe_unet.conv_in.state_dict())
    myUNet.time_embedding.load_state_dict(pipe_unet.time_embedding.state_dict())
    #% Loading the down blocks
    myUNet.down_blocks[0][0].load_state_dict(pipe_unet.down_blocks[0].resnets[0].state_dict())
    myUNet.down_blocks[0][1].load_state_dict(pipe_unet.down_blocks[0].attentions[0].state_dict())
    myUNet.down_blocks[1][0].load_state_dict(pipe_unet.down_blocks[0].resnets[1].state_dict())
    myUNet.down_blocks[1][1].load_state_dict(pipe_unet.down_blocks[0].attentions[1].state_dict())
    myUNet.down_blocks[2][0].load_state_dict(pipe_unet.down_blocks[0].downsamplers[0].state_dict())

    myUNet.down_blocks[3][0].load_state_dict(pipe_unet.down_blocks[1].resnets[0].state_dict())
    myUNet.down_blocks[3][1].load_state_dict(pipe_unet.down_blocks[1].attentions[0].state_dict())
    myUNet.down_blocks[4][0].load_state_dict(pipe_unet.down_blocks[1].resnets[1].state_dict())
    myUNet.down_blocks[4][1].load_state_dict(pipe_unet.down_blocks[1].attentions[1].state_dict())
    myUNet.down_blocks[5][0].load_state_dict(pipe_unet.down_blocks[1].downsamplers[0].state_dict())

    myUNet.down_blocks[6][0].load_state_dict(pipe_unet.down_blocks[2].resnets[0].state_dict())
    myUNet.down_blocks[6][1].load_state_dict(pipe_unet.down_blocks[2].attentions[0].state_dict())
    myUNet.down_blocks[7][0].load_state_dict(pipe_unet.down_blocks[2].resnets[1].state_dict())
    myUNet.down_blocks[7][1].load_state_dict(pipe_unet.down_blocks[2].attentions[1].state_dict())
    myUNet.down_blocks[8][0].load_state_dict(pipe_unet.down_blocks[2].downsamplers[0].state_dict())

    myUNet.down_blocks[9][0].load_state_dict(pipe_unet.down_blocks[3].resnets[0].state_dict())
    myUNet.down_blocks[10][0].load_state_dict(pipe_unet.down_blocks[3].resnets[1].state_dict())

    #% Loading the middle blocks
    myUNet.mid_block[0].load_state_dict(pipe_unet.mid_block.resnets[0].state_dict())
    myUNet.mid_block[1].load_state_dict(pipe_unet.mid_block.attentions[0].state_dict())
    myUNet.mid_block[2].load_state_dict(pipe_unet.mid_block.resnets[1].state_dict())
    # % Loading the up blocks
    # upblock 0
    myUNet.up_blocks[0][0].load_state_dict(pipe_unet.up_blocks[0].resnets[0].state_dict())
    myUNet.up_blocks[1][0].load_state_dict(pipe_unet.up_blocks[0].resnets[1].state_dict())
    myUNet.up_blocks[2][0].load_state_dict(pipe_unet.up_blocks[0].resnets[2].state_dict())
    myUNet.up_blocks[2][1].load_state_dict(pipe_unet.up_blocks[0].upsamplers[0].state_dict())
    # % upblock 1
    myUNet.up_blocks[3][0].load_state_dict(pipe_unet.up_blocks[1].resnets[0].state_dict())
    myUNet.up_blocks[3][1].load_state_dict(pipe_unet.up_blocks[1].attentions[0].state_dict())
    myUNet.up_blocks[4][0].load_state_dict(pipe_unet.up_blocks[1].resnets[1].state_dict())
    myUNet.up_blocks[4][1].load_state_dict(pipe_unet.up_blocks[1].attentions[1].state_dict())
    myUNet.up_blocks[5][0].load_state_dict(pipe_unet.up_blocks[1].resnets[2].state_dict())
    myUNet.up_blocks[5][1].load_state_dict(pipe_unet.up_blocks[1].attentions[2].state_dict())
    myUNet.up_blocks[5][2].load_state_dict(pipe_unet.up_blocks[1].upsamplers[0].state_dict())
    # % upblock 2
    myUNet.up_blocks[6][0].load_state_dict(pipe_unet.up_blocks[2].resnets[0].state_dict())
    myUNet.up_blocks[6][1].load_state_dict(pipe_unet.up_blocks[2].attentions[0].state_dict())
    myUNet.up_blocks[7][0].load_state_dict(pipe_unet.up_blocks[2].resnets[1].state_dict())
    myUNet.up_blocks[7][1].load_state_dict(pipe_unet.up_blocks[2].attentions[1].state_dict())
    myUNet.up_blocks[8][0].load_state_dict(pipe_unet.up_blocks[2].resnets[2].state_dict())
    myUNet.up_blocks[8][1].load_state_dict(pipe_unet.up_blocks[2].attentions[2].state_dict())
    myUNet.up_blocks[8][2].load_state_dict(pipe_unet.up_blocks[2].upsamplers[0].state_dict())
    # % upblock 3
    myUNet.up_blocks[9][0].load_state_dict(pipe_unet.up_blocks[3].resnets[0].state_dict())
    myUNet.up_blocks[9][1].load_state_dict(pipe_unet.up_blocks[3].attentions[0].state_dict())
    myUNet.up_blocks[10][0].load_state_dict(pipe_unet.up_blocks[3].resnets[1].state_dict())
    myUNet.up_blocks[10][1].load_state_dict(pipe_unet.up_blocks[3].attentions[1].state_dict())
    myUNet.up_blocks[11][0].load_state_dict(pipe_unet.up_blocks[3].resnets[2].state_dict())
    myUNet.up_blocks[11][1].load_state_dict(pipe_unet.up_blocks[3].attentions[2].state_dict())
myunet = UNet_SD()
original_unet = pipe.unet.cpu()
load_pipe_into_our_UNet(myunet, original_unet)
pipe.unet = myunet.cuda()
prompt = "A ballerina riding a Harley Motorcycle, CG Art"
with autocast("cuda"):
    image = pipe(prompt)["sample"][0]

image.save("astronaut_rides_horse.png")