{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "DCGAN Tutorial\n==============\n\n**Author**: [Nathan Inkawhich](https://github.com/inkawhich)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Introduction\n============\n\nThis tutorial will give an introduction to DCGANs through an example. We\nwill train a generative adversarial network (GAN) to generate new\ncelebrities after showing it pictures of many real celebrities. Most of\nthe code here is from the DCGAN implementation in\n[pytorch/examples](https://github.com/pytorch/examples), and this\ndocument will give a thorough explanation of the implementation and shed\nlight on how and why this model works. But don't worry, no prior\nknowledge of GANs is required, but it may require a first-timer to spend\nsome time reasoning about what is actually happening under the hood.\nAlso, for the sake of time it will help to have a GPU, or two. Lets\nstart from the beginning.\n\nGenerative Adversarial Networks\n===============================\n\nWhat is a GAN?\n--------------\n\nGANs are a framework for teaching a deep learning model to capture the\ntraining data distribution so we can generate new data from that same\ndistribution. GANs were invented by Ian Goodfellow in 2014 and first\ndescribed in the paper [Generative Adversarial\nNets](https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf).\nThey are made of two distinct models, a *generator* and a\n*discriminator*. The job of the generator is to spawn 'fake' images that\nlook like the training images. The job of the discriminator is to look\nat an image and output whether or not it is a real training image or a\nfake image from the generator. During training, the generator is\nconstantly trying to outsmart the discriminator by generating better and\nbetter fakes, while the discriminator is working to become a better\ndetective and correctly classify the real and fake images. The\nequilibrium of this game is when the generator is generating perfect\nfakes that look as if they came directly from the training data, and the\ndiscriminator is left to always guess at 50% confidence that the\ngenerator output is real or fake.\n\nNow, lets define some notation to be used throughout tutorial starting\nwith the discriminator. Let $x$ be data representing an image. $D(x)$ is\nthe discriminator network which outputs the (scalar) probability that\n$x$ came from training data rather than the generator. Here, since we\nare dealing with images, the input to $D(x)$ is an image of CHW size\n3x64x64. Intuitively, $D(x)$ should be HIGH when $x$ comes from training\ndata and LOW when $x$ comes from the generator. $D(x)$ can also be\nthought of as a traditional binary classifier.\n\nFor the generator's notation, let $z$ be a latent space vector sampled\nfrom a standard normal distribution. $G(z)$ represents the generator\nfunction which maps the latent vector $z$ to data-space. The goal of $G$\nis to estimate the distribution that the training data comes from\n($p_{data}$) so it can generate fake samples from that estimated\ndistribution ($p_g$).\n\nSo, $D(G(z))$ is the probability (scalar) that the output of the\ngenerator $G$ is a real image. As described in [Goodfellow's\npaper](https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf),\n$D$ and $G$ play a minimax game in which $D$ tries to maximize the\nprobability it correctly classifies reals and fakes ($logD(x)$), and $G$\ntries to minimize the probability that $D$ will predict its outputs are\nfake ($log(1-D(G(z)))$). From the paper, the GAN loss function is\n\n$$\\underset{G}{\\text{min}} \\underset{D}{\\text{max}}V(D,G) = \\mathbb{E}_{x\\sim p_{data}(x)}\\big[logD(x)\\big] + \\mathbb{E}_{z\\sim p_{z}(z)}\\big[log(1-D(G(z)))\\big]$$\n\nIn theory, the solution to this minimax game is where $p_g = p_{data}$,\nand the discriminator guesses randomly if the inputs are real or fake.\nHowever, the convergence theory of GANs is still being actively\nresearched and in reality models do not always train to this point.\n\nWhat is a DCGAN?\n----------------\n\nA DCGAN is a direct extension of the GAN described above, except that it\nexplicitly uses convolutional and convolutional-transpose layers in the\ndiscriminator and generator, respectively. It was first described by\nRadford et. al.\u00a0in the paper [Unsupervised Representation Learning With\nDeep Convolutional Generative Adversarial\nNetworks](https://arxiv.org/pdf/1511.06434.pdf). The discriminator is\nmade up of strided\n[convolution](https://pytorch.org/docs/stable/nn.html#torch.nn.Conv2d)\nlayers, [batch\nnorm](https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm2d)\nlayers, and\n[LeakyReLU](https://pytorch.org/docs/stable/nn.html#torch.nn.LeakyReLU)\nactivations. The input is a 3x64x64 input image and the output is a\nscalar probability that the input is from the real data distribution.\nThe generator is comprised of\n[convolutional-transpose](https://pytorch.org/docs/stable/nn.html#torch.nn.ConvTranspose2d)\nlayers, batch norm layers, and\n[ReLU](https://pytorch.org/docs/stable/nn.html#relu) activations. The\ninput is a latent vector, $z$, that is drawn from a standard normal\ndistribution and the output is a 3x64x64 RGB image. The strided\nconv-transpose layers allow the latent vector to be transformed into a\nvolume with the same shape as an image. In the paper, the authors also\ngive some tips about how to setup the optimizers, how to calculate the\nloss functions, and how to initialize the model weights, all of which\nwill be explained in the coming sections.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#%matplotlib inline\nimport argparse\nimport os\nimport random\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.optim as optim\nimport torch.utils.data\nimport torchvision.datasets as dset\nimport torchvision.transforms as transforms\nimport torchvision.utils as vutils\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport matplotlib.animation as animation\nfrom IPython.display import HTML\n\n# Set random seed for reproducibility\nmanualSeed = 999\n#manualSeed = random.randint(1, 10000) # use if you want new results\nprint(\"Random Seed: \", manualSeed)\nrandom.seed(manualSeed)\ntorch.manual_seed(manualSeed)\ntorch.use_deterministic_algorithms(True) # Needed for reproducible results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Inputs\n======\n\nLet's define some inputs for the run:\n\n- `dataroot` - the path to the root of the dataset folder. We will\n talk more about the dataset in the next section.\n- `workers` - the number of worker threads for loading the data with\n the `DataLoader`.\n- `batch_size` - the batch size used in training. The DCGAN paper uses\n a batch size of 128.\n- `image_size` - the spatial size of the images used for training.\n This implementation defaults to 64x64. If another size is desired,\n the structures of D and G must be changed. See\n [here](https://github.com/pytorch/examples/issues/70) for more\n details.\n- `nc` - number of color channels in the input images. For color\n images this is 3.\n- `nz` - length of latent vector.\n- `ngf` - relates to the depth of feature maps carried through the\n generator.\n- `ndf` - sets the depth of feature maps propagated through the\n discriminator.\n- `num_epochs` - number of training epochs to run. Training for longer\n will probably lead to better results but will also take much longer.\n- `lr` - learning rate for training. As described in the DCGAN paper,\n this number should be 0.0002.\n- `beta1` - beta1 hyperparameter for Adam optimizers. As described in\n paper, this number should be 0.5.\n- `ngpu` - number of GPUs available. If this is 0, code will run in\n CPU mode. If this number is greater than 0 it will run on that\n number of GPUs.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Root directory for dataset\ndataroot = \"data/celeba\"\n\n# Number of workers for dataloader\nworkers = 2\n\n# Batch size during training\nbatch_size = 128\n\n# Spatial size of training images. All images will be resized to this\n# size using a transformer.\nimage_size = 64\n\n# Number of channels in the training images. For color images this is 3\nnc = 3\n\n# Size of z latent vector (i.e. size of generator input)\nnz = 100\n\n# Size of feature maps in generator\nngf = 64\n\n# Size of feature maps in discriminator\nndf = 64\n\n# Number of training epochs\nnum_epochs = 5\n\n# Learning rate for optimizers\nlr = 0.0002\n\n# Beta1 hyperparameter for Adam optimizers\nbeta1 = 0.5\n\n# Number of GPUs available. Use 0 for CPU mode.\nngpu = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Data\n====\n\nIn this tutorial we will use the [Celeb-A Faces\ndataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) which can be\ndownloaded at the linked site, or in [Google\nDrive](https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg).\nThe dataset will download as a file named `img_align_celeba.zip`. Once\ndownloaded, create a directory named `celeba` and extract the zip file\ninto that directory. Then, set the `dataroot` input for this notebook to\nthe `celeba` directory you just created. The resulting directory\nstructure should be:\n\n``` {.sh}\n/path/to/celeba\n -> img_align_celeba \n -> 188242.jpg\n -> 173822.jpg\n -> 284702.jpg\n -> 537394.jpg\n ...\n```\n\nThis is an important step because we will be using the `ImageFolder`\ndataset class, which requires there to be subdirectories in the dataset\nroot folder. Now, we can create the dataset, create the dataloader, set\nthe device to run on, and finally visualize some of the training data.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# We can use an image folder dataset the way we have it setup.\n# Create the dataset\ndataset = dset.ImageFolder(root=dataroot,\n transform=transforms.Compose([\n transforms.Resize(image_size),\n transforms.CenterCrop(image_size),\n transforms.ToTensor(),\n transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n ]))\n# Create the dataloader\ndataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,\n shuffle=True, num_workers=workers)\n\n# Decide which device we want to run on\ndevice = torch.device(\"cuda:0\" if (torch.cuda.is_available() and ngpu > 0) else \"cpu\")\n\n# Plot some training images\nreal_batch = next(iter(dataloader))\nplt.figure(figsize=(8,8))\nplt.axis(\"off\")\nplt.title(\"Training Images\")\nplt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation\n==============\n\nWith our input parameters set and the dataset prepared, we can now get\ninto the implementation. We will start with the weight initialization\nstrategy, then talk about the generator, discriminator, loss functions,\nand training loop in detail.\n\nWeight Initialization\n---------------------\n\nFrom the DCGAN paper, the authors specify that all model weights shall\nbe randomly initialized from a Normal distribution with `mean=0`,\n`stdev=0.02`. The `weights_init` function takes an initialized model as\ninput and reinitializes all convolutional, convolutional-transpose, and\nbatch normalization layers to meet this criteria. This function is\napplied to the models immediately after initialization.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# custom weights initialization called on ``netG`` and ``netD``\ndef weights_init(m):\n classname = m.__class__.__name__\n if classname.find('Conv') != -1:\n nn.init.normal_(m.weight.data, 0.0, 0.02)\n elif classname.find('BatchNorm') != -1:\n nn.init.normal_(m.weight.data, 1.0, 0.02)\n nn.init.constant_(m.bias.data, 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generator\n=========\n\nThe generator, $G$, is designed to map the latent space vector ($z$) to\ndata-space. Since our data are images, converting $z$ to data-space\nmeans ultimately creating a RGB image with the same size as the training\nimages (i.e.\u00a03x64x64). In practice, this is accomplished through a\nseries of strided two dimensional convolutional transpose layers, each\npaired with a 2d batch norm layer and a relu activation. The output of\nthe generator is fed through a tanh function to return it to the input\ndata range of $[-1,1]$. It is worth noting the existence of the batch\nnorm functions after the conv-transpose layers, as this is a critical\ncontribution of the DCGAN paper. These layers help with the flow of\ngradients during training. An image of the generator from the DCGAN\npaper is shown below.\n\n![](https://pytorch.org/tutorials/_static/img/dcgan_generator.png)\n\nNotice, how the inputs we set in the input section (`nz`, `ngf`, and\n`nc`) influence the generator architecture in code. `nz` is the length\nof the z input vector, `ngf` relates to the size of the feature maps\nthat are propagated through the generator, and `nc` is the number of\nchannels in the output image (set to 3 for RGB images). Below is the\ncode for the generator.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Generator Code\n\nclass Generator(nn.Module):\n def __init__(self, ngpu):\n super(Generator, self).__init__()\n self.ngpu = ngpu\n self.main = nn.Sequential(\n # input is Z, going into a convolution\n nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),\n nn.BatchNorm2d(ngf * 8),\n nn.ReLU(True),\n # state size. ``(ngf*8) x 4 x 4``\n nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),\n nn.BatchNorm2d(ngf * 4),\n nn.ReLU(True),\n # state size. ``(ngf*4) x 8 x 8``\n nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),\n nn.BatchNorm2d(ngf * 2),\n nn.ReLU(True),\n # state size. ``(ngf*2) x 16 x 16``\n nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),\n nn.BatchNorm2d(ngf),\n nn.ReLU(True),\n # state size. ``(ngf) x 32 x 32``\n nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),\n nn.Tanh()\n # state size. ``(nc) x 64 x 64``\n )\n\n def forward(self, input):\n return self.main(input)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can instantiate the generator and apply the `weights_init`\nfunction. Check out the printed model to see how the generator object is\nstructured.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Create the generator\nnetG = Generator(ngpu).to(device)\n\n# Handle multi-GPU if desired\nif (device.type == 'cuda') and (ngpu > 1):\n netG = nn.DataParallel(netG, list(range(ngpu)))\n\n# Apply the ``weights_init`` function to randomly initialize all weights\n# to ``mean=0``, ``stdev=0.02``.\nnetG.apply(weights_init)\n\n# Print the model\nprint(netG)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Discriminator\n=============\n\nAs mentioned, the discriminator, $D$, is a binary classification network\nthat takes an image as input and outputs a scalar probability that the\ninput image is real (as opposed to fake). Here, $D$ takes a 3x64x64\ninput image, processes it through a series of Conv2d, BatchNorm2d, and\nLeakyReLU layers, and outputs the final probability through a Sigmoid\nactivation function. This architecture can be extended with more layers\nif necessary for the problem, but there is significance to the use of\nthe strided convolution, BatchNorm, and LeakyReLUs. The DCGAN paper\nmentions it is a good practice to use strided convolution rather than\npooling to downsample because it lets the network learn its own pooling\nfunction. Also batch norm and leaky relu functions promote healthy\ngradient flow which is critical for the learning process of both $G$ and\n$D$.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Discriminator Code\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Discriminator(nn.Module):\n def __init__(self, ngpu):\n super(Discriminator, self).__init__()\n self.ngpu = ngpu\n self.main = nn.Sequential(\n # input is ``(nc) x 64 x 64``\n nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),\n nn.LeakyReLU(0.2, inplace=True),\n # state size. ``(ndf) x 32 x 32``\n nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),\n nn.BatchNorm2d(ndf * 2),\n nn.LeakyReLU(0.2, inplace=True),\n # state size. ``(ndf*2) x 16 x 16``\n nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),\n nn.BatchNorm2d(ndf * 4),\n nn.LeakyReLU(0.2, inplace=True),\n # state size. ``(ndf*4) x 8 x 8``\n nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),\n nn.BatchNorm2d(ndf * 8),\n nn.LeakyReLU(0.2, inplace=True),\n # state size. ``(ndf*8) x 4 x 4``\n nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),\n nn.Sigmoid()\n )\n\n def forward(self, input):\n return self.main(input)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, as with the generator, we can create the discriminator, apply the\n`weights_init` function, and print the model's structure.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Create the Discriminator\nnetD = Discriminator(ngpu).to(device)\n\n# Handle multi-GPU if desired\nif (device.type == 'cuda') and (ngpu > 1):\n netD = nn.DataParallel(netD, list(range(ngpu)))\n \n# Apply the ``weights_init`` function to randomly initialize all weights\n# like this: ``to mean=0, stdev=0.2``.\nnetD.apply(weights_init)\n\n# Print the model\nprint(netD)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loss Functions and Optimizers\n=============================\n\nWith $D$ and $G$ setup, we can specify how they learn through the loss\nfunctions and optimizers. We will use the Binary Cross Entropy loss\n([BCELoss](https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html#torch.nn.BCELoss))\nfunction which is defined in PyTorch as:\n\n$$\\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\\top, \\quad l_n = - \\left[ y_n \\cdot \\log x_n + (1 - y_n) \\cdot \\log (1 - x_n) \\right]$$\n\nNotice how this function provides the calculation of both log components\nin the objective function (i.e. $log(D(x))$ and $log(1-D(G(z)))$). We\ncan specify what part of the BCE equation to use with the $y$ input.\nThis is accomplished in the training loop which is coming up soon, but\nit is important to understand how we can choose which component we wish\nto calculate just by changing $y$ (i.e.\u00a0GT labels).\n\nNext, we define our real label as 1 and the fake label as 0. These\nlabels will be used when calculating the losses of $D$ and $G$, and this\nis also the convention used in the original GAN paper. Finally, we set\nup two separate optimizers, one for $D$ and one for $G$. As specified in\nthe DCGAN paper, both are Adam optimizers with learning rate 0.0002 and\nBeta1 = 0.5. For keeping track of the generator's learning progression,\nwe will generate a fixed batch of latent vectors that are drawn from a\nGaussian distribution (i.e.\u00a0fixed\\_noise) . In the training loop, we\nwill periodically input this fixed\\_noise into $G$, and over the\niterations we will see images form out of the noise.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Initialize the ``BCELoss`` function\ncriterion = nn.BCELoss()\n\n# Create batch of latent vectors that we will use to visualize\n# the progression of the generator\nfixed_noise = torch.randn(64, nz, 1, 1, device=device)\n\n# Establish convention for real and fake labels during training\nreal_label = 1.\nfake_label = 0.\n\n# Setup Adam optimizers for both G and D\noptimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))\noptimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training\n========\n\nFinally, now that we have all of the parts of the GAN framework defined,\nwe can train it. Be mindful that training GANs is somewhat of an art\nform, as incorrect hyperparameter settings lead to mode collapse with\nlittle explanation of what went wrong. Here, we will closely follow\nAlgorithm 1 from the [Goodfellow's\npaper](https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf),\nwhile abiding by some of the best practices shown in\n[ganhacks](https://github.com/soumith/ganhacks). Namely, we will\n\"construct different mini-batches for real and fake\" images, and also\nadjust G's objective function to maximize $log(D(G(z)))$. Training is\nsplit up into two main parts. Part 1 updates the Discriminator and Part\n2 updates the Generator.\n\n**Part 1 - Train the Discriminator**\n\nRecall, the goal of training the discriminator is to maximize the\nprobability of correctly classifying a given input as real or fake. In\nterms of Goodfellow, we wish to \"update the discriminator by ascending\nits stochastic gradient\". Practically, we want to maximize\n$log(D(x)) + log(1-D(G(z)))$. Due to the separate mini-batch suggestion\nfrom [ganhacks](https://github.com/soumith/ganhacks), we will calculate\nthis in two steps. First, we will construct a batch of real samples from\nthe training set, forward pass through $D$, calculate the loss\n($log(D(x))$), then calculate the gradients in a backward pass.\nSecondly, we will construct a batch of fake samples with the current\ngenerator, forward pass this batch through $D$, calculate the loss\n($log(1-D(G(z)))$), and *accumulate* the gradients with a backward pass.\nNow, with the gradients accumulated from both the all-real and all-fake\nbatches, we call a step of the Discriminator's optimizer.\n\n**Part 2 - Train the Generator**\n\nAs stated in the original paper, we want to train the Generator by\nminimizing $log(1-D(G(z)))$ in an effort to generate better fakes. As\nmentioned, this was shown by Goodfellow to not provide sufficient\ngradients, especially early in the learning process. As a fix, we\ninstead wish to maximize $log(D(G(z)))$. In the code we accomplish this\nby: classifying the Generator output from Part 1 with the Discriminator,\ncomputing G's loss *using real labels as GT*, computing G's gradients in\na backward pass, and finally updating G's parameters with an optimizer\nstep. It may seem counter-intuitive to use the real labels as GT labels\nfor the loss function, but this allows us to use the $log(x)$ part of\nthe `BCELoss` (rather than the $log(1-x)$ part) which is exactly what we\nwant.\n\nFinally, we will do some statistic reporting and at the end of each\nepoch we will push our fixed\\_noise batch through the generator to\nvisually track the progress of G's training. The training statistics\nreported are:\n\n- **Loss\\_D** - discriminator loss calculated as the sum of losses for\n the all real and all fake batches ($log(D(x)) + log(1 - D(G(z)))$).\n- **Loss\\_G** - generator loss calculated as $log(D(G(z)))$\n- **D(x)** - the average output (across the batch) of the\n discriminator for the all real batch. This should start close to 1\n then theoretically converge to 0.5 when G gets better. Think about\n why this is.\n- **D(G(z))** - average discriminator outputs for the all fake batch.\n The first number is before D is updated and the second number is\n after D is updated. These numbers should start near 0 and converge\n to 0.5 as G gets better. Think about why this is.\n\n**Note:** This step might take a while, depending on how many epochs you\nrun and if you removed some data from the dataset.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Training Loop\n\n# Lists to keep track of progress\nimg_list = []\nG_losses = []\nD_losses = []\niters = 0\n\nprint(\"Starting Training Loop...\")\n# For each epoch\nfor epoch in range(num_epochs):\n # For each batch in the dataloader\n for i, data in enumerate(dataloader, 0):\n \n ############################\n # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))\n ###########################\n ## Train with all-real batch\n netD.zero_grad()\n # Format batch\n real_cpu = data[0].to(device)\n b_size = real_cpu.size(0)\n label = torch.full((b_size,), real_label, dtype=torch.float, device=device)\n # Forward pass real batch through D\n output = netD(real_cpu).view(-1)\n # Calculate loss on all-real batch\n errD_real = criterion(output, label)\n # Calculate gradients for D in backward pass\n errD_real.backward()\n D_x = output.mean().item()\n\n ## Train with all-fake batch\n # Generate batch of latent vectors\n noise = torch.randn(b_size, nz, 1, 1, device=device)\n # Generate fake image batch with G\n fake = netG(noise)\n label.fill_(fake_label)\n # Classify all fake batch with D\n output = netD(fake.detach()).view(-1)\n # Calculate D's loss on the all-fake batch\n errD_fake = criterion(output, label)\n # Calculate the gradients for this batch, accumulated (summed) with previous gradients\n errD_fake.backward()\n D_G_z1 = output.mean().item()\n # Compute error of D as sum over the fake and the real batches\n errD = errD_real + errD_fake\n # Update D\n optimizerD.step()\n\n ############################\n # (2) Update G network: maximize log(D(G(z)))\n ###########################\n netG.zero_grad()\n label.fill_(real_label) # fake labels are real for generator cost\n # Since we just updated D, perform another forward pass of all-fake batch through D\n output = netD(fake).view(-1)\n # Calculate G's loss based on this output\n errG = criterion(output, label)\n # Calculate gradients for G\n errG.backward()\n D_G_z2 = output.mean().item()\n # Update G\n optimizerG.step()\n \n # Output training stats\n if i % 50 == 0:\n print('[%d/%d][%d/%d]\\tLoss_D: %.4f\\tLoss_G: %.4f\\tD(x): %.4f\\tD(G(z)): %.4f / %.4f'\n % (epoch, num_epochs, i, len(dataloader),\n errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))\n \n # Save Losses for plotting later\n G_losses.append(errG.item())\n D_losses.append(errD.item())\n \n # Check how the generator is doing by saving G's output on fixed_noise\n if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):\n with torch.no_grad():\n fake = netG(fixed_noise).detach().cpu()\n img_list.append(vutils.make_grid(fake, padding=2, normalize=True))\n \n iters += 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Results\n=======\n\nFinally, lets check out how we did. Here, we will look at three\ndifferent results. First, we will see how D and G's losses changed\nduring training. Second, we will visualize G's output on the\nfixed\\_noise batch for every epoch. And third, we will look at a batch\nof real data next to a batch of fake data from G.\n\n**Loss versus training iteration**\n\nBelow is a plot of D & G's losses versus training iterations.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.figure(figsize=(10,5))\nplt.title(\"Generator and Discriminator Loss During Training\")\nplt.plot(G_losses,label=\"G\")\nplt.plot(D_losses,label=\"D\")\nplt.xlabel(\"iterations\")\nplt.ylabel(\"Loss\")\nplt.legend()\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Visualization of G's progression**\n\nRemember how we saved the generator's output on the fixed\\_noise batch\nafter every epoch of training. Now, we can visualize the training\nprogression of G with an animation. Press the play button to start the\nanimation.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = plt.figure(figsize=(8,8))\nplt.axis(\"off\")\nims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]\nani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)\n\nHTML(ani.to_jshtml())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Real Images vs.\u00a0Fake Images**\n\nFinally, lets take a look at some real images and fake images side by\nside.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Grab a batch of real images from the dataloader\nreal_batch = next(iter(dataloader))\n\n# Plot the real images\nplt.figure(figsize=(15,15))\nplt.subplot(1,2,1)\nplt.axis(\"off\")\nplt.title(\"Real Images\")\nplt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))\n\n# Plot the fake images from the last epoch\nplt.subplot(1,2,2)\nplt.axis(\"off\")\nplt.title(\"Fake Images\")\nplt.imshow(np.transpose(img_list[-1],(1,2,0)))\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Where to Go Next\n================\n\nWe have reached the end of our journey, but there are several places you\ncould go from here. You could:\n\n- Train for longer to see how good the results get\n- Modify this model to take a different dataset and possibly change\n the size of the images and the model architecture\n- Check out some other cool GAN projects\n [here](https://github.com/nashory/gans-awesome-applications)\n- Create GANs that generate\n [music](https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio/)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }