{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Introduction](introyt1_tutorial.html) \\|\\| **Tensors** \\|\\|\n[Autograd](autogradyt_tutorial.html) \\|\\| [Building\nModels](modelsyt_tutorial.html) \\|\\| [TensorBoard\nSupport](tensorboardyt_tutorial.html) \\|\\| [Training\nModels](trainingyt.html) \\|\\| [Model Understanding](captumyt.html)\n\nIntroduction to PyTorch Tensors\n===============================\n\nFollow along with the video below or on\n[youtube](https://www.youtube.com/watch?v=r7QDUPb2dCM).\n\n``` {.python .jupyter-code-cell}\nfrom IPython.display import display, HTML\nhtml_code = \"\"\"\n
\n \n
\n\"\"\"\ndisplay(HTML(html_code))\n```\n\nTensors are the central data abstraction in PyTorch. This interactive\nnotebook provides an in-depth introduction to the `torch.Tensor` class.\n\nFirst things first, let's import the PyTorch module. We'll also add\nPython's math module to facilitate some of the examples.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport math" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creating Tensors\n================\n\nThe simplest way to create a tensor is with the `torch.empty()` call:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x = torch.empty(3, 4)\nprint(type(x))\nprint(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's upack what we just did:\n\n- We created a tensor using one of the numerous factory methods\n attached to the `torch` module.\n- The tensor itself is 2-dimensional, having 3 rows and 4 columns.\n- The type of the object returned is `torch.Tensor`, which is an alias\n for `torch.FloatTensor`; by default, PyTorch tensors are populated\n with 32-bit floating point numbers. (More on data types below.)\n- You will probably see some random-looking values when printing your\n tensor. The `torch.empty()` call allocates memory for the tensor,\n but does not initialize it with any values - so what you're seeing\n is whatever was in memory at the time of allocation.\n\nA brief note about tensors and their number of dimensions, and\nterminology:\n\n- You will sometimes see a 1-dimensional tensor called a *vector.*\n- Likewise, a 2-dimensional tensor is often referred to as a *matrix.*\n- Anything with more than two dimensions is generally just called a\n tensor.\n\nMore often than not, you'll want to initialize your tensor with some\nvalue. Common cases are all zeros, all ones, or random values, and the\n`torch` module provides factory methods for all of these:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "zeros = torch.zeros(2, 3)\nprint(zeros)\n\nones = torch.ones(2, 3)\nprint(ones)\n\ntorch.manual_seed(1729)\nrandom = torch.rand(2, 3)\nprint(random)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The fctory methods all do just what you'd expect - we have a tensor full\nof zeros, another full of ones, and another with random values between 0\nand 1.\n\nRandom Tensors and Seeding\n==========================\n\nSpeaking of the random tensor, did you notice the call to\n`torch.manual_seed()` immediately preceding it? Initializing tensors,\nsuch as a model's learning weights, with random values is common but\nthere are times - especially in research settings - where you'll want\nsome assurance of the reproducibility of your results. Manually setting\nyour random number generator's seed is the way to do this. Let's look\nmore closely:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(1729)\nrandom1 = torch.rand(2, 3)\nprint(random1)\n\nrandom2 = torch.rand(2, 3)\nprint(random2)\n\ntorch.manual_seed(1729)\nrandom3 = torch.rand(2, 3)\nprint(random3)\n\nrandom4 = torch.rand(2, 3)\nprint(random4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What you should see above is that `random1` and `random3` carry\nidentical values, as do `random2` and `random4`. Manually setting the\nRNG's seed resets it, so that identical computations depending on random\nnumber should, in most settings, provide identical results.\n\nFor more information, see the [PyTorch documentation on\nreproducibility](https://pytorch.org/docs/stable/notes/randomness.html).\n\nTensor Shapes\n=============\n\nOften, when you're performing operations on two or more tensors, they\nwill need to be of the same *shape* - that is, having the same number of\ndimensions and the same number of cells in each dimension. For that, we\nhave the `torch.*_like()` methods:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x = torch.empty(2, 2, 3)\nprint(x.shape)\nprint(x)\n\nempty_like_x = torch.empty_like(x)\nprint(empty_like_x.shape)\nprint(empty_like_x)\n\nzeros_like_x = torch.zeros_like(x)\nprint(zeros_like_x.shape)\nprint(zeros_like_x)\n\nones_like_x = torch.ones_like(x)\nprint(ones_like_x.shape)\nprint(ones_like_x)\n\nrand_like_x = torch.rand_like(x)\nprint(rand_like_x.shape)\nprint(rand_like_x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first new thing in the code cell above is the use of the `.shape`\nproperty on a tensor. This property contains a list of the extent of\neach dimension of a tensor - in our case, `x` is a three-dimensional\ntensor with shape 2 x 2 x 3.\n\nBelow that, we call the `.empty_like()`, `.zeros_like()`,\n`.ones_like()`, and `.rand_like()` methods. Using the `.shape` property,\nwe can verify that each of these methods returns a tensor of identical\ndimensionality and extent.\n\nThe last way to create a tensor that will cover is to specify its data\ndirectly from a PyTorch collection:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "some_constants = torch.tensor([[3.1415926, 2.71828], [1.61803, 0.0072897]])\nprint(some_constants)\n\nsome_integers = torch.tensor((2, 3, 5, 7, 11, 13, 17, 19))\nprint(some_integers)\n\nmore_integers = torch.tensor(((2, 4, 6), [3, 6, 9]))\nprint(more_integers)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using `torch.tensor()` is the most straightforward way to create a\ntensor if you already have data in a Python tuple or list. As shown\nabove, nesting the collections will result in a multi-dimensional\ntensor.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

torch.tensor() creates a copy of the data.

\n```\n```{=html}\n
\n```\nTensor Data Types\n=================\n\nSetting the datatype of a tensor is possible a couple of ways:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "a = torch.ones((2, 3), dtype=torch.int16)\nprint(a)\n\nb = torch.rand((2, 3), dtype=torch.float64) * 20.\nprint(b)\n\nc = b.to(torch.int32)\nprint(c)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The simplest way to set the underlying data type of a tensor is with an\noptional argument at creation time. In the first line of the cell above,\nwe set `dtype=torch.int16` for the tensor `a`. When we print `a`, we can\nsee that it's full of `1` rather than `1.` - Python's subtle cue that\nthis is an integer type rather than floating point.\n\nAnother thing to notice about printing `a` is that, unlike when we left\n`dtype` as the default (32-bit floating point), printing the tensor also\nspecifies its `dtype`.\n\nYou may have also spotted that we went from specifying the tensor's\nshape as a series of integer arguments, to grouping those arguments in a\ntuple. This is not strictly necessary - PyTorch will take a series of\ninitial, unlabeled integer arguments as a tensor shape - but when adding\nthe optional arguments, it can make your intent more readable.\n\nThe other way to set the datatype is with the `.to()` method. In the\ncell above, we create a random floating point tensor `b` in the usual\nway. Following that, we create `c` by converting `b` to a 32-bit integer\nwith the `.to()` method. Note that `c` contains all the same values as\n`b`, but truncated to integers.\n\nFor more information, see the [data types\ndocumentation](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype).\n\nMath & Logic with PyTorch Tensors\n=================================\n\nNow that you know some of the ways to create a tensor... what can you do\nwith them?\n\nLet's look at basic arithmetic first, and how tensors interact with\nsimple scalars:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ones = torch.zeros(2, 2) + 1\ntwos = torch.ones(2, 2) * 2\nthrees = (torch.ones(2, 2) * 7 - 1) / 2\nfours = twos ** 2\nsqrt2s = twos ** 0.5\n\nprint(ones)\nprint(twos)\nprint(threes)\nprint(fours)\nprint(sqrt2s)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see above, arithmetic operations between tensors and scalars,\nsuch as addition, subtraction, multiplication, division, and\nexponentiation are distributed over every element of the tensor. Because\nthe output of such an operation will be a tensor, you can chain them\ntogether with the usual operator precedence rules, as in the line where\nwe create `threes`.\n\nSimilar operations between two tensors also behave like you'd\nintuitively expect:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "powers2 = twos ** torch.tensor([[1, 2], [3, 4]])\nprint(powers2)\n\nfives = ones + fours\nprint(fives)\n\ndozens = threes * fours\nprint(dozens)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's important to note here that all of the tensors in the previous code\ncell were of identical shape. What happens when we try to perform a\nbinary operation on tensors if dissimilar shape?\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

The following cell throws a run-time error. This is intentional.

a = torch.rand(2, 3)\nb = torch.rand(3, 2)

\n

print(a * b)

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the general case, you cannot operate on tensors of different shape\nthis way, even in a case like the cell above, where the tensors have an\nidentical number of elements.\n\nIn Brief: Tensor Broadcasting\n=============================\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

If you are familiar with broadcasting semantics in NumPyndarrays, you\u2019ll find the same rules apply here.

\n```\n```{=html}\n
\n```\nThe exception to the same-shapes rule is *tensor broadcasting.* Here's\nan example:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "rand = torch.rand(2, 4)\ndoubled = rand * (torch.ones(1, 4) * 2)\n\nprint(rand)\nprint(doubled)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What's the trick here? How is it we got to multiply a 2x4 tensor by a\n1x4 tensor?\n\nBroadcasting is a way to perform an operation between tensors that have\nsimilarities in their shapes. In the example above, the one-row,\nfour-column tensor is multiplied by *both rows* of the two-row,\nfour-column tensor.\n\nThis is an important operation in Deep Learning. The common example is\nmultiplying a tensor of learning weights by a *batch* of input tensors,\napplying the operation to each instance in the batch separately, and\nreturning a tensor of identical shape - just like our (2, 4) \\* (1, 4)\nexample above returned a tensor of shape (2, 4).\n\nThe rules for broadcasting are:\n\n- Each tensor must have at least one dimension - no empty tensors.\n- Comparing the dimension sizes of the two tensors, *going from last\n to first:*\n - Each dimension must be equal, *or*\n - One of the dimensions must be of size 1, *or*\n - The dimension does not exist in one of the tensors\n\nTensors of identical shape, of course, are trivially \"broadcastable\", as\nyou saw earlier.\n\nHere are some examples of situations that honor the above rules and\nallow broadcasting:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "a = torch.ones(4, 3, 2)\n\nb = a * torch.rand( 3, 2) # 3rd & 2nd dims identical to a, dim 1 absent\nprint(b)\n\nc = a * torch.rand( 3, 1) # 3rd dim = 1, 2nd dim identical to a\nprint(c)\n\nd = a * torch.rand( 1, 2) # 3rd dim identical to a, 2nd dim = 1\nprint(d)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Look closely at the values of each tensor above:\n\n- The multiplication operation that created `b` was broadcast over\n every \"layer\" of `a`.\n- For `c`, the operation was broadcast over every layer and row of\n `a` - every 3-element column is identical.\n- For `d`, we switched it around - now every *row* is identical,\n across layers and columns.\n\nFor more information on broadcasting, see the [PyTorch\ndocumentation](https://pytorch.org/docs/stable/notes/broadcasting.html)\non the topic.\n\nHere are some examples of attempts at broadcasting that will fail:\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

The following cell throws a run-time error. This is intentional.

a =     torch.ones(4, 3, 2)

\n

b = a * torch.rand(4, 3) # dimensions must match last-to-first

\n

c = a * torch.rand( 2, 3) # both 3rd & 2nd dims different

\n

d = a * torch.rand((0, )) # can't broadcast with an empty tensor

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "More Math with Tensors\n======================\n\nPyTorch tensors have over three hundred operations that can be performed\non them.\n\nHere is a small sample from some of the major categories of operations:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# common functions\na = torch.rand(2, 4) * 2 - 1\nprint('Common functions:')\nprint(torch.abs(a))\nprint(torch.ceil(a))\nprint(torch.floor(a))\nprint(torch.clamp(a, -0.5, 0.5))\n\n# trigonometric functions and their inverses\nangles = torch.tensor([0, math.pi / 4, math.pi / 2, 3 * math.pi / 4])\nsines = torch.sin(angles)\ninverses = torch.asin(sines)\nprint('\\nSine and arcsine:')\nprint(angles)\nprint(sines)\nprint(inverses)\n\n# bitwise operations\nprint('\\nBitwise XOR:')\nb = torch.tensor([1, 5, 11])\nc = torch.tensor([2, 7, 10])\nprint(torch.bitwise_xor(b, c))\n\n# comparisons:\nprint('\\nBroadcasted, element-wise equality comparison:')\nd = torch.tensor([[1., 2.], [3., 4.]])\ne = torch.ones(1, 2) # many comparison ops support broadcasting!\nprint(torch.eq(d, e)) # returns a tensor of type bool\n\n# reductions:\nprint('\\nReduction ops:')\nprint(torch.max(d)) # returns a single-element tensor\nprint(torch.max(d).item()) # extracts the value from the returned tensor\nprint(torch.mean(d)) # average\nprint(torch.std(d)) # standard deviation\nprint(torch.prod(d)) # product of all numbers\nprint(torch.unique(torch.tensor([1, 2, 1, 2, 1, 2]))) # filter unique elements\n\n# vector and linear algebra operations\nv1 = torch.tensor([1., 0., 0.]) # x unit vector\nv2 = torch.tensor([0., 1., 0.]) # y unit vector\nm1 = torch.rand(2, 2) # random matrix\nm2 = torch.tensor([[3., 0.], [0., 3.]]) # three times identity matrix\n\nprint('\\nVectors & Matrices:')\nprint(torch.linalg.cross(v2, v1)) # negative of z unit vector (v1 x v2 == -v2 x v1)\nprint(m1)\nm3 = torch.linalg.matmul(m1, m2)\nprint(m3) # 3 times m1\nprint(torch.linalg.svd(m3)) # singular value decomposition" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a small sample of operations. For more details and the full\ninventory of math functions, have a look at the\n[documentation](https://pytorch.org/docs/stable/torch.html#math-operations).\nFor more details and the full inventory of linear algebra operations,\nhave a look at this\n[documentation](https://pytorch.org/docs/stable/linalg.html).\n\nAltering Tensors in Place\n=========================\n\nMost binary operations on tensors will return a third, new tensor. When\nwe say `c = a * b` (where `a` and `b` are tensors), the new tensor `c`\nwill occupy a region of memory distinct from the other tensors.\n\nThere are times, though, that you may wish to alter a tensor in place\n-for example, if you're doing an element-wise computation where you can\ndiscard intermediate values. For this, most of the math functions have a\nversion with an appended underscore (`_`) that will alter a tensor in\nplace.\n\nFor example:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "a = torch.tensor([0, math.pi / 4, math.pi / 2, 3 * math.pi / 4])\nprint('a:')\nprint(a)\nprint(torch.sin(a)) # this operation creates a new tensor in memory\nprint(a) # a has not changed\n\nb = torch.tensor([0, math.pi / 4, math.pi / 2, 3 * math.pi / 4])\nprint('\\nb:')\nprint(b)\nprint(torch.sin_(b)) # note the underscore\nprint(b) # b has changed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For arithmetic operations, there are functions that behave similarly:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "a = torch.ones(2, 2)\nb = torch.rand(2, 2)\n\nprint('Before:')\nprint(a)\nprint(b)\nprint('\\nAfter adding:')\nprint(a.add_(b))\nprint(a)\nprint(b)\nprint('\\nAfter multiplying')\nprint(b.mul_(b))\nprint(b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that these in-place arithmetic functions are methods on the\n`torch.Tensor` object, not attached to the `torch` module like many\nother functions (e.g., `torch.sin()`). As you can see from `a.add_(b)`,\n*the calling tensor is the one that gets changed in place.*\n\nThere is another option for placing the result of a computation in an\nexisting, allocated tensor. Many of the methods and functions we've seen\nso far - including creation methods! - have an `out` argument that lets\nyou specify a tensor to receive the output. If the `out` tensor is the\ncorrect shape and `dtype`, this can happen without a new memory\nallocation:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "a = torch.rand(2, 2)\nb = torch.rand(2, 2)\nc = torch.zeros(2, 2)\nold_id = id(c)\n\nprint(c)\nd = torch.matmul(a, b, out=c)\nprint(c) # contents of c have changed\n\nassert c is d # test c & d are same object, not just containing equal values\nassert id(c) == old_id # make sure that our new c is the same object as the old one\n\ntorch.rand(2, 2, out=c) # works for creation too!\nprint(c) # c has changed again\nassert id(c) == old_id # still the same object!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Copying Tensors\n===============\n\nAs with any object in Python, assigning a tensor to a variable makes the\nvariable a *label* of the tensor, and does not copy it. For example:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "a = torch.ones(2, 2)\nb = a\n\na[0][1] = 561 # we change a...\nprint(b) # ...and b is also altered" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But what if you want a separate copy of the data to work on? The\n`clone()` method is there for you:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "a = torch.ones(2, 2)\nb = a.clone()\n\nassert b is not a # different objects in memory...\nprint(torch.eq(a, b)) # ...but still with the same contents!\n\na[0][1] = 561 # a changes...\nprint(b) # ...but b is still all ones" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**There is an important thing to be aware of when using\n\\`\\`clone()\\`\\`.** If your source tensor has autograd, enabled then so\nwill the clone. **This will be covered more deeply in the video on\nautograd,** but if you want the light version of the details, continue\non.\n\n*In many cases, this will be what you want.* For example, if your model\nhas multiple computation paths in its `forward()` method, and *both* the\noriginal tensor and its clone contribute to the model's output, then to\nenable model learning you want autograd turned on for both tensors. If\nyour source tensor has autograd enabled (which it generally will if it's\na set of learning weights or derived from a computation involving the\nweights), then you'll get the result you want.\n\nOn the other hand, if you're doing a computation where *neither* the\noriginal tensor nor its clone need to track gradients, then as long as\nthe source tensor has autograd turned off, you're good to go.\n\n*There is a third case,* though: Imagine you're performing a computation\nin your model's `forward()` function, where gradients are turned on for\neverything by default, but you want to pull out some values mid-stream\nto generate some metrics. In this case, you *don't* want the cloned copy\nof your source tensor to track gradients - performance is improved with\nautograd's history tracking turned off. For this, you can use the\n`.detach()` method on the source tensor:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "a = torch.rand(2, 2, requires_grad=True) # turn on autograd\nprint(a)\n\nb = a.clone()\nprint(b)\n\nc = a.detach().clone()\nprint(c)\n\nprint(a)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What's happening here?\n\n- We create `a` with `requires_grad=True` turned on. **We haven't\n covered this optional argument yet, but will during the unit on\n autograd.**\n- When we print `a`, it informs us that the property\n `requires_grad=True` - this means that autograd and computation\n history tracking are turned on.\n- We clone `a` and label it `b`. When we print `b`, we can see that\n it's tracking its computation history - it has inherited `a`'s\n autograd settings, and added to the computation history.\n- We clone `a` into `c`, but we call `detach()` first.\n- Printing `c`, we see no computation history, and no\n `requires_grad=True`.\n\nThe `detach()` method *detaches the tensor from its computation\nhistory.* It says, \"do whatever comes next as if autograd was off.\" It\ndoes this *without* changing `a` - you can see that when we print `a`\nagain at the end, it retains its `requires_grad=True` property.\n\nMoving to\n[Accelerator](https://pytorch.org/docs/stable/torch.html#accelerators)\n\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\--\n\nOne of the major advantages of PyTorch is its robust acceleration on an\n[accelerator](https://pytorch.org/docs/stable/torch.html#accelerators)\nsuch as CUDA, MPS, MTIA, or XPU. So far, everything we've done has been\non CPU. How do we move to the faster hardware?\n\nFirst, we should check whether an accelerator is available, with the\n`is_available()` method.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

If you do not have an accelerator, the executable cells in this section will not execute anyaccelerator-related code.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "if torch.accelerator.is_available():\n print('We have an accelerator!')\nelse:\n print('Sorry, CPU only.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once we've determined that one or more accelerators is available, we\nneed to put our data someplace where the accelerator can see it. Your\nCPU does computation on data in your computer's RAM. Your accelerator\nhas dedicated memory attached to it. Whenever you want to perform a\ncomputation on a device, you must move *all* the data needed for that\ncomputation to memory accessible by that device. (Colloquially, \"moving\nthe data to memory accessible by the GPU\" is shorted to, \"moving the\ndata to the GPU\".)\n\nThere are multiple ways to get your data onto your target device. You\nmay do it at creation time:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "if torch.accelerator.is_available():\n gpu_rand = torch.rand(2, 2, device=torch.accelerator.current_accelerator())\n print(gpu_rand)\nelse:\n print('Sorry, CPU only.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, new tensors are created on the CPU, so we have to specify\nwhen we want to create our tensor on the accelerator with the optional\n`device` argument. You can see when we print the new tensor, PyTorch\ninforms us which device it's on (if it's not on CPU).\n\nYou can query the number of accelerators with\n`torch.accelerator.device_count()`. If you have more than one\naccelerator, you can specify them by index, take CUDA for example:\n`device='cuda:0'`, `device='cuda:1'`, etc.\n\nAs a coding practice, specifying our devices everywhere with string\nconstants is pretty fragile. In an ideal world, your code would perform\nrobustly whether you're on CPU or accelerator hardware. You can do this\nby creating a device handle that can be passed to your tensors instead\nof a string:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "my_device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device('cpu')\nprint('Device: {}'.format(my_device))\n\nx = torch.rand(2, 2, device=my_device)\nprint(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you have an existing tensor living on one device, you can move it to\nanother with the `to()` method. The following line of code creates a\ntensor on CPU, and moves it to whichever device handle you acquired in\nthe previous cell.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "y = torch.rand(2, 2)\ny = y.to(my_device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is important to know that in order to do computation involving two or\nmore tensors, *all of the tensors must be on the same device*. The\nfollowing code will throw a runtime error, regardless of whether you\nhave an accelerator device available, take CUDA for example:\n\n``` {.python}\nx = torch.rand(2, 2)\ny = torch.rand(2, 2, device='cuda')\nz = x + y # exception will be thrown\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Manipulating Tensor Shapes\n==========================\n\nSometimes, you'll need to change the shape of your tensor. Below, we'll\nlook at a few common cases, and how to handle them.\n\nChanging the Number of Dimensions\n---------------------------------\n\nOne case where you might need to change the number of dimensions is\npassing a single instance of input to your model. PyTorch models\ngenerally expect *batches* of input.\n\nFor example, imagine having a model that works on 3 x 226 x 226 images\n-a 226-pixel square with 3 color channels. When you load and transform\nit, you'll get a tensor of shape `(3, 226, 226)`. Your model, though, is\nexpecting input of shape `(N, 3, 226, 226)`, where `N` is the number of\nimages in the batch. So how do you make a batch of one?\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "a = torch.rand(3, 226, 226)\nb = a.unsqueeze(0)\n\nprint(a.shape)\nprint(b.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `unsqueeze()` method adds a dimension of extent 1. `unsqueeze(0)`\nadds it as a new zeroth dimension - now you have a batch of one!\n\nSo if that's *un*squeezing? What do we mean by squeezing? We're taking\nadvantage of the fact that any dimension of extent 1 *does not* change\nthe number of elements in the tensor.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "c = torch.rand(1, 1, 1, 1, 1)\nprint(c)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Continuing the example above, let's say the model's output is a\n20-element vector for each input. You would then expect the output to\nhave shape `(N, 20)`, where `N` is the number of instances in the input\nbatch. That means that for our single-input batch, we'll get an output\nof shape `(1, 20)`.\n\nWhat if you want to do some *non-batched* computation with that output\n-something that's just expecting a 20-element vector?\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "a = torch.rand(1, 20)\nprint(a.shape)\nprint(a)\n\nb = a.squeeze(0)\nprint(b.shape)\nprint(b)\n\nc = torch.rand(2, 2)\nprint(c.shape)\n\nd = c.squeeze(0)\nprint(d.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see from the shapes that our 2-dimensional tensor is now\n1-dimensional, and if you look closely at the output of the cell above\nyou'll see that printing `a` shows an \"extra\" set of square brackets\n`[]` due to having an extra dimension.\n\nYou may only `squeeze()` dimensions of extent 1. See above where we try\nto squeeze a dimension of size 2 in `c`, and get back the same shape we\nstarted with. Calls to `squeeze()` and `unsqueeze()` can only act on\ndimensions of extent 1 because to do otherwise would change the number\nof elements in the tensor.\n\nAnother place you might use `unsqueeze()` is to ease broadcasting.\nRecall the example above where we had the following code:\n\n``` {.python}\na = torch.ones(4, 3, 2)\n\nc = a * torch.rand( 3, 1) # 3rd dim = 1, 2nd dim identical to a\nprint(c)\n```\n\nThe net effect of that was to broadcast the operation over dimensions 0\nand 2, causing the random, 3 x 1 tensor to be multiplied element-wise by\nevery 3-element column in `a`.\n\nWhat if the random vector had just been 3-element vector? We'd lose the\nability to do the broadcast, because the final dimensions would not\nmatch up according to the broadcasting rules. `unsqueeze()` comes to the\nrescue:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "a = torch.ones(4, 3, 2)\nb = torch.rand( 3) # trying to multiply a * b will give a runtime error\nc = b.unsqueeze(1) # change to a 2-dimensional tensor, adding new dim at the end\nprint(c.shape)\nprint(a * c) # broadcasting works again!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `squeeze()` and `unsqueeze()` methods also have in-place versions,\n`squeeze_()` and `unsqueeze_()`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "batch_me = torch.rand(3, 226, 226)\nprint(batch_me.shape)\nbatch_me.unsqueeze_(0)\nprint(batch_me.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sometimes you'll want to change the shape of a tensor more radically,\nwhile still preserving the number of elements and their contents. One\ncase where this happens is at the interface between a convolutional\nlayer of a model and a linear layer of the model - this is common in\nimage classification models. A convolution kernel will yield an output\ntensor of shape *features x width x height,* but the following linear\nlayer expects a 1-dimensional input. `reshape()` will do this for you,\nprovided that the dimensions you request yield the same number of\nelements as the input tensor has:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "output3d = torch.rand(6, 20, 20)\nprint(output3d.shape)\n\ninput1d = output3d.reshape(6 * 20 * 20)\nprint(input1d.shape)\n\n# can also call it as a method on the torch module:\nprint(torch.reshape(output3d, (6 * 20 * 20,)).shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

The (6 * 20 * 20,) argument in the final line of the cellabove is because PyTorch expects a when specifying atensor shape - but when the shape is the first argument of a method, itlets us cheat and just use a series of integers. Here, we had to add theparentheses and comma to convince the method that this is really aone-element tuple.

\n```\n```{=html}\n
\n```\nWhen it can, `reshape()` will return a *view* on the tensor to be\nchanged - that is, a separate tensor object looking at the same\nunderlying region of memory. *This is important:* That means any change\nmade to the source tensor will be reflected in the view on that tensor,\nunless you `clone()` it.\n\nThere *are* conditions, beyond the scope of this introduction, where\n`reshape()` has to return a tensor carrying a copy of the data. For more\ninformation, see the\n[docs](https://pytorch.org/docs/stable/torch.html#torch.reshape).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "NumPy Bridge\n============\n\nIn the section above on broadcasting, it was mentioned that PyTorch's\nbroadcast semantics are compatible with NumPy's - but the kinship\nbetween PyTorch and NumPy goes even deeper than that.\n\nIf you have existing ML or scientific code with data stored in NumPy\nndarrays, you may wish to express that same data as PyTorch tensors,\nwhether to take advantage of PyTorch's GPU acceleration, or its\nefficient abstractions for building ML models. It's easy to switch\nbetween ndarrays and PyTorch tensors:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n\nnumpy_array = np.ones((2, 3))\nprint(numpy_array)\n\npytorch_tensor = torch.from_numpy(numpy_array)\nprint(pytorch_tensor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch creates a tensor of the same shape and containing the same data\nas the NumPy array, going so far as to keep NumPy's default 64-bit float\ndata type.\n\nThe conversion can just as easily go the other way:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pytorch_rand = torch.rand(2, 3)\nprint(pytorch_rand)\n\nnumpy_rand = pytorch_rand.numpy()\nprint(numpy_rand)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is important to know that these converted objects are using *the same\nunderlying memory* as their source objects, meaning that changes to one\nare reflected in the other:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "numpy_array[1, 1] = 23\nprint(pytorch_tensor)\n\npytorch_rand[1, 1] = 17\nprint(numpy_rand)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }