{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Introduction](introyt1_tutorial.html) \\|\\|\n[Tensors](tensors_deeper_tutorial.html) \\|\\|\n[Autograd](autogradyt_tutorial.html) \\|\\| **Building Models** \\|\\|\n[TensorBoard Support](tensorboardyt_tutorial.html) \\|\\| [Training\nModels](trainingyt.html) \\|\\| [Model Understanding](captumyt.html)\n\nBuilding Models with PyTorch\n============================\n\nFollow along with the video below or on\n[youtube](https://www.youtube.com/watch?v=OSqIP-mOWOI).\n\n``` {.python .jupyter-code-cell}\nfrom IPython.display import display, HTML\nhtml_code = \"\"\"\n
\n \n
\n\"\"\"\ndisplay(HTML(html_code))\n```\n\n`torch.nn.Module` and `torch.nn.Parameter`\n------------------------------------------\n\nIn this video, we'll be discussing some of the tools PyTorch makes\navailable for building deep learning networks.\n\nExcept for `Parameter`, the classes we discuss in this video are all\nsubclasses of `torch.nn.Module`. This is the PyTorch base class meant to\nencapsulate behaviors specific to PyTorch Models and their components.\n\nOne important behavior of `torch.nn.Module` is registering parameters.\nIf a particular `Module` subclass has learning weights, these weights\nare expressed as instances of `torch.nn.Parameter`. The `Parameter`\nclass is a subclass of `torch.Tensor`, with the special behavior that\nwhen they are assigned as attributes of a `Module`, they are added to\nthe list of that modules parameters. These parameters may be accessed\nthrough the `parameters()` method on the `Module` class.\n\nAs a simple example, here's a very simple model with two linear layers\nand an activation function. We'll create an instance of it and ask it to\nreport on its parameters:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\n\nclass TinyModel(torch.nn.Module):\n \n def __init__(self):\n super(TinyModel, self).__init__()\n \n self.linear1 = torch.nn.Linear(100, 200)\n self.activation = torch.nn.ReLU()\n self.linear2 = torch.nn.Linear(200, 10)\n self.softmax = torch.nn.Softmax()\n \n def forward(self, x):\n x = self.linear1(x)\n x = self.activation(x)\n x = self.linear2(x)\n x = self.softmax(x)\n return x\n\ntinymodel = TinyModel()\n\nprint('The model:')\nprint(tinymodel)\n\nprint('\\n\\nJust one layer:')\nprint(tinymodel.linear2)\n\nprint('\\n\\nModel params:')\nfor param in tinymodel.parameters():\n print(param)\n\nprint('\\n\\nLayer params:')\nfor param in tinymodel.linear2.parameters():\n print(param)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This shows the fundamental structure of a PyTorch model: there is an\n`__init__()` method that defines the layers and other components of a\nmodel, and a `forward()` method where the computation gets done. Note\nthat we can print the model, or any of its submodules, to learn about\nits structure.\n\nCommon Layer Types\n==================\n\nLinear Layers\n-------------\n\nThe most basic type of neural network layer is a *linear* or *fully\nconnected* layer. This is a layer where every input influences every\noutput of the layer to a degree specified by the layer's weights. If a\nmodel has *m* inputs and *n* outputs, the weights will be an *m* x *n*\nmatrix. For example:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "lin = torch.nn.Linear(3, 2)\nx = torch.rand(1, 3)\nprint('Input:')\nprint(x)\n\nprint('\\n\\nWeight and Bias parameters:')\nfor param in lin.parameters():\n print(param)\n\ny = lin(x)\nprint('\\n\\nOutput:')\nprint(y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you do the matrix multiplication of `x` by the linear layer's\nweights, and add the biases, you'll find that you get the output vector\n`y`.\n\nOne other important feature to note: When we checked the weights of our\nlayer with `lin.weight`, it reported itself as a `Parameter` (which is a\nsubclass of `Tensor`), and let us know that it's tracking gradients with\nautograd. This is a default behavior for `Parameter` that differs from\n`Tensor`.\n\nLinear layers are used widely in deep learning models. One of the most\ncommon places you'll see them is in classifier models, which will\nusually have one or more linear layers at the end, where the last layer\nwill have *n* outputs, where *n* is the number of classes the classifier\naddresses.\n\nConvolutional Layers\n====================\n\n*Convolutional* layers are built to handle data with a high degree of\nspatial correlation. They are very commonly used in computer vision,\nwhere they detect close groupings of features which the compose into\nhigher-level features. They pop up in other contexts too - for example,\nin NLP applications, where a word's immediate context (that is, the\nother words nearby in the sequence) can affect the meaning of a\nsentence.\n\nWe saw convolutional layers in action in LeNet5 in an earlier video:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch.functional as F\n\n\nclass LeNet(torch.nn.Module):\n\n def __init__(self):\n super(LeNet, self).__init__()\n # 1 input image channel (black & white), 6 output channels, 5x5 square convolution\n # kernel\n self.conv1 = torch.nn.Conv2d(1, 6, 5)\n self.conv2 = torch.nn.Conv2d(6, 16, 3)\n # an affine operation: y = Wx + b\n self.fc1 = torch.nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension\n self.fc2 = torch.nn.Linear(120, 84)\n self.fc3 = torch.nn.Linear(84, 10)\n\n def forward(self, x):\n # Max pooling over a (2, 2) window\n x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n # If the size is a square you can only specify a single number\n x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n x = x.view(-1, self.num_flat_features(x))\n x = F.relu(self.fc1(x))\n x = F.relu(self.fc2(x))\n x = self.fc3(x)\n return x\n\n def num_flat_features(self, x):\n size = x.size()[1:] # all dimensions except the batch dimension\n num_features = 1\n for s in size:\n num_features *= s\n return num_features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's break down what's happening in the convolutional layers of this\nmodel. Starting with `conv1`:\n\n- LeNet5 is meant to take in a 1x32x32 black & white image. **The\n first argument to a convolutional layer's constructor is the number\n of input channels.** Here, it is 1. If we were building this model\n to look at 3-color channels, it would be 3.\n- A convolutional layer is like a window that scans over the image,\n looking for a pattern it recognizes. These patterns are called\n *features,* and one of the parameters of a convolutional layer is\n the number of features we would like it to learn. **This is the\n second argument to the constructor is the number of output\n features.** Here, we're asking our layer to learn 6 features.\n- Just above, I likened the convolutional layer to a window - but how\n big is the window? **The third argument is the window or kernel\n size.** Here, the \"5\" means we've chosen a 5x5 kernel. (If you want\n a kernel with height different from width, you can specify a tuple\n for this argument - e.g., `(3, 5)` to get a 3x5 convolution kernel.)\n\nThe output of a convolutional layer is an *activation map* - a spatial\nrepresentation of the presence of features in the input tensor. `conv1`\nwill give us an output tensor of 6x28x28; 6 is the number of features,\nand 28 is the height and width of our map. (The 28 comes from the fact\nthat when scanning a 5-pixel window over a 32-pixel row, there are only\n28 valid positions.)\n\nWe then pass the output of the convolution through a ReLU activation\nfunction (more on activation functions later), then through a max\npooling layer. The max pooling layer takes features near each other in\nthe activation map and groups them together. It does this by reducing\nthe tensor, merging every 2x2 group of cells in the output into a single\ncell, and assigning that cell the maximum value of the 4 cells that went\ninto it. This gives us a lower-resolution version of the activation map,\nwith dimensions 6x14x14.\n\nOur next convolutional layer, `conv2`, expects 6 input channels\n(corresponding to the 6 features sought by the first layer), has 16\noutput channels, and a 3x3 kernel. It puts out a 16x12x12 activation\nmap, which is again reduced by a max pooling layer to 16x6x6. Prior to\npassing this output to the linear layers, it is reshaped to a 16 \\* 6 \\*\n6 = 576-element vector for consumption by the next layer.\n\nThere are convolutional layers for addressing 1D, 2D, and 3D tensors.\nThere are also many more optional arguments for a conv layer\nconstructor, including stride length(e.g., only scanning every second or\nevery third position) in the input, padding (so you can scan out to the\nedges of the input), and more. See the\n[documentation](https://pytorch.org/docs/stable/nn.html#convolution-layers)\nfor more information.\n\nRecurrent Layers\n================\n\n*Recurrent neural networks* (or *RNNs)* are used for sequential data\n-anything from time-series measurements from a scientific instrument to\nnatural language sentences to DNA nucleotides. An RNN does this by\nmaintaining a *hidden state* that acts as a sort of memory for what it\nhas seen in the sequence so far.\n\nThe internal structure of an RNN layer - or its variants, the LSTM (long\nshort-term memory) and GRU (gated recurrent unit) - is moderately\ncomplex and beyond the scope of this video, but we'll show you what one\nlooks like in action with an LSTM-based part-of-speech tagger (a type of\nclassifier that tells you if a word is a noun, verb, etc.):\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class LSTMTagger(torch.nn.Module):\n\n def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):\n super(LSTMTagger, self).__init__()\n self.hidden_dim = hidden_dim\n\n self.word_embeddings = torch.nn.Embedding(vocab_size, embedding_dim)\n\n # The LSTM takes word embeddings as inputs, and outputs hidden states\n # with dimensionality hidden_dim.\n self.lstm = torch.nn.LSTM(embedding_dim, hidden_dim)\n\n # The linear layer that maps from hidden state space to tag space\n self.hidden2tag = torch.nn.Linear(hidden_dim, tagset_size)\n\n def forward(self, sentence):\n embeds = self.word_embeddings(sentence)\n lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))\n tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))\n tag_scores = F.log_softmax(tag_space, dim=1)\n return tag_scores" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The constructor has four arguments:\n\n- `vocab_size` is the number of words in the input vocabulary. Each\n word is a one-hot vector (or unit vector) in a\n `vocab_size`-dimensional space.\n- `tagset_size` is the number of tags in the output set.\n- `embedding_dim` is the size of the *embedding* space for the\n vocabulary. An embedding maps a vocabulary onto a low-dimensional\n space, where words with similar meanings are close together in the\n space.\n- `hidden_dim` is the size of the LSTM's memory.\n\nThe input will be a sentence with the words represented as indices of\none-hot vectors. The embedding layer will then map these down to an\n`embedding_dim`-dimensional space. The LSTM takes this sequence of\nembeddings and iterates over it, fielding an output vector of length\n`hidden_dim`. The final linear layer acts as a classifier; applying\n`log_softmax()` to the output of the final layer converts the output\ninto a normalized set of estimated probabilities that a given word maps\nto a given tag.\n\nIf you'd like to see this network in action, check out the [Sequence\nModels and LSTM\nNetworks](https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html)\ntutorial on pytorch.org.\n\nTransformers\n============\n\n*Transformers* are multi-purpose networks that have taken over the state\nof the art in NLP with models like BERT. A discussion of transformer\narchitecture is beyond the scope of this video, but PyTorch has a\n`Transformer` class that allows you to define the overall parameters of\na transformer model - the number of attention heads, the number of\nencoder & decoder layers, dropout and activation functions, etc. (You\ncan even build the BERT model from this single class, with the right\nparameters!) The `torch.nn.Transformer` class also has classes to\nencapsulate the individual components (`TransformerEncoder`,\n`TransformerDecoder`) and subcomponents (`TransformerEncoderLayer`,\n`TransformerDecoderLayer`). For details, check out the\n[documentation](https://pytorch.org/docs/stable/nn.html#transformer-layers)\non transformer classes.\n\nOther Layers and Functions\n--------------------------\n\nData Manipulation Layers\n========================\n\nThere are other layer types that perform important functions in models,\nbut don't participate in the learning process themselves.\n\n**Max pooling** (and its twin, min pooling) reduce a tensor by combining\ncells, and assigning the maximum value of the input cells to the output\ncell (we saw this). For example:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "my_tensor = torch.rand(1, 6, 6)\nprint(my_tensor)\n\nmaxpool_layer = torch.nn.MaxPool2d(3)\nprint(maxpool_layer(my_tensor))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you look closely at the values above, you'll see that each of the\nvalues in the maxpooled output is the maximum value of each quadrant of\nthe 6x6 input.\n\n**Normalization layers** re-center and normalize the output of one layer\nbefore feeding it to another. Centering and scaling the intermediate\ntensors has a number of beneficial effects, such as letting you use\nhigher learning rates without exploding/vanishing gradients.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "my_tensor = torch.rand(1, 4, 4) * 20 + 5\nprint(my_tensor)\n\nprint(my_tensor.mean())\n\nnorm_layer = torch.nn.BatchNorm1d(4)\nnormed_tensor = norm_layer(my_tensor)\nprint(normed_tensor)\n\nprint(normed_tensor.mean())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Running the cell above, we've added a large scaling factor and offset to\nan input tensor; you should see the input tensor's `mean()` somewhere in\nthe neighborhood of 15. After running it through the normalization\nlayer, you can see that the values are smaller, and grouped around zero\n- in fact, the mean should be very small (\\> 1e-8).\n\nThis is beneficial because many activation functions (discussed below)\nhave their strongest gradients near 0, but sometimes suffer from\nvanishing or exploding gradients for inputs that drive them far away\nfrom zero. Keeping the data centered around the area of steepest\ngradient will tend to mean faster, better learning and higher feasible\nlearning rates.\n\n**Dropout layers** are a tool for encouraging *sparse representations*\nin your model - that is, pushing it to do inference with less data.\n\nDropout layers work by randomly setting parts of the input tensor\n*during training* - dropout layers are always turned off for inference.\nThis forces the model to learn against this masked or reduced dataset.\nFor example:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "my_tensor = torch.rand(1, 4, 4)\n\ndropout = torch.nn.Dropout(p=0.4)\nprint(dropout(my_tensor))\nprint(dropout(my_tensor))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Above, you can see the effect of dropout on a sample tensor. You can use\nthe optional `p` argument to set the probability of an individual weight\ndropping out; if you don't it defaults to 0.5.\n\nActivation Functions\n====================\n\nActivation functions make deep learning possible. A neural network is\nreally a program - with many parameters - that *simulates a mathematical\nfunction*. If all we did was multiple tensors by layer weights\nrepeatedly, we could only simulate *linear functions;* further, there\nwould be no point to having many layers, as the whole network would\nreduce could be reduced to a single matrix multiplication. Inserting\n*non-linear* activation functions between layers is what allows a deep\nlearning model to simulate any function, rather than just linear ones.\n\n`torch.nn.Module` has objects encapsulating all of the major activation\nfunctions including ReLU and its many variants, Tanh, Hardtanh, sigmoid,\nand more. It also includes other functions, such as Softmax, that are\nmost useful at the output stage of a model.\n\nLoss Functions\n==============\n\nLoss functions tell us how far a model's prediction is from the correct\nanswer. PyTorch contains a variety of loss functions, including common\nMSE (mean squared error = L2 norm), Cross Entropy Loss and Negative\nLikelihood Loss (useful for classifiers), and others.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }