{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Parametrizations Tutorial\n=========================\n\n**Author**: [Mario Lezcano](https://github.com/lezcano)\n\nRegularizing deep-learning models is a surprisingly challenging task.\nClassical techniques such as penalty methods often fall short when\napplied on deep models due to the complexity of the function being\noptimized. This is particularly problematic when working with\nill-conditioned models. Examples of these are RNNs trained on long\nsequences and GANs. A number of techniques have been proposed in recent\nyears to regularize these models and improve their convergence. On\nrecurrent models, it has been proposed to control the singular values of\nthe recurrent kernel for the RNN to be well-conditioned. This can be\nachieved, for example, by making the recurrent kernel\n[orthogonal](https://en.wikipedia.org/wiki/Orthogonal_matrix). Another\nway to regularize recurrent models is via \\\"[weight\nnormalization](https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html)\\\".\nThis approach proposes to decouple the learning of the parameters from\nthe learning of their norms. To do so, the parameter is divided by its\n[Frobenius\nnorm](https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm) and a\nseparate parameter encoding its norm is learned. A similar\nregularization was proposed for GANs under the name of \\\"[spectral\nnormalization](https://pytorch.org/docs/stable/generated/torch.nn.utils.spectral_norm.html)\\\".\nThis method controls the Lipschitz constant of the network by dividing\nits parameters by their [spectral\nnorm](https://en.wikipedia.org/wiki/Matrix_norm#Special_cases), rather\nthan their Frobenius norm.\n\nAll these methods have a common pattern: they all transform a parameter\nin an appropriate way before using it. In the first case, they make it\northogonal by using a function that maps matrices to orthogonal\nmatrices. In the case of weight and spectral normalization, they divide\nthe original parameter by its norm.\n\nMore generally, all these examples use a function to put extra structure\non the parameters. In other words, they use a function to constrain the\nparameters.\n\nIn this tutorial, you will learn how to implement and use this pattern\nto put constraints on your model. Doing so is as easy as writing your\nown `nn.Module`.\n\nRequirements: `torch>=1.9.0`\n\nImplementing parametrizations by hand\n-------------------------------------\n\nAssume that we want to have a square linear layer with symmetric\nweights, that is, with weights `X` such that `X = X\u1d40`. One way to do so\nis to copy the upper-triangular part of the matrix into its\nlower-triangular part\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch.nn as nn\nimport torch.nn.utils.parametrize as parametrize\n\ndef symmetric(X):\n return X.triu() + X.triu(1).transpose(-1, -2)\n\nX = torch.rand(3, 3)\nA = symmetric(X)\nassert torch.allclose(A, A.T) # A is symmetric\nprint(A) # Quick visual check" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can then use this idea to implement a linear layer with symmetric\nweights\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class LinearSymmetric(nn.Module):\n def __init__(self, n_features):\n super().__init__()\n self.weight = nn.Parameter(torch.rand(n_features, n_features))\n\n def forward(self, x):\n A = symmetric(self.weight)\n return x @ A" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The layer can be then used as a regular linear layer\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "layer = LinearSymmetric(3)\nout = layer(torch.rand(8, 3))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This implementation, although correct and self-contained, presents a\nnumber of problems:\n\n1) It reimplements the layer. We had to implement the linear layer as\n `x @ A`. This is not very problematic for a linear layer, but\n imagine having to reimplement a CNN or a Transformer\\...\n2) It does not separate the layer and the parametrization. If the\n parametrization were more difficult, we would have to rewrite its\n code for each layer that we want to use it in.\n3) It recomputes the parametrization every time we use the layer. If we\n use the layer several times during the forward pass, (imagine the\n recurrent kernel of an RNN), it would compute the same `A` every\n time that the layer is called.\n\nIntroduction to parametrizations\n================================\n\nParametrizations can solve all these problems as well as others.\n\nLet\\'s start by reimplementing the code above using\n`torch.nn.utils.parametrize`. The only thing that we have to do is to\nwrite the parametrization as a regular `nn.Module`\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Symmetric(nn.Module):\n def forward(self, X):\n return X.triu() + X.triu(1).transpose(-1, -2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is all we need to do. Once we have this, we can transform any\nregular layer into a symmetric layer by doing\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "layer = nn.Linear(3, 3)\nparametrize.register_parametrization(layer, \"weight\", Symmetric())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, the matrix of the linear layer is symmetric\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "A = layer.weight\nassert torch.allclose(A, A.T) # A is symmetric\nprint(A) # Quick visual check" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can do the same thing with any other layer. For example, we can\ncreate a CNN with\n[skew-symmetric](https://en.wikipedia.org/wiki/Skew-symmetric_matrix)\nkernels. We use a similar parametrization, copying the upper-triangular\npart with signs reversed into the lower-triangular part\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Skew(nn.Module):\n def forward(self, X):\n A = X.triu(1)\n return A - A.transpose(-1, -2)\n\n\ncnn = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=3)\nparametrize.register_parametrization(cnn, \"weight\", Skew())\n# Print a few kernels\nprint(cnn.weight[0, 1])\nprint(cnn.weight[2, 2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Inspecting a parametrized module\n================================\n\nWhen a module is parametrized, we find that the module has changed in\nthree ways:\n\n1) `model.weight` is now a property\n2) It has a new `module.parametrizations` attribute\n3) The unparametrized weight has been moved to\n `module.parametrizations.weight.original`\n\n| \n\nAfter parametrizing `weight`, `layer.weight` is turned into a [Python\nproperty](https://docs.python.org/3/library/functions.html#property).\nThis property computes `parametrization(weight)` every time we request\n`layer.weight` just as we did in our implementation of `LinearSymmetric`\nabove.\n\nRegistered parametrizations are stored under a `parametrizations`\nattribute within the module.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "layer = nn.Linear(3, 3)\nprint(f\"Unparametrized:\\n{layer}\")\nparametrize.register_parametrization(layer, \"weight\", Symmetric())\nprint(f\"\\nParametrized:\\n{layer}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This `parametrizations` attribute is an `nn.ModuleDict`, and it can be\naccessed as such\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(layer.parametrizations)\nprint(layer.parametrizations.weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each element of this `nn.ModuleDict` is a `ParametrizationList`, which\nbehaves like an `nn.Sequential`. This list will allow us to concatenate\nparametrizations on one weight. Since this is a list, we can access the\nparametrizations indexing it. Here\\'s where our `Symmetric`\nparametrization sits\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(layer.parametrizations.weight[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The other thing that we notice is that, if we print the parameters, we\nsee that the parameter `weight` has been moved\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(dict(layer.named_parameters()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It now sits under `layer.parametrizations.weight.original`\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(layer.parametrizations.weight.original)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Besides these three small differences, the parametrization is doing\nexactly the same as our manual implementation\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "symmetric = Symmetric()\nweight_orig = layer.parametrizations.weight.original\nprint(torch.dist(layer.weight, symmetric(weight_orig)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Parametrizations are first-class citizens\n=========================================\n\nSince `layer.parametrizations` is an `nn.ModuleList`, it means that the\nparametrizations are properly registered as submodules of the original\nmodule. As such, the same rules for registering parameters in a module\napply to register a parametrization. For example, if a parametrization\nhas parameters, these will be moved from CPU to CUDA when calling\n`model = model.cuda()`.\n\nCaching the value of a parametrization\n======================================\n\nParametrizations come with an inbuilt caching system via the context\nmanager `parametrize.cached()`\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class NoisyParametrization(nn.Module):\n def forward(self, X):\n print(\"Computing the Parametrization\")\n return X\n\nlayer = nn.Linear(4, 4)\nparametrize.register_parametrization(layer, \"weight\", NoisyParametrization())\nprint(\"Here, layer.weight is recomputed every time we call it\")\nfoo = layer.weight + layer.weight.T\nbar = layer.weight.sum()\nwith parametrize.cached():\n print(\"Here, it is computed just the first time layer.weight is called\")\n foo = layer.weight + layer.weight.T\n bar = layer.weight.sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Concatenating parametrizations\n==============================\n\nConcatenating two parametrizations is as easy as registering them on the\nsame tensor. We may use this to create more complex parametrizations\nfrom simpler ones. For example, the [Cayley\nmap](https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map) maps the\nskew-symmetric matrices to the orthogonal matrices of positive\ndeterminant. We can concatenate `Skew` and a parametrization that\nimplements the Cayley map to get a layer with orthogonal weights\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class CayleyMap(nn.Module):\n def __init__(self, n):\n super().__init__()\n self.register_buffer(\"Id\", torch.eye(n))\n\n def forward(self, X):\n # (I + X)(I - X)^{-1}\n return torch.linalg.solve(self.Id - X, self.Id + X)\n\nlayer = nn.Linear(3, 3)\nparametrize.register_parametrization(layer, \"weight\", Skew())\nparametrize.register_parametrization(layer, \"weight\", CayleyMap(3))\nX = layer.weight\nprint(torch.dist(X.T @ X, torch.eye(3))) # X is orthogonal" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This may also be used to prune a parametrized module, or to reuse\nparametrizations. For example, the matrix exponential maps the symmetric\nmatrices to the Symmetric Positive Definite (SPD) matrices But the\nmatrix exponential also maps the skew-symmetric matrices to the\northogonal matrices. Using these two facts, we may reuse the\nparametrizations before to our advantage\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MatrixExponential(nn.Module):\n def forward(self, X):\n return torch.matrix_exp(X)\n\nlayer_orthogonal = nn.Linear(3, 3)\nparametrize.register_parametrization(layer_orthogonal, \"weight\", Skew())\nparametrize.register_parametrization(layer_orthogonal, \"weight\", MatrixExponential())\nX = layer_orthogonal.weight\nprint(torch.dist(X.T @ X, torch.eye(3))) # X is orthogonal\n\nlayer_spd = nn.Linear(3, 3)\nparametrize.register_parametrization(layer_spd, \"weight\", Symmetric())\nparametrize.register_parametrization(layer_spd, \"weight\", MatrixExponential())\nX = layer_spd.weight\nprint(torch.dist(X, X.T)) # X is symmetric\nprint((torch.linalg.eigvalsh(X) > 0.).all()) # X is positive definite" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initializing parametrizations\n=============================\n\nParametrizations come with a mechanism to initialize them. If we\nimplement a method `right_inverse` with signature\n\n``` {.python}\ndef right_inverse(self, X: Tensor) -> Tensor\n```\n\nit will be used when assigning to the parametrized tensor.\n\nLet\\'s upgrade our implementation of the `Skew` class to support this\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Skew(nn.Module):\n def forward(self, X):\n A = X.triu(1)\n return A - A.transpose(-1, -2)\n\n def right_inverse(self, A):\n # We assume that A is skew-symmetric\n # We take the upper-triangular elements, as these are those used in the forward\n return A.triu(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We may now initialize a layer that is parametrized with `Skew`\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "layer = nn.Linear(3, 3)\nparametrize.register_parametrization(layer, \"weight\", Skew())\nX = torch.rand(3, 3)\nX = X - X.T # X is now skew-symmetric\nlayer.weight = X # Initialize layer.weight to be X\nprint(torch.dist(layer.weight, X)) # layer.weight == X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This `right_inverse` works as expected when we concatenate\nparametrizations. To see this, let\\'s upgrade the Cayley parametrization\nto also support being initialized\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class CayleyMap(nn.Module):\n def __init__(self, n):\n super().__init__()\n self.register_buffer(\"Id\", torch.eye(n))\n\n def forward(self, X):\n # Assume X skew-symmetric\n # (I + X)(I - X)^{-1}\n return torch.linalg.solve(self.Id - X, self.Id + X)\n\n def right_inverse(self, A):\n # Assume A orthogonal\n # See https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map\n # (A - I)(A + I)^{-1}\n return torch.linalg.solve(A + self.Id, self.Id - A)\n\nlayer_orthogonal = nn.Linear(3, 3)\nparametrize.register_parametrization(layer_orthogonal, \"weight\", Skew())\nparametrize.register_parametrization(layer_orthogonal, \"weight\", CayleyMap(3))\n# Sample an orthogonal matrix with positive determinant\nX = torch.empty(3, 3)\nnn.init.orthogonal_(X)\nif X.det() < 0.:\n X[0].neg_()\nlayer_orthogonal.weight = X\nprint(torch.dist(layer_orthogonal.weight, X)) # layer_orthogonal.weight == X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This initialization step can be written more succinctly as\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "layer_orthogonal.weight = nn.init.orthogonal_(layer_orthogonal.weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The name of this method comes from the fact that we would often expect\nthat `forward(right_inverse(X)) == X`. This is a direct way of rewriting\nthat the forward after the initialization with value `X` should return\nthe value `X`. This constraint is not strongly enforced in practice. In\nfact, at times, it might be of interest to relax this relation. For\nexample, consider the following implementation of a randomized pruning\nmethod:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class PruningParametrization(nn.Module):\n def __init__(self, X, p_drop=0.2):\n super().__init__()\n # sample zeros with probability p_drop\n mask = torch.full_like(X, 1.0 - p_drop)\n self.mask = torch.bernoulli(mask)\n\n def forward(self, X):\n return X * self.mask\n\n def right_inverse(self, A):\n return A" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this case, it is not true that for every matrix A\n`forward(right_inverse(A)) == A`. This is only true when the matrix `A`\nhas zeros in the same positions as the mask. Even then, if we assign a\ntensor to a pruned parameter, it will comes as no surprise that tensor\nwill be, in fact, pruned\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "layer = nn.Linear(3, 4)\nX = torch.rand_like(layer.weight)\nprint(f\"Initialization matrix:\\n{X}\")\nparametrize.register_parametrization(layer, \"weight\", PruningParametrization(layer.weight))\nlayer.weight = X\nprint(f\"\\nInitialized weight:\\n{layer.weight}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Removing parametrizations\n=========================\n\nWe may remove all the parametrizations from a parameter or a buffer in a\nmodule by using `parametrize.remove_parametrizations()`\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "layer = nn.Linear(3, 3)\nprint(\"Before:\")\nprint(layer)\nprint(layer.weight)\nparametrize.register_parametrization(layer, \"weight\", Skew())\nprint(\"\\nParametrized:\")\nprint(layer)\nprint(layer.weight)\nparametrize.remove_parametrizations(layer, \"weight\")\nprint(\"\\nAfter. Weight has skew-symmetric values but it is unconstrained:\")\nprint(layer)\nprint(layer.weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When removing a parametrization, we may choose to leave the original\nparameter (i.e. that in `layer.parametriations.weight.original`) rather\nthan its parametrized version by setting the flag\n`leave_parametrized=False`\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "layer = nn.Linear(3, 3)\nprint(\"Before:\")\nprint(layer)\nprint(layer.weight)\nparametrize.register_parametrization(layer, \"weight\", Skew())\nprint(\"\\nParametrized:\")\nprint(layer)\nprint(layer.weight)\nparametrize.remove_parametrizations(layer, \"weight\", leave_parametrized=False)\nprint(\"\\nAfter. Same as Before:\")\nprint(layer)\nprint(layer.weight)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }