the loss associated with learning the score function
an attention model for conditional generation
an autoencoder
We need a reasonably small dataset so that training does not take forever, so we will be working with MNIST, a set of 28x28 images of handwritten 0-9 digits. By the end, our model should be able to take in a number prompt (e.g. “4”) and output an image of the digit 4.
import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npimport functoolsfrom torch.optim import Adamfrom torch.utils.data import DataLoaderimport torchvision.transforms as transformsfrom torchvision.datasets import MNISTimport tqdmfrom tqdm.notebook import trange, tqdmfrom torch.optim.lr_scheduler import MultiplicativeLR, LambdaLRimport matplotlib.pyplot as pltfrom torchvision.utils import make_grid
1. Basic forward/reverse diffusion
This section reviews material from the previous MLFS seminar on diffusion generative models. Skip this if you already know the basics!
The gist is that our generative model will work in the following way. We will take our training examples (e.g. images) and corrupt them with noise until they are unrecognizable. Then we will learn to `denoise’ them, and potentially turn pure noise into something similar to what we started with.
Basic forward diffusion
Let’s start with forward diffusion. In the simplest case, the relevant diffusion equation is \[\begin{equation}
\begin{split}
x(t + \Delta t) = x(t) + \sigma(t) \sqrt{\Delta t} \ r
\end{split}
\end{equation}\] where \(\sigma(t) > 0\) is the ‘noise strength’, \(\Delta t\) is the step size, and \(r \sim \mathcal{N}(0, 1)\) is a standard normal random variable. In essence, we repeatedly add normally-distributed noise to our sample. Often, the noise strength \(\sigma(t)\) is chosen to depend on time (i.e. it gets higher as \(t\) gets larger).
Implement the missing part of 1D forward diffusion.
Hint: You can use np.random.randn() to generate random numbers.
# Simulate forward diffusion for N steps.def forward_diffusion_1D(x0, noise_strength_fn, t0, nsteps, dt):"""x0: initial sample value, scalar noise_strength_fn: function of time, outputs scalar noise strength t0: initial time nsteps: number of diffusion steps dt: time step size """# Initialize trajectory x = np.zeros(nsteps +1); x[0] = x0 t = t0 + np.arange(nsteps +1)*dt# Perform many Euler-Maruyama time stepsfor i inrange(nsteps): noise_strength = noise_strength_fn(t[i])############ YOUR CODE HERE (2 lines) random_normal = ... x[i+1] = ...#####################################return x, t# Example noise strength function: always equal to 1def noise_strength_constant(t):return1
We can reverse this diffusion process by a similar-looking update rule: \[\begin{equation}
x(t + \Delta t) = x(t) + \sigma(T - t)^2 \frac{d}{dx}\left[ \log p(x, T-t) \right] \Delta t + \sigma(T-t) \sqrt{\Delta t} \ r
\end{equation}\] where \[\begin{equation}
s(x, t) := \frac{d}{dx} \log p(x, t)
\end{equation}\] is called the score function. If we know this function, we can reverse the forward diffusion and turn noise into what we started with.
If our initial sample is always just one point at \(x_0 = 0\), and the noise strength is constant, then the score function is exactly equal to \[\begin{equation}
s(x, t) = - \frac{(x - x_0)}{\sigma^2 t} = - \frac{x}{\sigma^2 t} \ .
\end{equation}\]
Implement the missing part of 1D reverse diffusion. You will test it with the above score function.
Hint: You can use np.random.randn() to generate random numbers.
# Simulate forward diffusion for N steps.def reverse_diffusion_1D(x0, noise_strength_fn, score_fn, T, nsteps, dt):"""x0: initial sample value, scalar noise_strength_fn: function of time, outputs scalar noise strength score_fn: score function T: final time nsteps: number of diffusion steps dt: time step size """# Initialize trajectory x = np.zeros(nsteps +1); x[0] = x0 t = np.arange(nsteps +1)*dt# Perform many Euler-Maruyama time stepsfor i inrange(nsteps): noise_strength = noise_strength_fn(T - t[i]) score = score_fn(x[i], 0, noise_strength, T-t[i])############ YOUR CODE HERE (2 lines) random_normal = ... x[i+1] = ...#####################################return x, t# Example noise strength function: always equal to 1def score_simple(x, x0, noise_strength, t): score =- (x-x0)/((noise_strength**2)*t)return score
Run the cell below to see if your implementation works.
nsteps =100t0 =0dt =0.1noise_strength_fn = noise_strength_constantscore_fn = score_simplex0 =0T =11num_tries =5for i inrange(num_tries): x0 = np.random.normal(loc=0, scale=T) # draw from the noise distribution, which is diffusion for time T w noise strength 1 x, t = reverse_diffusion_1D(x0, noise_strength_fn, score_fn, T, nsteps, dt) plt.plot(t, x) plt.xlabel('time', fontsize=20) plt.ylabel('$x$', fontsize=20)plt.title('Reverse diffusion visualized', fontsize=20)plt.show()
Basic score function learning
In practice, we don’t already know the score function; instead, we have to learn it. One way to learn it is to train a neural network to `denoise’ samples via the denoising objective \[\begin{equation}
J := \mathbb{E}_{t\in (0, T), x_0 \sim p_0(x_0)}\left[ \ \Vert s(x_{noised}, t) \sigma^2(t) + (x_{noised} - x_0) \Vert^2_2 \ \right]
\end{equation}\] where \(p_0(x_0)\) is our target distribution (e.g. pictures of cats and dogs), and where \(x_{noised}\) is the target distribution sample \(x_0\) after one forward diffusion step, i.e. \(x_{noised} - x_0\) is just a normally-distributed random variable.
Here’s another way of writing the same thing, which is closer to the actual implementation. By substituting \[\begin{equation}
x_{noised} = x_0 + \sigma(t) \epsilon, \; \epsilon\sim \mathcal N(0,I)
\end{equation}\] We got this objective function \[\begin{equation}
J := \mathbb{E}_{t\in (0, T), x_0 \sim p_0(x_0), \epsilon \sim \mathcal N(0,I)}\left[ \ \Vert s(x_0 + \sigma(t) \epsilon, t) \sigma(t) + \epsilon \Vert^2_2 \ \right]
\end{equation}\] We will implement this for you. But it’s good to understand the intuition: we are learning to predict how much noise was added to each part of our sample. We should be able to do this well at every time \(t\) in the diffusion process, and for every \(x_0\) in our original (dogs/cats/etc) distribution.
2. Working with images via U-Nets
We just reviewed the very basics of diffusion models, with the takeaway that learning the score function allows us to turn pure noise into something interesting. We will learn to approximate the score function with a neural network. But when we are working with images, we need our neural network to ‘play nice’ with them, and to reflect inductive biases we associate with images.
A reasonable choice is to choose the neural network architecture to be that of a U-Net, which combines a CNN-like structure with downscaling/upscaling operations that help the network pay attention to features of images at different spatial scales.
Since the score function we’re trying to learn is a function of time, we also need to come up with a way to make sure our neural network properly responds to changes in time. For this purpose, we can use a time embedding.
In this section, you will fill in some missing U-Net pieces.
Helping our neural network work with time
The below code helps our neural network work with time via a time embedding. The idea is that, instead of just telling our network one number (the current time), we express the current time in terms of a large number of sinusoidal features. The hope is that, if we tell our network the current time in many different ways, it will more easily respond to changes in time.
This will enable us to successfully learn a time-dependent score function \(s(x, t)\).
#@title Get some modules to let time interactclass GaussianFourierProjection(nn.Module):"""Gaussian random features for encoding time steps."""def__init__(self, embed_dim, scale=30.):super().__init__()# Randomly sample weights (frequencies) during initialization.# These weights (frequencies) are fixed during optimization and are not trainable.self.W = nn.Parameter(torch.randn(embed_dim //2) * scale, requires_grad=False)def forward(self, x):# Cosine(2 pi freq x), Sine(2 pi freq x) x_proj = x[:, None] *self.W[None, :] *2* np.pireturn torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)class Dense(nn.Module):"""A fully connected layer that reshapes outputs to feature maps. Allow time repr to input additively from the side of a convolution layer. """def__init__(self, input_dim, output_dim):super().__init__()self.dense = nn.Linear(input_dim, output_dim)def forward(self, x):returnself.dense(x)[..., None, None]# this broadcast the 2d tensor to 4d, add the same value across space.
Defining the U-Net architecture
The below class defines a U-Net architecture. Fill in the missing pieces. (This shouldn’t be hard; mainly, it’s to get you to look at the structure.)
#@title Defining a time-dependent score-based model (double click to expand or collapse)class UNet(nn.Module):"""A time-dependent score-based model built upon U-Net architecture."""def__init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):"""Initialize a time-dependent score-based network. Args: marginal_prob_std: A function that takes time t and gives the standard deviation of the perturbation kernel p_{0t}(x(t) | x(0)). channels: The number of channels for feature maps of each resolution. embed_dim: The dimensionality of Gaussian random feature embeddings. """super().__init__()# Gaussian random feature embedding layer for timeself.time_embed = nn.Sequential( GaussianFourierProjection(embed_dim=embed_dim), nn.Linear(embed_dim, embed_dim) )# Encoding layers where the resolution decreasesself.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)self.dense1 = Dense(embed_dim, channels[0])self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)self.dense2 = Dense(embed_dim, channels[1])self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])########### YOUR CODE HERE (3 lines)self.conv3 = ...self.dense3 = ...self.gnorm3 = ...#########################self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)self.dense4 = Dense(embed_dim, channels[3])self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])# Decoding layers where the resolution increasesself.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)self.dense5 = Dense(embed_dim, channels[2])self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)self.dense6 = Dense(embed_dim, channels[1])self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)self.dense7 = Dense(embed_dim, channels[0])self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)# The swish activation functionself.act =lambda x: x * torch.sigmoid(x)self.marginal_prob_std = marginal_prob_stddef forward(self, x, t, y=None):# Obtain the Gaussian random feature embedding for t embed =self.act(self.time_embed(t))# Encoding path h1 =self.conv1(x) +self.dense1(embed)## Incorporate information from t## Group normalization h1 =self.act(self.gnorm1(h1)) h2 =self.conv2(h1) +self.dense2(embed) h2 =self.act(self.gnorm2(h2))########## YOUR CODE HERE (2 lines) h3 = ... # conv, dense# apply activation function h3 =self.conv3(h2) +self.dense3(embed) h3 =self.act(self.gnorm3(h3))############ h4 =self.conv4(h3) +self.dense4(embed) h4 =self.act(self.gnorm4(h4))# Decoding path h =self.tconv4(h4)## Skip connection from the encoding path h +=self.dense5(embed) h =self.act(self.tgnorm4(h)) h =self.tconv3(torch.cat([h, h3], dim=1)) h +=self.dense6(embed) h =self.act(self.tgnorm3(h)) h =self.tconv2(torch.cat([h, h2], dim=1)) h +=self.dense7(embed) h =self.act(self.tgnorm2(h)) h =self.tconv1(torch.cat([h, h1], dim=1))# Normalize output h = h /self.marginal_prob_std(t)[:, None, None, None]return h
Below is code for an alternate U-Net architecture. Apparently, diffusion models can be successful with somewhat different architectural details. (Note that the differences from the above class are kind of subtle, though.)
Upper one, concatenate the tensor from the down block for skip connection.
Lower one, directly add the tensor from the down blocks for skip connection.
A special case of the upper
#@title Alternative time-dependent score-based model (double click to expand or collapse)class UNet_res(nn.Module):"""A time-dependent score-based model built upon U-Net architecture."""def__init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):"""Initialize a time-dependent score-based network. Args: marginal_prob_std: A function that takes time t and gives the standard deviation of the perturbation kernel p_{0t}(x(t) | x(0)). channels: The number of channels for feature maps of each resolution. embed_dim: The dimensionality of Gaussian random feature embeddings. """super().__init__()# Gaussian random feature embedding layer for timeself.time_embed = nn.Sequential( GaussianFourierProjection(embed_dim=embed_dim), nn.Linear(embed_dim, embed_dim) )# Encoding layers where the resolution decreasesself.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)self.dense1 = Dense(embed_dim, channels[0])self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)self.dense2 = Dense(embed_dim, channels[1])self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)self.dense3 = Dense(embed_dim, channels[2])self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)self.dense4 = Dense(embed_dim, channels[3])self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])# Decoding layers where the resolution increasesself.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)self.dense5 = Dense(embed_dim, channels[2])self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1) # + channels[2]self.dense6 = Dense(embed_dim, channels[1])self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, output_padding=1) # + channels[1]self.dense7 = Dense(embed_dim, channels[0])self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])self.tconv1 = nn.ConvTranspose2d(channels[0], 1, 3, stride=1) # + channels[0]# The swish activation functionself.act =lambda x: x * torch.sigmoid(x)self.marginal_prob_std = marginal_prob_stddef forward(self, x, t, y=None):# Obtain the Gaussian random feature embedding for t embed =self.act(self.time_embed(t))# Encoding path h1 =self.conv1(x) +self.dense1(embed)## Incorporate information from t## Group normalization h1 =self.act(self.gnorm1(h1)) h2 =self.conv2(h1) +self.dense2(embed) h2 =self.act(self.gnorm2(h2)) h3 =self.conv3(h2) +self.dense3(embed) h3 =self.act(self.gnorm3(h3)) h4 =self.conv4(h3) +self.dense4(embed) h4 =self.act(self.gnorm4(h4))# Decoding path h =self.tconv4(h4)## Skip connection from the encoding path h +=self.dense5(embed) h =self.act(self.tgnorm4(h)) h =self.tconv3(h + h3) h +=self.dense6(embed) h =self.act(self.tgnorm3(h)) h =self.tconv2(h + h2) h +=self.dense7(embed) h =self.act(self.tgnorm2(h)) h =self.tconv1(h + h1)# Normalize output h = h /self.marginal_prob_std(t)[:, None, None, None]return h
Tips: When you feel uncertain about the shape of the tensors throughout the layers, define the layers outside and see the shapes. This format could be helpful.
net = nn.Sequential( nn.Conv2d(...), nn.ConvTranspose2d(...),)x = torch.randn(...)for l in net: x = layer(x)print(x.shape)
3. Train the U-Net to learn a score function
Let’s combine the U-Net we just defined with a way to learn the score function. We need to define a loss function, and then train a neural network in the usual way.
In the next cell, we will define the specific forward diffusion process \[
dx = \sigma^t dw
\] the “diffusion constant” \(\sigma^t\) adds noise to \(x\) samples with exponentially increasing noise scale.
Given this forward process, and given a starting \(x(0)\) we have an analytically solution for the sample at any time \(x(t)\)\[
p(x(t)|x(0))=\mathcal N(x(0),\sigma(t)^2)
\] We call \(\sigma(t)\) the marginal std, i.e. the standard deviation of the conditional distribution. In this specific case, \[
\sigma^2(t)=\int_0^t (\sigma^\tau ) ^2d\tau = \int_0^t \sigma^{2\tau} d\tau = \frac{\sigma^{2t}-1}{2\log \sigma }
\]
#@title Diffusion constant and noise strengthdevice ='cuda'#@param ['cuda', 'cpu'] {'type':'string'}def marginal_prob_std(t, sigma):"""Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$. Args: t: A vector of time steps. sigma: The $\sigma$ in our SDE. Returns: The standard deviation. """ t = torch.tensor(t, device=device)return torch.sqrt((sigma**(2* t) -1.) /2./ np.log(sigma))def diffusion_coeff(t, sigma):"""Compute the diffusion coefficient of our SDE. Args: t: A vector of time steps. sigma: The $\sigma$ in our SDE. Returns: The vector of diffusion coefficients. """return torch.tensor(sigma**t, device=device)sigma =25.0#@param {'type':'number'}marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
Defining the loss function
The loss function is mostly defined below. You need to add one part: sample random noise with strength std[:, None, None, None], and make sure it has the same shape as \(\mathbf{x}\). Then use this to perturb \(\mathbf{x}\).
Hint: torch.randn_like() may be useful.
def loss_fn(model, x, marginal_prob_std, eps=1e-5):"""The loss function for training score-based generative models. Args: model: A PyTorch model instance that represents a time-dependent score-based model. x: A mini-batch of training data. marginal_prob_std: A function that gives the standard deviation of the perturbation kernel. eps: A tolerance value for numerical stability. """# Sample time uniformly in 0, 1 random_t = torch.rand(x.shape[0], device=x.device) * (1.- eps) + eps# Find the noise std at the time `t` std = marginal_prob_std(random_t)####### YOUR CODE HERE (2 lines) z = ... # get normally distributed noise perturbed_x = x + ...############## score = model(perturbed_x, random_t) loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))return loss
Defining the sampler
#@title Sampler codenum_steps =500#@param {'type':'integer'}def Euler_Maruyama_sampler(score_model, marginal_prob_std, diffusion_coeff, batch_size=64, x_shape=(1, 28, 28), num_steps=num_steps, device='cuda', eps=1e-3, y=None):"""Generate samples from score-based models with the Euler-Maruyama solver. Args: score_model: A PyTorch model that represents the time-dependent score-based model. marginal_prob_std: A function that gives the standard deviation of the perturbation kernel. diffusion_coeff: A function that gives the diffusion coefficient of the SDE. batch_size: The number of samplers to generate by calling this function once. num_steps: The number of sampling steps. Equivalent to the number of discretized time steps. device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs. eps: The smallest time step for numerical stability. Returns: Samples. """ t = torch.ones(batch_size, device=device) init_x = torch.randn(batch_size, *x_shape, device=device) \* marginal_prob_std(t)[:, None, None, None] time_steps = torch.linspace(1., eps, num_steps, device=device) step_size = time_steps[0] - time_steps[1] x = init_xwith torch.no_grad():for time_step in tqdm(time_steps): batch_time_step = torch.ones(batch_size, device=device) * time_step g = diffusion_coeff(batch_time_step) mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step, y=y) * step_size x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)# Do not include any noise in the last sampling step.return mean_x
Training on MNIST
We will train on MNIST, and learn to generate pictures that look like 0-9 digits. No code to fill in here; just run it and see if it works!
In the following code, the loss could descent to ~ 40-50.
#@title Training (double click to expand or collapse)score_model = torch.nn.DataParallel(UNet(marginal_prob_std=marginal_prob_std_fn))score_model = score_model.to(device)n_epochs =50#@param {'type':'integer'}## size of a mini-batchbatch_size =2048#@param {'type':'integer'}## learning ratelr=5e-4#@param {'type':'number'}dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)optimizer = Adam(score_model.parameters(), lr=lr)tqdm_epoch = trange(n_epochs)for epoch in tqdm_epoch: avg_loss =0. num_items =0for x, y in tqdm(data_loader): x = x.to(device) loss = loss_fn(score_model, x, marginal_prob_std_fn) optimizer.zero_grad() loss.backward() optimizer.step() avg_loss += loss.item() * x.shape[0] num_items += x.shape[0]# Print the averaged training loss so far. tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))# Update the checkpoint after each epoch of training. torch.save(score_model.state_dict(), 'ckpt.pth')
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
del sys.path[0]
In the following code, the loss could descent to ~ 25, with relatively good quality.
#@title Training the alternate U-Net model (double click to expand or collapse)score_model = torch.nn.DataParallel(UNet_res(marginal_prob_std=marginal_prob_std_fn))score_model = score_model.to(device)n_epochs =75#@param {'type':'integer'}## size of a mini-batchbatch_size =1024#@param {'type':'integer'}## learning ratelr=10e-4#@param {'type':'number'}dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)optimizer = Adam(score_model.parameters(), lr=lr)scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 0.98** epoch))tqdm_epoch = trange(n_epochs)for epoch in tqdm_epoch: avg_loss =0. num_items =0for x, y in data_loader: x = x.to(device) loss = loss_fn(score_model, x, marginal_prob_std_fn) optimizer.zero_grad() loss.backward() optimizer.step() avg_loss += loss.item() * x.shape[0] num_items += x.shape[0] scheduler.step() lr_current = scheduler.get_last_lr()[0]print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))# Print the averaged training loss so far. tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))# Update the checkpoint after each epoch of training. torch.save(score_model.state_dict(), 'ckpt_res.pth')
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
0 Average Loss: 1070.723943 lr 9.8e-04
1 Average Loss: 312.869164 lr 9.6e-04
2 Average Loss: 209.593512 lr 9.4e-04
3 Average Loss: 159.604592 lr 9.2e-04
4 Average Loss: 131.208202 lr 9.0e-04
5 Average Loss: 112.011338 lr 8.9e-04
6 Average Loss: 99.231037 lr 8.7e-04
7 Average Loss: 89.299207 lr 8.5e-04
8 Average Loss: 81.803251 lr 8.3e-04
9 Average Loss: 75.683446 lr 8.2e-04
10 Average Loss: 70.674667 lr 8.0e-04
11 Average Loss: 66.624008 lr 7.8e-04
12 Average Loss: 62.762296 lr 7.7e-04
13 Average Loss: 59.316022 lr 7.5e-04
14 Average Loss: 56.370946 lr 7.4e-04
15 Average Loss: 54.499845 lr 7.2e-04
16 Average Loss: 51.994298 lr 7.1e-04
17 Average Loss: 49.876137 lr 7.0e-04
18 Average Loss: 48.393268 lr 6.8e-04
19 Average Loss: 46.500402 lr 6.7e-04
20 Average Loss: 45.150150 lr 6.5e-04
21 Average Loss: 43.693895 lr 6.4e-04
22 Average Loss: 42.697722 lr 6.3e-04
23 Average Loss: 41.578316 lr 6.2e-04
24 Average Loss: 40.365930 lr 6.0e-04
25 Average Loss: 39.272062 lr 5.9e-04
26 Average Loss: 39.007713 lr 5.8e-04
27 Average Loss: 38.007527 lr 5.7e-04
28 Average Loss: 37.280869 lr 5.6e-04
29 Average Loss: 36.785621 lr 5.5e-04
30 Average Loss: 35.737501 lr 5.3e-04
31 Average Loss: 35.325589 lr 5.2e-04
32 Average Loss: 34.618500 lr 5.1e-04
33 Average Loss: 34.061596 lr 5.0e-04
34 Average Loss: 34.175156 lr 4.9e-04
35 Average Loss: 33.403281 lr 4.8e-04
36 Average Loss: 32.976893 lr 4.7e-04
37 Average Loss: 32.414113 lr 4.6e-04
38 Average Loss: 32.455040 lr 4.5e-04
39 Average Loss: 31.908144 lr 4.5e-04
40 Average Loss: 31.733034 lr 4.4e-04
41 Average Loss: 31.166364 lr 4.3e-04
42 Average Loss: 30.933214 lr 4.2e-04
43 Average Loss: 30.873448 lr 4.1e-04
44 Average Loss: 30.315580 lr 4.0e-04
45 Average Loss: 30.131343 lr 3.9e-04
46 Average Loss: 29.691356 lr 3.9e-04
47 Average Loss: 29.471414 lr 3.8e-04
48 Average Loss: 29.218028 lr 3.7e-04
49 Average Loss: 29.266639 lr 3.6e-04
50 Average Loss: 28.840175 lr 3.6e-04
51 Average Loss: 28.906321 lr 3.5e-04
52 Average Loss: 28.433899 lr 3.4e-04
53 Average Loss: 28.146673 lr 3.4e-04
54 Average Loss: 27.990703 lr 3.3e-04
55 Average Loss: 27.869217 lr 3.2e-04
56 Average Loss: 27.822635 lr 3.2e-04
57 Average Loss: 27.884746 lr 3.1e-04
58 Average Loss: 27.158734 lr 3.0e-04
59 Average Loss: 27.297360 lr 3.0e-04
60 Average Loss: 26.998111 lr 2.9e-04
61 Average Loss: 27.063771 lr 2.9e-04
62 Average Loss: 26.929028 lr 2.8e-04
63 Average Loss: 26.607439 lr 2.7e-04
64 Average Loss: 26.404006 lr 2.7e-04
65 Average Loss: 26.331446 lr 2.6e-04
66 Average Loss: 26.207267 lr 2.6e-04
67 Average Loss: 26.259585 lr 2.5e-04
68 Average Loss: 26.134909 lr 2.5e-04
69 Average Loss: 25.886961 lr 2.4e-04
70 Average Loss: 25.874144 lr 2.4e-04
71 Average Loss: 25.843274 lr 2.3e-04
72 Average Loss: 25.532615 lr 2.3e-04
73 Average Loss: 25.499755 lr 2.2e-04
74 Average Loss: 25.274344 lr 2.2e-04
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:27: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:27: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
4. Using attention to get conditional generation to work
In addition to generating images of e.g. 0-9 digits, we would like to do conditional generation: we would like to specify which digit we would like to generate an image of, for example.
Attention models, while not strictly necessary for conditional generation, have proven useful for getting it to work well. In this section, you will implement parts of an attention model.
from einops import rearrangeimport math
“Word embedding” of digits
Here, instead of using a fancy CLIP model, we just define our own vector representations of the digits 0-9.
We used the nn.Embedding layer to turn 0-9 index into vectors.
class WordEmbed(nn.Module):def__init__(self, vocab_size, embed_dim):super(WordEmbed, self).__init__()self.embed = nn.Embedding(vocab_size+1, embed_dim)def forward(self, ids):returnself.embed(ids)
Let’s develop our Attention layer
We usually implement attention models using 3 parts: * CrossAttention Write a module to do self / cross attention for sequences. * TransformerBlock Combine self/cross-attention and a feed-forward neural network. * SpatialTransformer To use attention in a U-net, transform the spatial tensor to a sequential tensor and back.
You will implement a part of the CrossAttention class, and you will also pick where to add attention in your U-Net.
Here’s a brief review of the mathematics of attention models. QKV (query-key-value) attention models represent queries, keys, and values as vectors. These are tools that help us relate words/images on one side of a translation task to the other side.
These are linearly related to \(\mathbf{e}\) vectors (which represent the hidden state of the encoder) and \(\mathbf{h}\) vectors (which represent the hidden state of the decoder):
To figure out what to ‘pay attention’ to, we compute the inner product (i.e. similarity) of each key \(\mathbf{k}\) and query \(\mathbf{q}\). To get typical values that are not too big or too small, we normalize by the length/dimension of the query vectors \(\mathbf{q}_i\).
The final attention distribution comes from softmax-ing all of this:
The attention distribution is used to pick out some relevant combination of features. For example, in translating the phrase “European Union” from English to French, getting the correct answer (“Union européenne”) requires paying attention to both words at once, instead of trying to translate each word totally separately. Mathematically, we weight values \(\mathbf{v}_j\) by the attention distribution:
Implement the missing pieces of the CrossAttention class below. Also implement the missing part of the TransformerBlock class.
Note that the relevant matrix multiplications can be performed using torch.einsum. For example, to multiply an \(M \times N\) matrix \(A\) together with an \(N \times N\) matrix \(B\), i.e. to obtain \(A B\), we can write:
torch.einsum(ij,jk -> ik ,A, B)
If, instead, we wanted to compute \(A B^T\), we could write
torch.einsum(ij,kj->ik ,A, B)
The library takes care of moving the dimensions around properly. For machine learning, it is often important to do operations like matrix multiplication in batches. In this case, you can might have tensors instead of matrices, but you can write a very similar expression:
torch.einsum(bij,bkj ,A, B)
where b here is the index that describes which batch elements we are talking about. As a final point, you can use whatever letters you like instead of i, j, etc.
Hint: F.softmax may also be helpful!
Attention Modules
class CrossAttention(nn.Module):def__init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=1,):""" Note: For simplicity reason, we just implemented 1-head attention. Feel free to implement multi-head attention! with fancy tensor manipulations. """super(CrossAttention, self).__init__()self.hidden_dim = hidden_dimself.context_dim = context_dimself.embed_dim = embed_dimself.query = nn.Linear(hidden_dim, embed_dim, bias=False)if context_dim isNone:self.self_attn =Trueself.key = nn.Linear(hidden_dim, embed_dim, bias=False) ###########self.value = nn.Linear(hidden_dim, hidden_dim, bias=False) ############else:self.self_attn =Falseself.key = nn.Linear(context_dim, embed_dim, bias=False) #############self.value = nn.Linear(context_dim, hidden_dim, bias=False) ############def forward(self, tokens, context=None):# tokens: with shape [batch, sequence_len, hidden_dim]# context: with shape [batch, contex_seq_len, context_dim]ifself.self_attn: Q =self.query(tokens) K =self.key(tokens) V =self.value(tokens)else:# implement Q, K, V for the Cross attention Q = ... K = ... V = ...#print(Q.shape, K.shape, V.shape)####### YOUR CODE HERE (2 lines) scoremats = ... # inner product of Q and K, a tensor attnmats = ... # softmax of scoremats#print(scoremats.shape, attnmats.shape, ) ctx_vecs = torch.einsum("BTS,BSH->BTH", attnmats, V) # weighted average value vectors by attnmatsreturn ctx_vecsclass TransformerBlock(nn.Module):"""The transformer block that combines self-attn, cross-attn and feed forward neural net"""def__init__(self, hidden_dim, context_dim):super(TransformerBlock, self).__init__()self.attn_self = CrossAttention(hidden_dim, hidden_dim, )self.attn_cross = CrossAttention(hidden_dim, hidden_dim, context_dim)self.norm1 = nn.LayerNorm(hidden_dim)self.norm2 = nn.LayerNorm(hidden_dim)self.norm3 = nn.LayerNorm(hidden_dim)# implement a 2 layer MLP with K*hidden_dim hidden units, and nn.GeLU nonlinearity #######self.ffn = nn.Sequential(# YOUR CODE HERE ################## ... )def forward(self, x, context=None):# Notice the + x as residue connections x =self.attn_self(self.norm1(x)) + x# Notice the + x as residue connections x =self.attn_cross(self.norm2(x), context=context) + x# Notice the + x as residue connections x =self.ffn(self.norm3(x)) + xreturn xclass SpatialTransformer(nn.Module):def__init__(self, hidden_dim, context_dim):super(SpatialTransformer, self).__init__()self.transformer = TransformerBlock(hidden_dim, context_dim)def forward(self, x, context=None): b, c, h, w = x.shape x_in = x# Combine the spatial dimensions and move the channel dimen to the end x = rearrange(x, "b c h w->b (h w) c")# Apply the sequence transformer x =self.transformer(x, context)# Reverse the process x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)# Residuereturn x + x_in
Code play ground!
Putting it together, UNet Transformer
Now you can interleave your SpatialTransformer layers with the convolutional layers!
Remember to use them in your forward function. Look at the architecture, and add in extra attention layers if you wish.
class UNet_Tranformer(nn.Module):"""A time-dependent score-based model built upon U-Net architecture."""def__init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256, text_dim=256, nClass=10):"""Initialize a time-dependent score-based network. Args: marginal_prob_std: A function that takes time t and gives the standard deviation of the perturbation kernel p_{0t}(x(t) | x(0)). channels: The number of channels for feature maps of each resolution. embed_dim: The dimensionality of Gaussian random feature embeddings of time. text_dim: the embedding dimension of text / digits. nClass: number of classes you want to model. """super().__init__()# Gaussian random feature embedding layer for timeself.time_embed = nn.Sequential( GaussianFourierProjection(embed_dim=embed_dim), nn.Linear(embed_dim, embed_dim) )# Encoding layers where the resolution decreasesself.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)self.dense1 = Dense(embed_dim, channels[0])self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)self.dense2 = Dense(embed_dim, channels[1])self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)self.dense3 = Dense(embed_dim, channels[2])self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])self.attn3 = SpatialTransformer(channels[2], text_dim)self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)self.dense4 = Dense(embed_dim, channels[3])self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])# YOUR CODE: interleave some attention layers with conv layersself.attn4 = ... ####################################### Decoding layers where the resolution increasesself.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)self.dense5 = Dense(embed_dim, channels[2])self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1) # + channels[2]self.dense6 = Dense(embed_dim, channels[1])self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, output_padding=1) # + channels[1]self.dense7 = Dense(embed_dim, channels[0])self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])self.tconv1 = nn.ConvTranspose2d(channels[0], 1, 3, stride=1) # + channels[0]# The swish activation functionself.act = nn.SiLU() # lambda x: x * torch.sigmoid(x)self.marginal_prob_std = marginal_prob_stdself.cond_embed = nn.Embedding(nClass, text_dim)def forward(self, x, t, y=None):# Obtain the Gaussian random feature embedding for t embed =self.act(self.time_embed(t)) y_embed =self.cond_embed(y).unsqueeze(1)# Encoding path h1 =self.conv1(x) +self.dense1(embed)## Incorporate information from t## Group normalization h1 =self.act(self.gnorm1(h1)) h2 =self.conv2(h1) +self.dense2(embed) h2 =self.act(self.gnorm2(h2)) h3 =self.conv3(h2) +self.dense3(embed) h3 =self.act(self.gnorm3(h3)) h3 =self.attn3(h3, y_embed) # Use your attention layers h4 =self.conv4(h3) +self.dense4(embed) h4 =self.act(self.gnorm4(h4))# Your code: Use your additional attention layers! h4 = ... ##################### ATTENTION LAYER COULD GO HERE IF ATTN4 IS DEFINED# Decoding path h =self.tconv4(h4) +self.dense5(embed)## Skip connection from the encoding path h =self.act(self.tgnorm4(h)) h =self.tconv3(h + h3) +self.dense6(embed) h =self.act(self.tgnorm3(h)) h =self.tconv2(h + h2) +self.dense7(embed) h =self.act(self.tgnorm2(h)) h =self.tconv1(h + h1)# Normalize output h = h /self.marginal_prob_std(t)[:, None, None, None]return h
Conditional Denoising Loss
Here, we need to modify the loss function by using the y information in the training.
def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5):"""The loss function for training score-based generative models. Args: model: A PyTorch model instance that represents a time-dependent score-based model. x: A mini-batch of training data. marginal_prob_std: A function that gives the standard deviation of the perturbation kernel. eps: A tolerance value for numerical stability. """ random_t = torch.rand(x.shape[0], device=x.device) * (1.- eps) + eps z = torch.randn_like(x) std = marginal_prob_std(random_t) perturbed_x = x + z * std[:, None, None, None] score = model(perturbed_x, random_t, y=y) loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))return loss
Training a model that includes attention
The below code, similar to code above, does the training.
#@title Training modelcontinue_training =False#@param {type:"boolean"}ifnot continue_training:print("initilize new score model...") score_model = torch.nn.DataParallel(UNet_Tranformer(marginal_prob_std=marginal_prob_std_fn)) score_model = score_model.to(device)n_epochs =100#@param {'type':'integer'}## size of a mini-batchbatch_size =1024#@param {'type':'integer'}## learning ratelr=10e-4#@param {'type':'number'}dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)optimizer = Adam(score_model.parameters(), lr=lr)scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 0.98** epoch))tqdm_epoch = trange(n_epochs)for epoch in tqdm_epoch: avg_loss =0. num_items =0for x, y in tqdm(data_loader): x = x.to(device) loss = loss_fn_cond(score_model, x, y, marginal_prob_std_fn) optimizer.zero_grad() loss.backward() optimizer.step() avg_loss += loss.item() * x.shape[0] num_items += x.shape[0] scheduler.step() lr_current = scheduler.get_last_lr()[0]print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))# Print the averaged training loss so far. tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))# Update the checkpoint after each epoch of training. torch.save(score_model.state_dict(), 'ckpt_transformer.pth')
initilize new score model...
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:566: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
0 Average Loss: 911.835488 lr 9.8e-04
1 Average Loss: 247.529141 lr 9.6e-04
2 Average Loss: 159.000987 lr 9.4e-04
3 Average Loss: 117.500008 lr 9.2e-04
4 Average Loss: 93.430158 lr 9.0e-04
5 Average Loss: 79.269252 lr 8.9e-04
6 Average Loss: 69.451496 lr 8.7e-04
7 Average Loss: 62.244141 lr 8.5e-04
8 Average Loss: 56.955125 lr 8.3e-04
9 Average Loss: 51.763428 lr 8.2e-04
10 Average Loss: 48.633911 lr 8.0e-04
11 Average Loss: 46.004020 lr 7.8e-04
12 Average Loss: 44.236268 lr 7.7e-04
13 Average Loss: 42.072793 lr 7.5e-04
14 Average Loss: 40.703189 lr 7.4e-04
15 Average Loss: 39.530541 lr 7.2e-04
16 Average Loss: 38.568226 lr 7.1e-04
17 Average Loss: 37.374072 lr 7.0e-04
18 Average Loss: 36.376193 lr 6.8e-04
19 Average Loss: 35.763268 lr 6.7e-04
20 Average Loss: 35.203006 lr 6.5e-04
21 Average Loss: 34.426941 lr 6.4e-04
22 Average Loss: 33.771495 lr 6.3e-04
23 Average Loss: 33.430943 lr 6.2e-04
24 Average Loss: 33.097876 lr 6.0e-04
25 Average Loss: 32.334015 lr 5.9e-04
26 Average Loss: 32.121282 lr 5.8e-04
27 Average Loss: 31.700731 lr 5.7e-04
28 Average Loss: 31.240735 lr 5.6e-04
29 Average Loss: 31.116556 lr 5.5e-04
30 Average Loss: 30.213783 lr 5.3e-04
31 Average Loss: 30.142528 lr 5.2e-04
32 Average Loss: 29.572817 lr 5.1e-04
33 Average Loss: 29.405506 lr 5.0e-04
34 Average Loss: 28.913737 lr 4.9e-04
35 Average Loss: 29.124497 lr 4.8e-04
36 Average Loss: 28.703883 lr 4.7e-04
37 Average Loss: 28.425991 lr 4.6e-04
38 Average Loss: 28.084901 lr 4.5e-04
39 Average Loss: 27.986393 lr 4.5e-04
40 Average Loss: 27.545136 lr 4.4e-04
41 Average Loss: 27.341096 lr 4.3e-04
42 Average Loss: 27.178352 lr 4.2e-04
43 Average Loss: 27.105629 lr 4.1e-04
44 Average Loss: 26.826516 lr 4.0e-04
45 Average Loss: 26.458968 lr 3.9e-04
46 Average Loss: 26.755344 lr 3.9e-04
47 Average Loss: 26.361928 lr 3.8e-04
48 Average Loss: 25.903759 lr 3.7e-04
49 Average Loss: 26.234516 lr 3.6e-04
50 Average Loss: 25.775030 lr 3.6e-04
51 Average Loss: 25.525633 lr 3.5e-04
52 Average Loss: 25.677089 lr 3.4e-04
53 Average Loss: 25.267192 lr 3.4e-04
54 Average Loss: 25.461929 lr 3.3e-04
55 Average Loss: 25.184782 lr 3.2e-04
56 Average Loss: 24.780046 lr 3.2e-04
57 Average Loss: 24.718920 lr 3.1e-04
58 Average Loss: 24.864525 lr 3.0e-04
59 Average Loss: 24.542736 lr 3.0e-04
60 Average Loss: 24.552311 lr 2.9e-04
61 Average Loss: 24.430597 lr 2.9e-04
62 Average Loss: 24.046150 lr 2.8e-04
63 Average Loss: 23.895412 lr 2.7e-04
64 Average Loss: 23.895568 lr 2.7e-04
65 Average Loss: 24.022242 lr 2.6e-04
66 Average Loss: 23.582576 lr 2.6e-04
67 Average Loss: 23.530880 lr 2.5e-04
68 Average Loss: 23.764094 lr 2.5e-04
69 Average Loss: 23.642832 lr 2.4e-04
70 Average Loss: 23.571634 lr 2.4e-04
71 Average Loss: 23.427607 lr 2.3e-04
72 Average Loss: 23.303141 lr 2.3e-04
73 Average Loss: 23.363988 lr 2.2e-04
74 Average Loss: 23.228841 lr 2.2e-04
75 Average Loss: 23.210248 lr 2.2e-04
76 Average Loss: 22.962327 lr 2.1e-04
77 Average Loss: 22.764237 lr 2.1e-04
78 Average Loss: 22.771735 lr 2.0e-04
79 Average Loss: 22.673690 lr 2.0e-04
80 Average Loss: 22.459341 lr 2.0e-04
81 Average Loss: 22.565987 lr 2.0e-04
82 Average Loss: 22.460520 lr 2.0e-04
83 Average Loss: 22.582685 lr 2.0e-04
84 Average Loss: 22.532176 lr 2.0e-04
85 Average Loss: 22.255016 lr 2.0e-04
86 Average Loss: 22.438315 lr 2.0e-04
87 Average Loss: 22.113862 lr 2.0e-04
88 Average Loss: 22.139238 lr 2.0e-04
89 Average Loss: 22.073406 lr 2.0e-04
90 Average Loss: 21.951703 lr 2.0e-04
91 Average Loss: 22.078716 lr 2.0e-04
92 Average Loss: 21.877473 lr 2.0e-04
93 Average Loss: 21.665942 lr 2.0e-04
94 Average Loss: 21.745823 lr 2.0e-04
95 Average Loss: 21.626989 lr 2.0e-04
96 Average Loss: 21.624281 lr 2.0e-04
97 Average Loss: 21.708963 lr 2.0e-04
98 Average Loss: 21.351878 lr 2.0e-04
99 Average Loss: 21.622979 lr 2.0e-04
Loss around 23 at 67 epochs (without lr tuning, around 150 epochs reach 23)
#@title A handy training functiondef train_diffusion_model(dataset, score_model, n_epochs =100, batch_size =1024, lr=10e-4, model_name="transformer"): data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) optimizer = Adam(score_model.parameters(), lr=lr) scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 0.98** epoch)) tqdm_epoch = trange(n_epochs)for epoch in tqdm_epoch: avg_loss =0. num_items =0for x, y in tqdm(data_loader): x = x.to(device) loss = loss_fn_cond(score_model, x, y, marginal_prob_std_fn) optimizer.zero_grad() loss.backward() optimizer.step() avg_loss += loss.item() * x.shape[0] num_items += x.shape[0] scheduler.step() lr_current = scheduler.get_last_lr()[0]print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))# Print the averaged training loss so far. tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))# Update the checkpoint after each epoch of training. torch.save(score_model.state_dict(), f'ckpt_{model_name}.pth')
# Feel free to play with hyperparameters for training!score_model = torch.nn.DataParallel(UNet_Tranformer(marginal_prob_std=marginal_prob_std_fn))score_model = score_model.to(device)train_diffusion_model(dataset, score_model, n_epochs =100, batch_size =1024, lr=10e-4, model_name="transformer")
Here is some code we can use to see how well the model does conditional generation. You can use the menu on the right to choose what you want to generate.
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:27: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
save_samples_cond(score_model,"_res") # model with res connection
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:27: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
save_samples_cond(score_model) # model without res connection
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
del sys.path[0]
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:26: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
5. Latent space diffusion using an autoencoder
Finally, we get to one of the most important contributions of Rombach et al., the paper behind Stable Diffusion! Instead of diffusing in pixel space (i.e. corrupting and denoising each pixel of an image), we can try diffusing in some kind of latent space.
This has a few advantages. An obvious one is speed: compressing images before doing forward/reverse diffusion on them makes both generation and training faster. Another advantage is that the latent space, if carefully chosen, may be a more natural or interpretable space for working with images. For example, given a set of pictures of heads, perhaps some latent direction corresponds to head direction.
If we do not have any a priori bias towards one latent space or another, we can just throw an autoencoder at the problem and hope it comes up with something appropriate.
In this section, we will use an autoencoder to compress MNIST images to a smaller scale, and glue this to the rest of our diffusion pipeline.
Defining the autoencoder
Complete the missing part of the autoencoder’s forward function. Note that what an autoencoder does is first ‘encode’ an image into some latent representation, and then ‘decode’ an image from that latent representation.
Hint: You can access the encoder and decoder via self.encoder and self.decoder.
class AutoEncoder(nn.Module):"""A time-dependent score-based model built upon U-Net architecture."""def__init__(self, channels=[4, 8, 32],):"""Initialize a time-dependent score-based network. Args: channels: The number of channels for feature maps of each resolution. embed_dim: The dimensionality of Gaussian random feature embeddings. """super().__init__()# Gaussian random feature embedding layer for time# Encoding layers where the resolution decreasesself.encoder = nn.Sequential(nn.Conv2d(1, channels[0], 3, stride=1, bias=True), nn.BatchNorm2d(channels[0]), nn.SiLU(), nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=True), nn.BatchNorm2d(channels[1]), nn.SiLU(), nn.Conv2d(channels[1], channels[2], 3, stride=1, bias=True), nn.BatchNorm2d(channels[2]), ) #nn.SiLU(),self.decoder = nn.Sequential(nn.ConvTranspose2d(channels[2], channels[1], 3, stride=1, bias=True), nn.BatchNorm2d(channels[1]), nn.SiLU(), nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=True, output_padding=1), nn.BatchNorm2d(channels[0]), nn.SiLU(), nn.ConvTranspose2d(channels[0], 1, 3, stride=1, bias=True), nn.Sigmoid(),)def forward(self, x):########## YOUR CODE HERE (1 line) output = ...###################return output
The following code checks to see whether your autoencoder was defined properly.
x_tmp = torch.randn(1,1,28,28)print(AutoEncoder()(x_tmp).shape)assert AutoEncoder()(x_tmp).shape == x_tmp.shape, "Check conv layer spec! the autoencoder input output shape not align"
torch.Size([1, 1, 28, 28])
Train the autoencoder with the help of a perceptual loss
Let’s train the autoencoder on MNIST images! Do this by running the cells below.
The loss could be really small ~ close to 0.01
from lpips import LPIPS
# Define the loss function, MSE and LPIPSlpips = LPIPS(net="squeeze").cuda()loss_fn_ae =lambda x,xhat: \ nn.functional.mse_loss(x, xhat) +\ lpips(x.repeat(1,3,1,1), x_hat.repeat(1,3,1,1)).mean()
ae_model = AutoEncoder([4, 4, 4]).cuda()n_epochs =50#@param {'type':'integer'}## size of a mini-batchbatch_size =2048#@param {'type':'integer'}## learning ratelr=10e-4#@param {'type':'number'}dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)optimizer = Adam(ae_model.parameters(), lr=lr)tqdm_epoch = trange(n_epochs)for epoch in tqdm_epoch: avg_loss =0. num_items =0for x, y in data_loader: x = x.to(device) z = ae_model.encoder(x) x_hat = ae_model.decoder(z) loss = loss_fn_ae(x, x_hat) #loss_fn_cond(score_model, x, y, marginal_prob_std_fn) optimizer.zero_grad() loss.backward() optimizer.step() avg_loss += loss.item() * x.shape[0] num_items += x.shape[0]print('{} Average Loss: {:5f}'.format(epoch, avg_loss / num_items))# Print the averaged training loss so far. tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))# Update the checkpoint after each epoch of training. torch.save(ae_model.state_dict(), 'ckpt_ae.pth')
Setting up [LPIPS] perceptual loss: trunk [squeeze], v[0.1], spatial [off]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=SqueezeNet1_1_Weights.IMAGENET1K_V1`. You can also use `weights=SqueezeNet1_1_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth" to /root/.cache/torch/hub/checkpoints/squeezenet1_1-b8a52dc0.pth
Loading model from: /usr/local/lib/python3.7/dist-packages/lpips/weights/v0.1/squeeze.pth
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
0 Average Loss: 0.576265
1 Average Loss: 0.377764
2 Average Loss: 0.277722
3 Average Loss: 0.213910
4 Average Loss: 0.167607
5 Average Loss: 0.137186
6 Average Loss: 0.114799
7 Average Loss: 0.097592
8 Average Loss: 0.084385
9 Average Loss: 0.073949
10 Average Loss: 0.065258
11 Average Loss: 0.057647
12 Average Loss: 0.050750
13 Average Loss: 0.044420
14 Average Loss: 0.038747
15 Average Loss: 0.033954
16 Average Loss: 0.030010
17 Average Loss: 0.026624
18 Average Loss: 0.023933
19 Average Loss: 0.021834
20 Average Loss: 0.020204
21 Average Loss: 0.018909
22 Average Loss: 0.017843
23 Average Loss: 0.016944
24 Average Loss: 0.016180
25 Average Loss: 0.015518
26 Average Loss: 0.014946
27 Average Loss: 0.014435
28 Average Loss: 0.013976
29 Average Loss: 0.013558
30 Average Loss: 0.013188
31 Average Loss: 0.012850
32 Average Loss: 0.012540
33 Average Loss: 0.012256
---------------------------------------------------------------------------KeyboardInterrupt Traceback (most recent call last)
<ipython-input-19-27c9966383ef> in <module> 18 avg_loss =0. 19 num_items =0---> 20for x, y in data_loader: 21 x = x.to(device) 22 z = ae_model.encoder(x)/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in __next__(self) 679# TODO(https://github.com/pytorch/pytorch/issues/76750) 680 self._reset()# type: ignore[call-arg]--> 681data = self._next_data() 682 self._num_yielded +=1 683if self._dataset_kind == _DatasetKind.Iterable and\/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _next_data(self) 719def _next_data(self): 720 index = self._next_index()# may raise StopIteration--> 721data = self._dataset_fetcher.fetch(index)# may raise StopIteration 722if self._pin_memory: 723 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index) 47def fetch(self, possibly_batched_index): 48if self.auto_collation:---> 49data =[self.dataset[idx]for idx in possibly_batched_index] 50else: 51 data = self.dataset[possibly_batched_index]/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0) 47def fetch(self, possibly_batched_index): 48if self.auto_collation:---> 49data =[self.dataset[idx]for idx in possibly_batched_index] 50else: 51 data = self.dataset[possibly_batched_index]/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py in __getitem__(self, index) 143 144if self.transform isnotNone:--> 145img = self.transform(img) 146 147if self.target_transform isnotNone:/usr/local/lib/python3.7/dist-packages/torchvision/transforms/transforms.py in __call__(self, pic) 132 Tensor: Converted image. 133 """
--> 134return F.to_tensor(pic) 135 136def __repr__(self)-> str:/usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py in to_tensor(pic) 168 img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) 169# put it from HWC to CHW format--> 170img = img.permute((2,0,1)).contiguous() 171if isinstance(img, torch.ByteTensor): 172return img.to(dtype=default_float_dtype).div(255)KeyboardInterrupt:
x_hat.shape
torch.Size([2048, 1, 28, 28])
The below cell visualizes the results. The autoencoder’s output should look almost identical to the original images.
---------------------------------------------------------------------------NameError Traceback (most recent call last)
<ipython-input-1-7501fc6a3e74> in <module> 1#@title Visualize trained autoencoder----> 2ae_model.eval() 3 x, y = next(iter(data_loader)) 4 x_hat = ae_model(x.to(device)).cpu() 5 plt.figure(figsize=(6,6.5))NameError: name 'ae_model' is not defined
Create latent state dataset
Let’s use our autoencoder to convert MNIST images into a latent space representation. We will use these compressed images to train our diffusion generative model.
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:566: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
from torch.utils.data import TensorDatasetlatent_dataset = TensorDataset(zdata, ydata)
Transformer UNet model for Latents
Here is a U-Net (that includes self/cross-attention) similar to the one we defined before, but that this time works with compressed images instead of full-size images. You don’t need to do anything here except take a look at the architecture.
class Latent_UNet_Tranformer(nn.Module):"""A time-dependent score-based model built upon U-Net architecture."""def__init__(self, marginal_prob_std, channels=[4, 64, 128, 256], embed_dim=256, text_dim=256, nClass=10):"""Initialize a time-dependent score-based network. Args: marginal_prob_std: A function that takes time t and gives the standard deviation of the perturbation kernel p_{0t}(x(t) | x(0)). channels: The number of channels for feature maps of each resolution. embed_dim: The dimensionality of Gaussian random feature embeddings. """super().__init__()# Gaussian random feature embedding layer for timeself.time_embed = nn.Sequential( GaussianFourierProjection(embed_dim=embed_dim), nn.Linear(embed_dim, embed_dim))# Encoding layers where the resolution decreasesself.conv1 = nn.Conv2d(channels[0], channels[1], 3, stride=1, bias=False)self.dense1 = Dense(embed_dim, channels[1])self.gnorm1 = nn.GroupNorm(4, num_channels=channels[1])self.conv2 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)self.dense2 = Dense(embed_dim, channels[2])self.gnorm2 = nn.GroupNorm(4, num_channels=channels[2])self.attn2 = SpatialTransformer(channels[2], text_dim)self.conv3 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)self.dense3 = Dense(embed_dim, channels[3])self.gnorm3 = nn.GroupNorm(4, num_channels=channels[3])self.attn3 = SpatialTransformer(channels[3], text_dim)self.tconv3 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False, )self.dense6 = Dense(embed_dim, channels[2])self.tgnorm3 = nn.GroupNorm(4, num_channels=channels[2])self.attn6 = SpatialTransformer(channels[2], text_dim)self.tconv2 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1) # + channels[2]self.dense7 = Dense(embed_dim, channels[1])self.tgnorm2 = nn.GroupNorm(4, num_channels=channels[1])self.tconv1 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=1) # + channels[1]# The swish activation functionself.act = nn.SiLU() # lambda x: x * torch.sigmoid(x)self.marginal_prob_std = marginal_prob_stdself.cond_embed = nn.Embedding(nClass, text_dim)def forward(self, x, t, y=None):# Obtain the Gaussian random feature embedding for t embed =self.act(self.time_embed(t)) y_embed =self.cond_embed(y).unsqueeze(1)# Encoding path## Incorporate information from t h1 =self.conv1(x) +self.dense1(embed)## Group normalization h1 =self.act(self.gnorm1(h1)) h2 =self.conv2(h1) +self.dense2(embed) h2 =self.act(self.gnorm2(h2)) h2 =self.attn2(h2, y_embed) h3 =self.conv3(h2) +self.dense3(embed) h3 =self.act(self.gnorm3(h3)) h3 =self.attn3(h3, y_embed)# Decoding path## Skip connection from the encoding path h =self.tconv3(h3) +self.dense6(embed) h =self.act(self.tgnorm3(h)) h =self.attn6(h, y_embed) h =self.tconv2(h + h2) h +=self.dense7(embed) h =self.act(self.tgnorm2(h)) h =self.tconv1(h + h1)# Normalize output h = h /self.marginal_prob_std(t)[:, None, None, None]return h
Training our latent diffusion model
Finally, we will put everything together, and combine our latent space representation with our fancy U-Net for learning the score function. (This may not actually work that well…but at least you can appreciate that, with all these moving parts, this becomes a hard engineering problem.)
Run the cell below to train our latent diffusion model!
#@title Training Latent diffusion modelcontinue_training =True#@param {type:"boolean"}ifnot continue_training:print("initilize new score model...") latent_score_model = torch.nn.DataParallel( Latent_UNet_Tranformer(marginal_prob_std=marginal_prob_std_fn, channels=[4, 16, 32, 64], )) latent_score_model = latent_score_model.to(device)n_epochs =250#@param {'type':'integer'}## size of a mini-batchbatch_size =1024#@param {'type':'integer'}## learning ratelr=1e-4#@param {'type':'number'}latent_data_loader = DataLoader(latent_dataset, batch_size=batch_size, shuffle=True, )latent_score_model.train()optimizer = Adam(latent_score_model.parameters(), lr=lr)scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.5, 0.995** epoch))tqdm_epoch = trange(n_epochs)for epoch in tqdm_epoch: avg_loss =0. num_items =0for z, y in latent_data_loader: z = z.to(device) loss = loss_fn_cond(latent_score_model, z, y, marginal_prob_std_fn) optimizer.zero_grad() loss.backward() optimizer.step() avg_loss += loss.item() * z.shape[0] num_items += z.shape[0] scheduler.step() lr_current = scheduler.get_last_lr()[0]print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))# Print the averaged training loss so far. tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))# Update the checkpoint after each epoch of training. torch.save(latent_score_model.state_dict(), 'ckpt_latent_diff_transformer.pth')
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
0 Average Loss: 79.741900 lr 1.0e-04
1 Average Loss: 78.435195 lr 9.9e-05
2 Average Loss: 78.256687 lr 9.9e-05
3 Average Loss: 78.923802 lr 9.8e-05
4 Average Loss: 78.866818 lr 9.8e-05
5 Average Loss: 78.671988 lr 9.7e-05
6 Average Loss: 78.752654 lr 9.7e-05
7 Average Loss: 78.650670 lr 9.6e-05
8 Average Loss: 78.584266 lr 9.6e-05
9 Average Loss: 78.676226 lr 9.5e-05
10 Average Loss: 78.515674 lr 9.5e-05
11 Average Loss: 78.654196 lr 9.4e-05
12 Average Loss: 78.686057 lr 9.4e-05
13 Average Loss: 78.658464 lr 9.3e-05
14 Average Loss: 78.215316 lr 9.3e-05
15 Average Loss: 78.253186 lr 9.2e-05
16 Average Loss: 78.219765 lr 9.2e-05
17 Average Loss: 78.056799 lr 9.1e-05
18 Average Loss: 78.698358 lr 9.1e-05
19 Average Loss: 78.507159 lr 9.0e-05
20 Average Loss: 78.948341 lr 9.0e-05
21 Average Loss: 78.129618 lr 9.0e-05
22 Average Loss: 77.991600 lr 8.9e-05
23 Average Loss: 78.630760 lr 8.9e-05
24 Average Loss: 78.335837 lr 8.8e-05
25 Average Loss: 78.107454 lr 8.8e-05
26 Average Loss: 78.802007 lr 8.7e-05
27 Average Loss: 78.101927 lr 8.7e-05
28 Average Loss: 78.752702 lr 8.6e-05
29 Average Loss: 78.302277 lr 8.6e-05
30 Average Loss: 78.242809 lr 8.6e-05
31 Average Loss: 78.166564 lr 8.5e-05
32 Average Loss: 78.127243 lr 8.5e-05
33 Average Loss: 78.371072 lr 8.4e-05
34 Average Loss: 78.446809 lr 8.4e-05
35 Average Loss: 78.027956 lr 8.3e-05
36 Average Loss: 78.083886 lr 8.3e-05
37 Average Loss: 78.058587 lr 8.3e-05
38 Average Loss: 77.870787 lr 8.2e-05
39 Average Loss: 78.248184 lr 8.2e-05
40 Average Loss: 78.590380 lr 8.1e-05
41 Average Loss: 78.125853 lr 8.1e-05
42 Average Loss: 77.715559 lr 8.1e-05
43 Average Loss: 78.153985 lr 8.0e-05
44 Average Loss: 78.224670 lr 8.0e-05
45 Average Loss: 78.565155 lr 7.9e-05
46 Average Loss: 78.480046 lr 7.9e-05
47 Average Loss: 77.679577 lr 7.9e-05
48 Average Loss: 78.273799 lr 7.8e-05
49 Average Loss: 78.113882 lr 7.8e-05
50 Average Loss: 77.927349 lr 7.7e-05
51 Average Loss: 77.552600 lr 7.7e-05
52 Average Loss: 78.029377 lr 7.7e-05
53 Average Loss: 77.763078 lr 7.6e-05
54 Average Loss: 78.368110 lr 7.6e-05
55 Average Loss: 78.099132 lr 7.6e-05
56 Average Loss: 78.030513 lr 7.5e-05
57 Average Loss: 78.235108 lr 7.5e-05
58 Average Loss: 78.008969 lr 7.4e-05
59 Average Loss: 77.641894 lr 7.4e-05
60 Average Loss: 77.469205 lr 7.4e-05
61 Average Loss: 77.889893 lr 7.3e-05
62 Average Loss: 77.717054 lr 7.3e-05
63 Average Loss: 78.176804 lr 7.3e-05
64 Average Loss: 78.475317 lr 7.2e-05
65 Average Loss: 77.928177 lr 7.2e-05
66 Average Loss: 77.676708 lr 7.1e-05
67 Average Loss: 77.646743 lr 7.1e-05
68 Average Loss: 77.801639 lr 7.1e-05
69 Average Loss: 77.197982 lr 7.0e-05
70 Average Loss: 78.122972 lr 7.0e-05
71 Average Loss: 78.405846 lr 7.0e-05
72 Average Loss: 77.939054 lr 6.9e-05
73 Average Loss: 77.618042 lr 6.9e-05
74 Average Loss: 77.883120 lr 6.9e-05
75 Average Loss: 78.044423 lr 6.8e-05
76 Average Loss: 77.567990 lr 6.8e-05
77 Average Loss: 78.137115 lr 6.8e-05
78 Average Loss: 77.830123 lr 6.7e-05
79 Average Loss: 77.465878 lr 6.7e-05
80 Average Loss: 77.238517 lr 6.7e-05
81 Average Loss: 78.094812 lr 6.6e-05
82 Average Loss: 77.645884 lr 6.6e-05
83 Average Loss: 77.504665 lr 6.6e-05
84 Average Loss: 77.869831 lr 6.5e-05
85 Average Loss: 77.900258 lr 6.5e-05
86 Average Loss: 78.299160 lr 6.5e-05
87 Average Loss: 77.819575 lr 6.4e-05
88 Average Loss: 77.891344 lr 6.4e-05
89 Average Loss: 77.804814 lr 6.4e-05
90 Average Loss: 77.903180 lr 6.3e-05
91 Average Loss: 77.446807 lr 6.3e-05
92 Average Loss: 77.180977 lr 6.3e-05
93 Average Loss: 77.810924 lr 6.2e-05
94 Average Loss: 77.582899 lr 6.2e-05
95 Average Loss: 77.650902 lr 6.2e-05
96 Average Loss: 77.177852 lr 6.1e-05
97 Average Loss: 77.866134 lr 6.1e-05
98 Average Loss: 77.266161 lr 6.1e-05
99 Average Loss: 77.848319 lr 6.1e-05
100 Average Loss: 77.941477 lr 6.0e-05
101 Average Loss: 77.718800 lr 6.0e-05
102 Average Loss: 77.651057 lr 6.0e-05
103 Average Loss: 77.751019 lr 5.9e-05
104 Average Loss: 77.563723 lr 5.9e-05
105 Average Loss: 77.922618 lr 5.9e-05
106 Average Loss: 77.731911 lr 5.8e-05
107 Average Loss: 77.411512 lr 5.8e-05
108 Average Loss: 77.497530 lr 5.8e-05
109 Average Loss: 77.464712 lr 5.8e-05
110 Average Loss: 77.919681 lr 5.7e-05
111 Average Loss: 77.089884 lr 5.7e-05
112 Average Loss: 77.609092 lr 5.7e-05
113 Average Loss: 77.559160 lr 5.6e-05
114 Average Loss: 77.290920 lr 5.6e-05
115 Average Loss: 77.572117 lr 5.6e-05
116 Average Loss: 77.537510 lr 5.6e-05
117 Average Loss: 77.314877 lr 5.5e-05
118 Average Loss: 77.292442 lr 5.5e-05
119 Average Loss: 77.182233 lr 5.5e-05
120 Average Loss: 76.770025 lr 5.5e-05
121 Average Loss: 77.140515 lr 5.4e-05
122 Average Loss: 77.753775 lr 5.4e-05
123 Average Loss: 76.708605 lr 5.4e-05
124 Average Loss: 77.594884 lr 5.3e-05
125 Average Loss: 77.324399 lr 5.3e-05
126 Average Loss: 77.769956 lr 5.3e-05
127 Average Loss: 77.315549 lr 5.3e-05
128 Average Loss: 77.425095 lr 5.2e-05
129 Average Loss: 77.691606 lr 5.2e-05
130 Average Loss: 77.380239 lr 5.2e-05
131 Average Loss: 77.080548 lr 5.2e-05
132 Average Loss: 76.912462 lr 5.1e-05
133 Average Loss: 77.147056 lr 5.1e-05
134 Average Loss: 76.999053 lr 5.1e-05
135 Average Loss: 77.205619 lr 5.1e-05
136 Average Loss: 77.358422 lr 5.0e-05
137 Average Loss: 77.169585 lr 5.0e-05
138 Average Loss: 77.581670 lr 5.0e-05
139 Average Loss: 77.321906 lr 5.0e-05
140 Average Loss: 77.383666 lr 5.0e-05
141 Average Loss: 77.297920 lr 5.0e-05
142 Average Loss: 77.366751 lr 5.0e-05
143 Average Loss: 77.855146 lr 5.0e-05
144 Average Loss: 77.116829 lr 5.0e-05
145 Average Loss: 77.323109 lr 5.0e-05
146 Average Loss: 76.791543 lr 5.0e-05
147 Average Loss: 77.454809 lr 5.0e-05
148 Average Loss: 76.987101 lr 5.0e-05
149 Average Loss: 77.093755 lr 5.0e-05
150 Average Loss: 77.206592 lr 5.0e-05
151 Average Loss: 77.344288 lr 5.0e-05
152 Average Loss: 76.944758 lr 5.0e-05
153 Average Loss: 77.216894 lr 5.0e-05
154 Average Loss: 77.409606 lr 5.0e-05
155 Average Loss: 77.371059 lr 5.0e-05
156 Average Loss: 76.781848 lr 5.0e-05
157 Average Loss: 76.994091 lr 5.0e-05
158 Average Loss: 76.811613 lr 5.0e-05
159 Average Loss: 76.775970 lr 5.0e-05
160 Average Loss: 77.237806 lr 5.0e-05
161 Average Loss: 77.182012 lr 5.0e-05
162 Average Loss: 77.315010 lr 5.0e-05
163 Average Loss: 77.092724 lr 5.0e-05
164 Average Loss: 77.018680 lr 5.0e-05
165 Average Loss: 77.454656 lr 5.0e-05
166 Average Loss: 76.964814 lr 5.0e-05
167 Average Loss: 77.330033 lr 5.0e-05
168 Average Loss: 76.930848 lr 5.0e-05
169 Average Loss: 76.659252 lr 5.0e-05
170 Average Loss: 77.174353 lr 5.0e-05
171 Average Loss: 76.848775 lr 5.0e-05
172 Average Loss: 77.341619 lr 5.0e-05
173 Average Loss: 76.914897 lr 5.0e-05
174 Average Loss: 76.997426 lr 5.0e-05
175 Average Loss: 77.034331 lr 5.0e-05
176 Average Loss: 76.898844 lr 5.0e-05
177 Average Loss: 77.260483 lr 5.0e-05
178 Average Loss: 76.800081 lr 5.0e-05
179 Average Loss: 77.055465 lr 5.0e-05
180 Average Loss: 77.245446 lr 5.0e-05
181 Average Loss: 77.588249 lr 5.0e-05
182 Average Loss: 77.196555 lr 5.0e-05
183 Average Loss: 77.217582 lr 5.0e-05
184 Average Loss: 77.646229 lr 5.0e-05
185 Average Loss: 76.726135 lr 5.0e-05
186 Average Loss: 77.047729 lr 5.0e-05
187 Average Loss: 77.190589 lr 5.0e-05
188 Average Loss: 76.917252 lr 5.0e-05
189 Average Loss: 77.095856 lr 5.0e-05
190 Average Loss: 76.599167 lr 5.0e-05
191 Average Loss: 77.529027 lr 5.0e-05
192 Average Loss: 77.306830 lr 5.0e-05
193 Average Loss: 77.470388 lr 5.0e-05
194 Average Loss: 77.153232 lr 5.0e-05
195 Average Loss: 76.966829 lr 5.0e-05
196 Average Loss: 77.399963 lr 5.0e-05
197 Average Loss: 77.156293 lr 5.0e-05
198 Average Loss: 77.290599 lr 5.0e-05
199 Average Loss: 76.886455 lr 5.0e-05
200 Average Loss: 76.325025 lr 5.0e-05
201 Average Loss: 76.655142 lr 5.0e-05
202 Average Loss: 76.951446 lr 5.0e-05
203 Average Loss: 77.340102 lr 5.0e-05
204 Average Loss: 77.056123 lr 5.0e-05
205 Average Loss: 77.161179 lr 5.0e-05
206 Average Loss: 76.506679 lr 5.0e-05
207 Average Loss: 77.103935 lr 5.0e-05
208 Average Loss: 76.547799 lr 5.0e-05
209 Average Loss: 76.954140 lr 5.0e-05
210 Average Loss: 77.212222 lr 5.0e-05
211 Average Loss: 77.476724 lr 5.0e-05
212 Average Loss: 77.393100 lr 5.0e-05
213 Average Loss: 77.050557 lr 5.0e-05
214 Average Loss: 76.813148 lr 5.0e-05
215 Average Loss: 77.018597 lr 5.0e-05
216 Average Loss: 76.848586 lr 5.0e-05
217 Average Loss: 76.818908 lr 5.0e-05
218 Average Loss: 76.874240 lr 5.0e-05
219 Average Loss: 77.007598 lr 5.0e-05
220 Average Loss: 76.713896 lr 5.0e-05
221 Average Loss: 76.492506 lr 5.0e-05
222 Average Loss: 77.130726 lr 5.0e-05
223 Average Loss: 76.787134 lr 5.0e-05
224 Average Loss: 77.018584 lr 5.0e-05
225 Average Loss: 77.239958 lr 5.0e-05
226 Average Loss: 76.415214 lr 5.0e-05
227 Average Loss: 76.972245 lr 5.0e-05
228 Average Loss: 76.875388 lr 5.0e-05
229 Average Loss: 76.552531 lr 5.0e-05
230 Average Loss: 76.871315 lr 5.0e-05
231 Average Loss: 76.998418 lr 5.0e-05
232 Average Loss: 77.196564 lr 5.0e-05
233 Average Loss: 76.549707 lr 5.0e-05
234 Average Loss: 76.553033 lr 5.0e-05
235 Average Loss: 76.690979 lr 5.0e-05
236 Average Loss: 76.996255 lr 5.0e-05
237 Average Loss: 77.290900 lr 5.0e-05
238 Average Loss: 76.588793 lr 5.0e-05
239 Average Loss: 76.372851 lr 5.0e-05
240 Average Loss: 76.952500 lr 5.0e-05
241 Average Loss: 76.530314 lr 5.0e-05
242 Average Loss: 76.639879 lr 5.0e-05
243 Average Loss: 77.050608 lr 5.0e-05
244 Average Loss: 76.749453 lr 5.0e-05
245 Average Loss: 76.694167 lr 5.0e-05
246 Average Loss: 76.588385 lr 5.0e-05
247 Average Loss: 76.337892 lr 5.0e-05
248 Average Loss: 76.853386 lr 5.0e-05
249 Average Loss: 77.407620 lr 5.0e-05
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:27: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
del sys.path[0]
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:26: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:27: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).