• Tutorials >
  • Unconditional Image Generation With DDIM Model Tutorial
Shortcuts

Unconditional Image Generation With DDIM Model Tutorial

Author: Muhammed Ayman

"""
This is an Implementation for `Denoising Diffusion Implicit Models (DDIM) <https://arxiv.org/abs/2010.02502>`__

===============================================================================================================


DDIM is one of the denoising diffusion probabilistic models family but

the key difference here it doesn�t require **a large reverse diffusion

time steps** to produce samples or images.


"""
DDIM

Setup

# =====

#

import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from torchmetrics.image.kid import KernelInceptionDistance
from PIL import Image

Downloading Data and Preparing the pipeline

# ===========================================

#

Here We download the Oxford Flowers Dataset for generating images of

# flowers, which is a diverse natural flowers dataset containing around 8,000

# images with 102 category.

#

import torchvision.datasets as data
data.OxfordIIITPet('./data',download=True)
data.Flowers102("./data",download=True)

Here we prepare the data pipline using Dataset and Dataloader classes

# from **torch.utils.data** instance

#

from torch.utils.data import Dataset,DataLoader
from PIL import Image
from torchvision import transforms
class PetDataset(Dataset):
  def __init__(self,pathes,img_size=(64,64),train=True):
      self.pathes = pathes
      self.img_size = img_size
      self.aug = transforms.Compose([
                                      transforms.RandomHorizontalFlip(),
                                     # transforms.RandomAdjustSharpness(2)
                                     ])
      self.processor = transforms.Compose(
      [
        transforms.Resize(self.img_size, antialias=True),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]), # to normalize the images from [0,255] to [-1,1]
      ]
      )
      self.train = train
  def _center_crop(self,img):
      h,w = img.size
      crop_size = min(h,w)
      img = img.crop(((h-crop_size)//2,
                     (w-crop_size)//2,crop_size,crop_size))
      return img
  def __len__(self):
      return len(self.pathes)
  def __getitem__(self,idx):
       img = Image.open(self.pathes[idx]).convert("RGB")
       img = self._center_crop(img)
       img = self.processor(img)
       if self.train:
        img = self.aug(img)
       return img



import os
import random
all_flowers_pathes =  [os.path.join('/content/data/flowers-102/jpg',x)for x in os.listdir('/content/data/flowers-102/jpg')
                if x.endswith('.jpg')] # to gather all image pathes

random.shuffle(all_flowers_pathes)
train_pathes = all_flowers_pathes[:-500]
val_pathes = all_flowers_pathes[-500:]
train_ds = PetDataset(train_pathes) # training dataset
val_ds = PetDataset(val_pathes,train=False) # validation dataset

# helper function to display the image after generation
def display_img(img):
    img = (img+1)*0.5
    img= img.permute(1,2,0)
    plt.imshow(img)
    plt.axis('off')

test= val_ds[101] # grap a sample
display_img(test)

train_iter = DataLoader(train_ds,150,shuffle=True,num_workers=2,pin_memory=True)
val_iter = DataLoader(val_ds,20,num_workers=2,pin_memory=True)

Model Architecture and Modules

# ==============================

#

import math
MAX_FREQ = 1000
def get_timestep_embedding(timesteps, embedding_dim): # sinusoidal embedding like in Transformers
    assert len(timesteps.shape) == 1
    half_dim = embedding_dim // 2
    emb = math.log(MAX_FREQ) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(device=timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb
def Normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32,
                              num_channels=in_channels,
                              eps=1e-6, affine=True)

def nonlinearity(name):
    return getattr(nn,name)()


# Resnet Block
class ResidualBlock(nn.Module):
  def __init__(self,in_chs,
               out_chs,
               temb_dim,
               act='SiLU',dropout=0.2):
     super().__init__()
     self.time_proj = nn.Sequential(nonlinearity(act),
                                    nn.Linear(temb_dim,out_chs))

     dims = [in_chs]+2*[out_chs]
     blocks =[]
     for i in range(1,3):
        blc = nn.Sequential(Normalize(dims[i-1]),
                            nonlinearity(act),
                            nn.Conv2d(dims[i-1],dims[i],3,padding=1),)
        if i>1:
          blc.insert(2,nn.Dropout(dropout))
        blocks.append(blc)
     self.blocks= nn.ModuleList(blocks)
     self.short_cut =  False
     if in_chs!= out_chs:
        self.short_cut = True
        self.conv_short = nn.Conv2d(in_chs,out_chs,1)
  def forward(self,x,temb):
      h =x
      for i,blc in enumerate(self.blocks):
          h = blc(h)
          if i==0:
            h = h+self.time_proj(temb)[:,:,None,None]
      if self.short_cut:
         x = self.conv_short(x)
      return x+h


# Attention Module
class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b, c, h, w = q.shape
        q = q.reshape(b, c, h*w)
        q = q.permute(0, 2, 1)   # b,hw,c
        k = k.reshape(b, c, h*w)  # b,c,hw
        w_ = torch.bmm(q, k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (c**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b, c, h*w)
        w_ = w_.permute(0, 2, 1)   # b,hw,hw (first hw of k, second of q)
        # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        h_ = torch.bmm(v, w_)
        h_ = h_.reshape(b, c, h, w)

        h_ = self.proj_out(h_)

        return x+h_

# Downsize Block
class DownBlock(nn.Module):
  def __init__(self,
               out_chs,
               with_conv=True):
      super().__init__()
      self.with_conv = with_conv
      if with_conv:
        self.down_conv = nn.Conv2d(out_chs,out_chs,3,stride=2)
      else:
        self.down_conv = nn.AvgPool2d(2,2)


  def _down(self,x):
      if self.with_conv:
        pad = (0,1,0,1)  # to make the input shape equals to the output shape after convulotion op
        x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
        x = self.down_conv(x)
      else:
        x = self.down_conv(x)
      return x


  def forward(self,x):

      return self._down(x)

# Upsample BLock
class UpBlock(nn.Module):
  def __init__(self,out_chs,
               with_conv=True,
               mode='nearest',):
    super().__init__()
    self.with_conv = with_conv
    self.mode = mode
    if with_conv:
      self.up_conv = nn.Conv2d(out_chs,out_chs,3,padding=1)

  def _up(self,x):
      x = torch.nn.functional.interpolate(
            x, scale_factor=2.0, mode=self.mode)
      if self.with_conv:
         x = self.up_conv(x)
      return x
  def forward(self,x):
      return self._up(x)


#Unet Model
class DiffUnet(nn.Module):
  def __init__(self,chs=32,
               chs_mult=[2,2,4,4,8],
               attn_res=[16,8,4],
               block_depth=2,
               act='SiLU',
               temb_dim=256,
               with_conv=True,
               res=64,dropout=0.3):
    super().__init__()
    self.chs = chs
    self.conv_in = nn.Conv2d(3,chs,3,padding=1)
    self.time_proj = nn.Sequential(nn.Linear(chs,temb_dim),
                                   nonlinearity(act),
                                   nn.Linear(temb_dim,temb_dim))
    chs_mult = [1]+chs_mult
    #down block
    down_dims = [] # to store the down features
    downs = [] # to store the down blocks of the unet model
    for i in range(1,len(chs_mult)-1):
        in_ch = chs*chs_mult[i-1]
        out_ch = chs*chs_mult[i]
        down = nn.Module()
        down.res = nn.ModuleList([ResidualBlock(in_ch,out_ch,temb_dim,act,dropout)]+
                                   [ResidualBlock(out_ch,out_ch,temb_dim,act,dropout) for _ in range(1,block_depth)])
        attn = AttnBlock(out_ch) if res in attn_res else nn.Identity()
        down.attn = attn
        down.down_blc = DownBlock(out_ch,with_conv)
        downs.append(down)
        down_dims.append(out_ch)
        res = res//2

    self.downs = nn.ModuleList(downs)

    #mid block
    last_ch_dim=  chs*chs_mult[-1]
    self.mid_res1 = ResidualBlock(out_ch,
                                 last_ch_dim,
                                 temb_dim,act,dropout)
    self.mid_attn = AttnBlock(last_ch_dim)
    self.mid_res2 = ResidualBlock(last_ch_dim,
                                 last_ch_dim,
                                 temb_dim,act,dropout)

    #up block
    down_dims = down_dims[1:]+[last_ch_dim]
    ups = []
    for i,skip_ch in zip(reversed(range(1,len(chs_mult)-1)),reversed(down_dims)):
        out_ch = chs*chs_mult[i]
        in_ch = out_ch+skip_ch
        up = nn.Module()

        up.res = nn.ModuleList([ResidualBlock(in_ch,out_ch,temb_dim,act,dropout)]+
                                   [ResidualBlock(out_ch*2,out_ch,temb_dim,act,dropout) for _ in range(1,block_depth)])
        attn = AttnBlock(out_ch) if res in attn_res else nn.Identity()
        up.attn = attn
        up.up_blc = UpBlock(skip_ch,with_conv) if i!=0 else nn.Identity()
        ups.append(up)
        res = int(res*2)
    self.ups = nn.ModuleList(ups)
    self.out = nn.Sequential(Normalize(out_ch),
                             nonlinearity(act),
                             nn.Conv2d(out_ch,3,3,padding=1))
    self.res = res
  def forward(self,x,timestep):
      t = get_timestep_embedding(timestep,self.chs)
      t = self.time_proj(t)
      h = self.conv_in(x)
      hs =[]
      #Down
      for blc in self.downs:
         for res_block in blc.res:
             h = res_block(h,t)
             h = blc.attn(h)
             hs.append(h)
         h = blc.down_blc(h)
      #Mid
      h = self.mid_res1(h,t)
      h = self.mid_attn(h)
      h = self.mid_res2(h,t)
      #Up
      for blc  in self.ups:
         h = blc.up_blc(h)
         for res_block in blc.res:
            h = torch.cat([h,hs.pop()],axis=1)
            h = res_block(h,t)
            h = blc.attn(h)
      return self.out(h)

Diffusion Model and noise scheduler

# ===================================

#


import math
import numpy as np
from typing import Optional, Tuple, List,Union

def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x):
        return 1 / (np.exp(-x) + 1)
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
    if beta_schedule == "quad":
        betas = (
            np.linspace(
                beta_start ** 0.5,
                beta_end ** 0.5,
                num_diffusion_timesteps,
                dtype=np.float32,
            )
            ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    elif beta_schedule == "cosv2":
         betas = []
         for i in range(num_diffusion_timesteps):
            t1 = i / num_diffusion_timesteps
            t2 = (i + 1) / num_diffusion_timesteps
            betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), beta_end))
         betas = np.array(betas)
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas




class DDIMSampler:

    def __init__(self,
                 schedule_name: str,
                 diff_train_steps: int,
                 beta_start: float = 0.001,
                 beta_end: float = 0.2):
        betas = get_beta_schedule(schedule_name,
                                  beta_start=beta_start,
                                  beta_end=beta_end,
                                  num_diffusion_timesteps=diff_train_steps)
        self.betas = torch.tensor(betas).to(torch.float32)
        self.alpha = 1 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)

        self.timesteps = np.arange(0, diff_train_steps)[::-1]
        self.num_train_steps = diff_train_steps
        self._num_inference_steps = 20
        self.eta = 0

    def _get_variance(self,
                      timestep: Union[torch.Tensor, int],
                      prev_timestep: Union[torch.Tensor, int] ):
        alpha_t = self.alpha_cumprod[timestep]
        alpha_prev = self.alpha_cumprod[prev_timestep] if prev_timestep >= 0 else torch.tensor(1.0)
        beta_t = (1 - alpha_t)
        beta_prev = (1 - alpha_prev)
        return (beta_prev / beta_t) / (1 - alpha_t / alpha_prev)

    @staticmethod
    def treshold_sample(sample: torch.Tensor,
                        threshold: float = 0.9956,
                        max_clip: float = 1):
        batch_size, channels, height, width = sample.shape
        dtype = sample.dtype
        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
        sample = sample.reshape(batch_size, channels * height * width)

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, threshold, dim=1)
        s = torch.clamp(
            s, min=1, max=max_clip
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]

        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

        sample = sample.reshape(batch_size, channels, height, width)
        sample = sample.to(dtype)

        return sample

    def set_infer_steps(self,
                        num_steps: int,
                        device: torch.DeviceObjType):
        self._num_inference_steps = num_steps
        step_ratio = self.num_train_steps // self._num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
        timesteps = (np.arange(0, num_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
        self.timesteps = torch.from_numpy(timesteps).to(device)

    @torch.no_grad()
    def p_sample(self,
                 x_t: torch.Tensor,
                 t_now: Union[torch.Tensor, int],
                 pred_net):
        prev_timestep = t_now - self.num_train_steps // self._num_inference_steps
        alpha_t = self.alpha_cumprod[t_now]
        alpha_prev = self.alpha_cumprod[prev_timestep] if prev_timestep >= 0 else torch.tensor(1.0)
        var = self._get_variance(t_now, prev_timestep)
        eps = torch.randn_like(x_t).to(x_t.device)
        t_now = (torch.ones((x_t.shape[0],),
                            device=x_t.device,
                            dtype=torch.int32) * t_now).to(x_t.device)
        eta_t = pred_net(x_t, t_now)

        x0_t = (x_t - eta_t * (1 - alpha_t).sqrt()) / alpha_t.sqrt()

        c1 = self.eta * var.sqrt()
        c2 = ((1 - alpha_prev) - c1 ** 2).sqrt()
        x_tminus = alpha_prev.sqrt() * x0_t + c2 * eta_t + c1 * eps
        return x_tminus, x0_t

    def q_sample(self,
                 x_t: torch.Tensor,
                 timesteps: Union[torch.Tensor, int]):

        alpha_t = self.alpha_cumprod[timesteps].to(timesteps.device)
        alpha_t = alpha_t.flatten().to(x_t.device)[:, None, None, None]
        eps = torch.randn(*list(x_t.shape)).to(x_t.device)
        x_t = alpha_t.sqrt() * x_t + (1 - alpha_t).sqrt() * eps
        return x_t, eps


import copy

class DiffusionModel:
    def __init__(self,
                 main_net: DiffUnet,
                 ema_net: Optional[DiffUnet] = None,
                 num_steps: int = 100,
                 input_res: Union[Tuple[int, int], List[int]] = (32, 32),
                 emma: float = 0.999,
                 noise_sch_name: str = 'cosv2',
                 **noise_sch_kwargs):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.eps_net = main_net.to(self.device)
        self.ema_net = ema_net if ema_net is not None else copy.deepcopy(main_net)
        self.ema_net = self.ema_net.to(self.device)
        self.ema_net.eval()
        self.steps = num_steps
        self.res = (3,) + input_res if isinstance(input_res, tuple) else [3] + input_res
        self.num_steps = num_steps
        self.scheduler = DDIMSampler(noise_sch_name,
                                     diff_train_steps=num_steps,
                                     **noise_sch_kwargs)
        self.emma = emma

    @torch.no_grad()
    def generate(self,
                 num_samples: int = 1,
                 num_infer_steps: int = 25,
                 pred_net: Optional[str] = 'ema',
                 return_list: bool = False,
                 x_t: Optional[torch.Tensor] = None):
        shape = (num_samples,) + self.res if isinstance(self.res, tuple) else [num_samples] + self.res
        x_t = torch.randn(*shape).to(self.device) if x_t is None else x_t
        self.scheduler.set_infer_steps(num_infer_steps, x_t.device)
        pred_net = getattr(self, pred_net + "_net")
        xs = [x_t.cpu()]
        for step in range(num_infer_steps):
            t = self.scheduler.timesteps[step]
            x_t, _ = self.scheduler.p_sample(x_t, t, pred_net)
            xs.append(x_t.cpu())
        return xs[-1] if not return_list else xs

    @staticmethod
    def inverse_transform(img):
        """ Inverse transform the images after generation"""
        img = (img + 1) / 2
        img = np.clip(img, 0.0, 1.0)
        img = np.transpose(img, (1, 2, 0)) if len(img.shape) == 3 else np.transpose(img, (0, 2, 3, 1))
        return img

    @staticmethod
    def transform(img):
        """Transform the image before training converting the pixels values from [0, 255] to [-1, 1]"""
        img = img.to(torch.float32) / 127.5
        img = img - 1
        if len(img.shape) == 3:  # one sample
            img = torch.permute(img, (2, 0, 1))
        else:  # batch of samples
            img = torch.permute(img, (0, 3, 1, 2))
        return img

    def train_loss(self,
                   input_batch: torch.Tensor,
                   loss_type: Optional[str] = 'l1_loss',
                   **losskwargs):
        """Training loss"""
        bs, _, _, _ = input_batch.shape
        t = torch.randint(0, self.num_steps, size=(bs,))
        x_t, eps = self.scheduler.q_sample(input_batch, t)
        t = t.int().to(input_batch.device)
        eps_pred = self.eps_net(x_t, t)
        loss = getattr(torch.nn.functional, loss_type)(eps_pred, eps, **losskwargs)
        return loss

    def update_emma(self):
        for p_ema, p in zip(self.ema_net.parameters(), self.eps_net.parameters()):
            p_ema.data = (1 - self.emma) * p.data + p_ema.data * self.emma

    def train(self):
        self.eps_net.train()

    def eval(self):
        self.eps_net.eval()

    def parameters(self):
        return self.eps_net.parameters()

    def save(self,
             file_name: str):
        if not os.path.exists(file_name):
            os.makedirs(file_name)
        ema_path = file_name + '/ema.pt'
        net_path = file_name + "/eps.pt"
        torch.save(self.ema_net.state_dict(), ema_path)
        torch.save(self.eps_net.state_dict(), net_path)

    def load(self,
             path_nets: str):
        pathes = [os.path.join(path_nets, p) for p in os.listdir(path_nets) if ("ema" in p or "eps" in p)]
        for index in range(len(pathes)):
            if "ema" in pathes[index]:
                break
        ema_p = pathes[index]
        eps_p = pathes[int(not index)]
        map_loc = 'cpu' if not torch.cuda.is_available() else 'cuda'
        self.eps_net.load_state_dict(torch.load(eps_p, map_location=map_loc))
        self.ema_net.load_state_dict(torch.load(ema_p, map_location=map_loc))


DEVICE = torch.device('cuda' if torch.cuda.is_available()else 'cpu')
Unet = DiffUnet(block_depth=2)
Unet = Unet.to(DEVICE)
Model = DiffusionModel(Unet,num_steps=1000,input_res=(64,64))
optim = torch.optim.AdamW(Unet.parameters(),5e-4,weight_decay=0.01)

print("number of parameters:{}".format(sum([p.numel()for p in Unet.parameters()])))


kid = KernelInceptionDistance(subset_size=100,normalize=True).to(DEVICE)

# preprocess the data before KID
def prepare_kid(real,pred):

    real = F.resize(((real+1)*0.5).clamp(0,1),(299,299))
    pred = F.resize(((pred+1)*0.5).clamp(0,1),(299,299))
    return real,pred

#Trainig utilities

#

mean = lambda x:sum(x)/len(x)

def mul(args):
    res = 1
    for i in args:
      res*=i
    return res

def train_epoch(model,train_ds,opt,loss_type='mse_loss',num=5,max_norm=None,**kwargs):
  model.train()
  losses = []
  for i,inputs in enumerate(train_ds):
      inputs = inputs.to(DEVICE)
      opt.zero_grad()

      loss = model.train_loss(inputs,loss_type,**kwargs)
      loss.backward()
      if max_norm is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
      opt.step()
      loss = loss.item()
      losses.append(loss)
      model.update_emma()
      if (i+1)%(len(train_ds)//num)==0:
        print(f"Finished training on {100*(i+1)/len(train_ds):.1f}% of the dataset and loss:{loss:.3f}")
  return mean(losses)

def plot_grid_images(imgs,grid_shape):
    n_rows,n_cols = grid_shape
    plt.figure(figsize=(n_cols * 2, n_rows * 2))
    plt.title('Generated Images')
    for row in range(n_rows):
          for col in range(n_cols):
              index = row * n_cols + col
              plt.subplot(n_rows, n_cols, index + 1)
              img = imgs[index]
              img = (img+1)/2
              img= img.permute(1,2,0)
              img = torch.clamp(img,0.0,1.0)
              plt.imshow(img)
              plt.axis('off')

    plt.tight_layout()
    plt.show()
    plt.close()


def val_epoch(model,val_ds,loss_type='mse_loss',grid_shape=[2,3],infer_steps=20,**kwargs):
  model.eval()
  with torch.no_grad():
    losses = []
    for i,inputs in enumerate(val_ds):
        inputs = inputs.to(DEVICE)
        loss = model.train_loss(inputs,loss_type,**kwargs)
        loss = loss.item()
        losses.append(loss)
        samples = model.generate(int(inputs.shape[0]),infer_steps).cuda()
        inputs ,samples = prepare_kid(inputs,samples)
        kid.update(inputs,real=True)
        kid.update(samples,real=False)
  mean_kid,std_kid = kid.compute()
  kid.reset()
  return mean(losses),mean_kid,std_kid

Training Loop

# =============

#

import time
EPOCHS = 10
loss_type='l1_loss'
max_kid= 0.5
path_to_save = '/content/Pretrained/DDIM'

for e in range(EPOCHS):
  st = time.time()
  print(f"Started Training on:{e+1}/{EPOCHS}")
  train_loss = train_epoch(Model,train_iter,optim,loss_type,max_norm=1.0)
  val_loss,mean_kid,std_kid = val_epoch(Model,val_iter,loss_type)
  if mean_kid<max_kid:
    path = path_to_save+f"{mean_kid:.2f}_model" # the weights name includes the kid score to make it easy to identify the best weights
    Model.save(path)
    max_kid = mean_kid
  print(f"Finished Training on epoch: {e+1} in {(time.time()-st)/60:.1f} mins, train_loss: {train_loss:.3f},val_loss:{val_loss:.3f},kid(mean={mean_kid:.3f},std={std_kid:.3f})\n")

Generation Time!

# ================

#

# load the best parameters
Model.load("/content/Pretrained/Pretrained/0.05_model") # put the path  to the weights directory here

num_samples = 24
x = torch.randn(num_samples,3,64,64).cuda()

# if you want the images for througout the genration process you should make the return list attribute True
emma_samples = Model.generate(num_infer_steps=25,x_t=x)

samples = [sam.squeeze() for sam in emma_samples][-24:]
plot_grid_images(samples,[3,8])

# Make a grid from a set of images
def make_grid(images,
                padding=10,
                grid_sh=(2,10),
                color=(255,255,255)):
    rows,cols = grid_sh
    color = np.array(color,dtype=np.uint8)
    imgs_h = []
    for row in range(rows):
      imgs_w =[]
      for col in range(cols):
        idx = row * cols + col
        img= images[idx]
        sh = np.array(img.shape)+2*padding # padding for seperating the images in the grid
        sh[-1] = img.shape[-1]
        new_img = np.ones(sh,dtype=np.uint8)*color[None,None,:]
        new_img[padding:sh[0]-padding,padding:sh[1]-padding,:] = img
        imgs_w.append(new_img)
      new_img = np.concatenate(imgs_w,axis=1)
      imgs_h.append(new_img)
    return np.concatenate(imgs_h)

# inverse data transform
def inverse(img):
   img = (img+1)/2
   img = (np.clip(img,0.0,1.0)*255).astype(np.uint8)
   img = np.transpose(img,(0,2,3,1)) if len(img.shape)==4 else np.transpose(img,(1,2,0))
   return img

IMGS = Model.generate(24,25,return_list=True)
IMGS = [im.cpu().numpy() for im in IMGS]


# making a Gif to visualize the generation process
def make_gif(imgs,name='test',time=250):
   imgs = [Image.fromarray(im) for im in imgs]
   imgs[0].save(name+".gif", format="GIF", append_images=imgs,
               save_all=True, duration=time, loop=0)
IMGS =   [make_grid(inverse(ims),grid_sh=(3,8)) for ims in IMGS]
IMGS += [IMGS[-1]]*10
make_gif(IMGS)

Conclusion

That’s all for this one, folks. Congratulations, you now know the fundamentals to building a generative model! If you’re interested, you can try and build the conditional version where you generate images based on class inputs by tweaking the (model unet and resnet blocks)<https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py> and retrain the model.

Check out the other tutorials for more cool deep learning applications in PyTorch!

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources