{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(beta) Accelerating BERT with semi-structured (2:4) sparsity\n===================================================== **Author**: [Jesse\nCai](https://github.com/jcaip)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Overview\n========\n\nLike other forms of sparsity, **semi-structured sparsity** is a model\noptimization technique that seeks to reduce the memory overhead and\nlatency of a neural network at the expense of some model accuracy. It is\nalso known as **fine-grained structured sparsity** or **2:4 structured\nsparsity**.\n\nSemi-structured sparsity derives its name from its unique sparsity\npattern, where n out of every 2n elements are pruned. We most often see\nn=2, hence 2:4 sparsity Semi-structured sparsity is particularly\ninteresting because it can be efficiently accelerated on GPUs and\ndoesn't degrade model accuracy as much as other sparsity patterns.\n\nWith the introduction of [semi-structured sparsity\nsupport](https://pytorch.org/docs/2.1/sparse.html#sparse-semi-structured-tensors),\nit is possible to prune and accelerate a semi-structured sparse model\nwithout leaving PyTorch. We will explain this process in this tutorial.\n\n![image](https://pytorch.org/tutorials/_static/img/pruning_flow.jpg)\n\nBy the end of this tutorial, we will have sparsified a BERT\nquestion-answering model to be 2:4 sparse, fine-tuning it to recover\nnearly all F1 loss (86.92 dense vs 86.48 sparse). Finally, we will\naccelerate this 2:4 sparse model for inference, yielding a 1.3x speedup.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Requirements\n============\n\n- PyTorch \\>= 2.1.\n- A NVIDIA GPU with semi-structured sparsity support (Compute\n Capability 8.0+).\n\nThis tutorial is designed for beginners to semi-structured sparsity and\nsparsity in general. For users with existing 2:4 sparse models,\naccelerating `nn.Linear` layers for inference with\n`to_sparse_semi_structured` is quite straightforward. Here is an\nexample:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nfrom torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor\nfrom torch.utils.benchmark import Timer\nSparseSemiStructuredTensor._FORCE_CUTLASS = True\n\n# mask Linear weight to be 2:4 sparse\nmask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()\nlinear = torch.nn.Linear(10240, 3072).half().cuda().eval()\nlinear.weight = torch.nn.Parameter(mask * linear.weight)\n\nx = torch.rand(3072, 10240).half().cuda()\n\nwith torch.inference_mode():\n dense_output = linear(x)\n dense_t = Timer(stmt=\"linear(x)\",\n globals={\"linear\": linear,\n \"x\": x}).blocked_autorange().median * 1e3\n\n # accelerate via SparseSemiStructuredTensor\n linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))\n\n sparse_output = linear(x)\n sparse_t = Timer(stmt=\"linear(x)\",\n globals={\"linear\": linear,\n \"x\": x}).blocked_autorange().median * 1e3\n\n # sparse and dense matmul are numerically equivalent\n # On an A100 80GB, we see: `Dense: 0.870ms Sparse: 0.630ms | Speedup: 1.382x`\n assert torch.allclose(sparse_output, dense_output, atol=1e-3)\n print(f\"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What problem does semi-structured sparsity solve?\n=================================================\n\nThe general motivation behind sparsity is simple: if there are zeros in\nyour network, you can optimize efficiency by not storing or computing\nthose parameters. However, the specifics of sparsity are tricky. Zeroing\nout parameters doesn't affect the latency / memory overhead of our model\nout of the box.\n\nThis is because the dense tensor still contains the pruned (zero)\nelements, which the dense matrix multiplication kernel will still\noperate on this elements. In order to realize performance gains, we need\nto swap out dense kernels for sparse kernels, which skip calculation\ninvolving pruned elements.\n\nTo do this, these kernels work on sparse matrices, which do not store\nthe pruned elements and store the specified elements in a compressed\nformat.\n\nFor semi-structured sparsity, we store exactly half of the original\nparameters along with some compressed metadata about how the elements\nwere arranged.\n\n![image](https://developer-blogs.nvidia.com/wp-content/uploads/2023/06/2-4-structured-sparsity-pattern.png){.align-center\n.:width: .80%}\n\nThere are many different sparse layouts, each with their own benefits\nand drawbacks. The 2:4 semi-structured sparse layout is particularly\ninteresting for two reasons:\n\n- Unlike previous sparse formats, semi-structured sparsity was\n designed to be efficiently accelerated on GPUs. In 2020, NVIDIA\n introduced hardware support for semi-structured sparsity with their\n Ampere architecture, and have also released fast sparse kernels via\n CUTLASS\n [cuSPARSELt](https://docs.nvidia.com/cuda/cusparselt/index.html).\n- At the same time, semi-structured sparsity tends to have a milder\n impact on model accuracy compared to other sparse formats,\n especially when accounting for more advanced pruning / fine-tuning\n methods. NVIDIA has shown in their [white\n paper](https://arxiv.org/abs/2104.08378) that a simple paradigm of\n magnitude pruning once to be 2:4 sparse and then retraining the\n model yields nearly identical model accuracies.\n\nSemi-structured exists in a sweet spot, providing a 2x (theoretical)\nspeedup at a much lower sparsity level (50%), while still being granular\nenough to preserve model accuracy.\n\n -----------------------------------------------------------------------\n Network Data Set Metric Dense FP16 Sparse FP16\n --------------------- ------------- -------- ------------ -------------\n ResNet-50 ImageNet Top-1 76.1 76.2\n\n ResNeXt-101\\_32x8d ImageNet Top-1 79.3 79.3\n\n Xception ImageNet Top-1 79.2 79.2\n\n SSD-RN50 COCO2017 bbAP 24.8 24.8\n\n MaskRCNN-RN50 COCO2017 bbAP 37.9 37.9\n\n FairSeq Transformer EN-DE WMT14 BLEU 28.2 28.5\n\n BERT-Large SQuAD v1.1 F1 91.9 91.9\n -----------------------------------------------------------------------\n\nSemi-structured sparsity has an additional advantage from a workflow\nperspective. Because the sparsity level is fixed at 50%, it is easier to\ndecompose the problem of sparsifying a model into two distinct\nsubproblems:\n\n- Accuracy - How can we find a set of 2:4 sparse weights that minimize\n the accuracy degradation of our model?\n- Performance - How can we accelerate our 2:4 sparse weights for\n inference and reduced memory overhead?\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\\begin{aligned}\n\\begin{bmatrix}\n 1 & 1 & 0 & 0 \\\\\n 0 & 0 & 1 & 1 \\\\\n 1 & 0 & 0 & 0 \\\\\n 0 & 0 & 1 & 1 \\\\\n \\end{bmatrix}\n\\end{aligned}$$\n\nThe natural handoff point between these two problems are zeroed-out\ndense tensors. Our inference solution is designed to compress and\naccelerate tensors in this format. We anticipate many users coming up\nwith custom masking solution, as this is an active area of research.\n\nNow that we've learned a little more about semi-structured sparsity,\nlet's apply it to a BERT model trained on a question answering task,\nSQuAD.\n\nIntro & Setup\n=============\n\nLet's start by importing all the packages we need.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# If you are running this in Google Colab, run:\n# .. code-block: python\n# \n# !pip install datasets transformers evaluate accelerate pandas\n#\nimport os\nos.environ[\"WANDB_DISABLED\"] = \"true\"\n\nimport collections\nimport datasets\nimport evaluate\nimport numpy as np\nimport torch\nimport torch.utils.benchmark as benchmark\nfrom torch import nn\nfrom torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor\nfrom torch.ao.pruning import WeightNormSparsifier\nimport transformers\n\n# force CUTLASS use if ``cuSPARSELt`` is not available\nSparseSemiStructuredTensor._FORCE_CUTLASS = True\ntorch.manual_seed(100)\n\n# Set default device to \"cuda:0\"\ntorch.set_default_device(torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll also need to define some helper functions that are specific to the\ndataset / task at hand. These were adapted from\n[this](https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt) Hugging\nFace course as a reference.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def preprocess_validation_function(examples, tokenizer):\n inputs = tokenizer(\n [q.strip() for q in examples[\"question\"]],\n examples[\"context\"],\n max_length=384,\n truncation=\"only_second\",\n return_overflowing_tokens=True,\n return_offsets_mapping=True,\n padding=\"max_length\",\n )\n sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n example_ids = []\n\n for i in range(len(inputs[\"input_ids\"])):\n sample_idx = sample_map[i]\n example_ids.append(examples[\"id\"][sample_idx])\n sequence_ids = inputs.sequence_ids(i)\n offset = inputs[\"offset_mapping\"][i]\n inputs[\"offset_mapping\"][i] = [\n o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)\n ]\n\n inputs[\"example_id\"] = example_ids\n return inputs\n\n\ndef preprocess_train_function(examples, tokenizer):\n inputs = tokenizer(\n [q.strip() for q in examples[\"question\"]],\n examples[\"context\"],\n max_length=384,\n truncation=\"only_second\",\n return_offsets_mapping=True,\n padding=\"max_length\",\n )\n\n offset_mapping = inputs[\"offset_mapping\"]\n answers = examples[\"answers\"]\n start_positions = []\n end_positions = []\n\n for i, (offset, answer) in enumerate(zip(offset_mapping, answers)):\n start_char = answer[\"answer_start\"][0]\n end_char = start_char + len(answer[\"text\"][0])\n sequence_ids = inputs.sequence_ids(i)\n\n # Find the start and end of the context\n idx = 0\n while sequence_ids[idx] != 1:\n idx += 1\n context_start = idx\n while sequence_ids[idx] == 1:\n idx += 1\n context_end = idx - 1\n\n # If the answer is not fully inside the context, label it (0, 0)\n if offset[context_start][0] > end_char or offset[context_end][1] < start_char:\n start_positions.append(0)\n end_positions.append(0)\n else:\n # Otherwise it's the start and end token positions\n idx = context_start\n while idx <= context_end and offset[idx][0] <= start_char:\n idx += 1\n start_positions.append(idx - 1)\n\n idx = context_end\n while idx >= context_start and offset[idx][1] >= end_char:\n idx -= 1\n end_positions.append(idx + 1)\n\n inputs[\"start_positions\"] = start_positions\n inputs[\"end_positions\"] = end_positions\n return inputs\n\n\ndef compute_metrics(start_logits, end_logits, features, examples):\n n_best = 20\n max_answer_length = 30\n metric = evaluate.load(\"squad\")\n\n example_to_features = collections.defaultdict(list)\n for idx, feature in enumerate(features):\n example_to_features[feature[\"example_id\"]].append(idx)\n\n predicted_answers = []\n # for example in ``tqdm`` (examples):\n for example in examples:\n example_id = example[\"id\"]\n context = example[\"context\"]\n answers = []\n\n # Loop through all features associated with that example\n for feature_index in example_to_features[example_id]:\n start_logit = start_logits[feature_index]\n end_logit = end_logits[feature_index]\n offsets = features[feature_index][\"offset_mapping\"]\n\n start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()\n end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()\n for start_index in start_indexes:\n for end_index in end_indexes:\n # Skip answers that are not fully in the context\n if offsets[start_index] is None or offsets[end_index] is None:\n continue\n # Skip answers with a length that is either < 0\n # or > max_answer_length\n if (\n end_index < start_index\n or end_index - start_index + 1 > max_answer_length\n ):\n continue\n\n answer = {\n \"text\": context[\n offsets[start_index][0] : offsets[end_index][1]\n ],\n \"logit_score\": start_logit[start_index] + end_logit[end_index],\n }\n answers.append(answer)\n\n # Select the answer with the best score\n if len(answers) > 0:\n best_answer = max(answers, key=lambda x: x[\"logit_score\"])\n predicted_answers.append(\n {\"id\": example_id, \"prediction_text\": best_answer[\"text\"]}\n )\n else:\n predicted_answers.append({\"id\": example_id, \"prediction_text\": \"\"})\n\n theoretical_answers = [\n {\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in examples\n ]\n return metric.compute(predictions=predicted_answers, references=theoretical_answers)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that those are defined, we just need one additional helper function,\nwhich will help us benchmark our model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def measure_execution_time(model, batch_sizes, dataset):\n dataset_for_model = dataset.remove_columns([\"example_id\", \"offset_mapping\"])\n dataset_for_model.set_format(\"torch\")\n batch_size_to_time_sec = {}\n for batch_size in batch_sizes:\n batch = {\n k: dataset_for_model[k][:batch_size].cuda()\n for k in dataset_for_model.column_names\n }\n\n with torch.no_grad():\n baseline_predictions = model(**batch)\n timer = benchmark.Timer(\n stmt=\"model(**batch)\", globals={\"model\": model, \"batch\": batch}\n )\n p50 = timer.blocked_autorange().median * 1000\n batch_size_to_time_sec[batch_size] = p50\n\n model_c = torch.compile(model, fullgraph=True)\n timer = benchmark.Timer(\n stmt=\"model(**batch)\", globals={\"model\": model_c, \"batch\": batch}\n )\n p50 = timer.blocked_autorange().median * 1000\n batch_size_to_time_sec[f\"{batch_size}_compile\"] = p50\n new_predictions = model_c(**batch)\n\n return batch_size_to_time_sec" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will get started by loading our model and tokenizer, and then setting\nup our dataset.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# load model\nmodel_name = \"bert-base-cased\"\ntokenizer = transformers.AutoTokenizer.from_pretrained(model_name)\nmodel = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)\nprint(f\"Loading tokenizer: {model_name}\")\nprint(f\"Loading model: {model_name}\")\n\n# set up train and val dataset\nsquad_dataset = datasets.load_dataset(\"squad\")\ntokenized_squad_dataset = {}\ntokenized_squad_dataset[\"train\"] = squad_dataset[\"train\"].map(\n lambda x: preprocess_train_function(x, tokenizer), batched=True\n)\ntokenized_squad_dataset[\"validation\"] = squad_dataset[\"validation\"].map(\n lambda x: preprocess_validation_function(x, tokenizer),\n batched=True,\n remove_columns=squad_dataset[\"train\"].column_names,\n)\ndata_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Establishing a baseline\n=======================\n\nNext, we'll train a quick baseline of our model on SQuAD. This task asks\nour model to identify spans, or segments of text, in a given context\n(Wikipedia articles) that answer a given question. Running the following\ncode gives me an F1 score of 86.9. This is quite close to the reported\nNVIDIA score and the difference is likely due to BERT-base\nvs.\u00a0BERT-large or fine-tuning hyperparameters.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "training_args = transformers.TrainingArguments(\n \"trainer\",\n num_train_epochs=1,\n lr_scheduler_type=\"constant\",\n per_device_train_batch_size=32,\n per_device_eval_batch_size=256,\n logging_steps=50, \n # Limit max steps for tutorial runners. Delete the below line to see the reported accuracy numbers. \n max_steps=500,\n report_to=None,\n)\n\ntrainer = transformers.Trainer(\n model,\n training_args,\n train_dataset=tokenized_squad_dataset[\"train\"],\n eval_dataset=tokenized_squad_dataset[\"validation\"],\n data_collator=data_collator,\n tokenizer=tokenizer,\n)\n\ntrainer.train()\n\n# batch sizes to compare for eval\nbatch_sizes = [4, 16, 64, 256]\n# 2:4 sparsity require fp16, so we cast here for a fair comparison\nwith torch.autocast(\"cuda\"):\n with torch.no_grad():\n predictions = trainer.predict(tokenized_squad_dataset[\"validation\"])\n start_logits, end_logits = predictions.predictions\n fp16_baseline = compute_metrics(\n start_logits,\n end_logits,\n tokenized_squad_dataset[\"validation\"],\n squad_dataset[\"validation\"],\n )\n fp16_time = measure_execution_time(\n model,\n batch_sizes,\n tokenized_squad_dataset[\"validation\"],\n )\n\nprint(\"fp16\", fp16_baseline)\nprint(\"cuda_fp16 time\", fp16_time)\n\nimport pandas as pd\ndf = pd.DataFrame(trainer.state.log_history)\ndf.plot.line(x='step', y='loss', title=\"Loss vs. # steps\", ylabel=\"loss\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pruning BERT to be 2:4 sparse\n=============================\n\nNow that we have our baseline, it's time we prune BERT. There are many\ndifferent pruning strategies, but one of the most common is **magnitude\npruning**, which seeks to remove the weights with the lowest L1 norm.\nMagnitude pruning was used by NVIDIA in all their results and is a\ncommon baseline.\n\nTo do this, we will use the `torch.ao.pruning` package, which contains a\nweight-norm (magnitude) sparsifier. These sparsifiers work by applying\nmask parametrizations to the weight tensors in a model. This lets them\nsimulate sparsity by masking out the pruned weights.\n\nWe'll also have to decide what layers of the model to apply sparsity to,\nwhich in this case is all of the `nn.Linear` layers, except for the\ntask-specific head outputs. That's because semi-structured sparsity has\n[shape\nconstraints](https://pytorch.org/docs/2.1/sparse.html#constructing-sparse-semi-structured-tensors),\nand the task-specific `nn.Linear` layers do not satisfy them.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sparsifier = WeightNormSparsifier(\n # apply sparsity to all blocks\n sparsity_level=1.0,\n # shape of 4 elements is a block\n sparse_block_shape=(1, 4),\n # two zeros for every block of 4\n zeros_per_block=2\n)\n\n# add to config if ``nn.Linear`` and in the BERT model.\nsparse_config = [\n {\"tensor_fqn\": f\"{fqn}.weight\"}\n for fqn, module in model.named_modules()\n if isinstance(module, nn.Linear) and \"layer\" in fqn\n]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first step for pruning the model is to insert parametrizations for\nmasking the weights of the model. This is done by the prepare step.\nAnytime we try to access the `.weight` we will get `mask * weight`\ninstead.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Prepare the model, insert fake-sparsity parametrizations for training\nsparsifier.prepare(model, sparse_config)\nprint(model.bert.encoder.layer[0].output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we'll take a single pruning step. All pruners implement a\n`update_mask()` method that updates the mask with the logic being\ndetermined by the pruner implementation. The step method calls this\n`update_mask` functions for the weights specified in the sparse config.\n\nWe will also evaluate the model to show the accuracy degradation of\nzero-shot pruning, or pruning without fine-tuning / retraining.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sparsifier.step()\nwith torch.autocast(\"cuda\"):\n with torch.no_grad():\n predictions = trainer.predict(tokenized_squad_dataset[\"validation\"])\n pruned = compute_metrics(\n *predictions.predictions,\n tokenized_squad_dataset[\"validation\"],\n squad_dataset[\"validation\"],\n )\nprint(\"pruned eval metrics:\", pruned)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this state, we can start fine-tuning the model, updating the elements\nthat wouldn't be pruned to better account for the accuracy loss. Once\nwe've reached a satisfied state, we can call `squash_mask` to fuse the\nmask and the weight together. This will remove the parametrizations and\nwe are left with a zeroed-out 2:4 dense model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "trainer.train()\nsparsifier.squash_mask()\ntorch.set_printoptions(edgeitems=4)\nprint(model.bert.encoder.layer[0].intermediate.dense.weight[:8, :8])\n\ndf[\"sparse_loss\"] = pd.DataFrame(trainer.state.log_history)[\"loss\"]\ndf.plot.line(x='step', y=[\"loss\", \"sparse_loss\"], title=\"Loss vs. # steps\", ylabel=\"loss\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Accelerating 2:4 sparse models for inference\n============================================\n\nNow that we have a model in this format, we can accelerate it for\ninference just like in the QuickStart Guide.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = model.cuda().half()\n# accelerate for sparsity\nfor fqn, module in model.named_modules():\n if isinstance(module, nn.Linear) and \"layer\" in fqn:\n module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))\n\nwith torch.no_grad():\n predictions = trainer.predict(tokenized_squad_dataset[\"validation\"])\nstart_logits, end_logits = predictions.predictions\nmetrics_sparse = compute_metrics(\n start_logits,\n end_logits,\n tokenized_squad_dataset[\"validation\"],\n squad_dataset[\"validation\"],\n)\nprint(\"sparse eval metrics: \", metrics_sparse)\nsparse_perf = measure_execution_time(\n model,\n batch_sizes,\n tokenized_squad_dataset[\"validation\"],\n)\nprint(\"sparse perf metrics: \", sparse_perf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Retraining our model after magnitude pruning has recovered nearly all of\nthe F1 that has been lost when the model was pruned. At the same time we\nhave achieved a 1.28x speedup for `bs=16`. Note that not all shapes are\namenable to performance improvements. When batch sizes are small and\nlimited time is spent in compute sparse kernels may be slower than their\ndense counterparts.\n\nBecause semi-structured sparsity is implemented as a tensor subclass, it\nis compatible with `torch.compile`. When composed with\n`to_sparse_semi_structured`, we are able to achieve a total 2x speedup\non BERT.\n\n -------------------------------------------------------------------------\n Metrics fp16 2:4 sparse delta / speedup compiled\n ------------------- -------- -------------- ----------------- -----------\n Exact Match (%) 78.53 78.44 -0.09 \n\n F1 (%) 86.93 86.49 -0.44 \n\n Time (bs=4) 11.10 15.54 0.71x no\n\n Time (bs=16) 19.35 15.74 1.23x no\n\n Time (bs=64) 72.71 59.41 1.22x no\n\n Time (bs=256) 286.65 247.63 1.14x no\n\n Time (bs=4) 7.59 7.46 1.02x yes\n\n Time (bs=16) 11.47 9.68 1.18x yes\n\n Time (bs=64) 41.57 36.92 1.13x yes\n\n Time (bs=256) 159.22 142.23 1.12x yes\n -------------------------------------------------------------------------\n\nConclusion\n==========\n\nIn this tutorial, we have shown how to prune BERT to be 2:4 sparse and\nhow to accelerate a 2:4 sparse model for inference. By taking advantage\nof our `SparseSemiStructuredTensor` subclass, we were able to achieve a\n1.3x speedup over the fp16 baseline, and up to 2x with `torch.compile`.\nWe also demonstrated the benefits of 2:4 sparsity by fine-tuning BERT to\nrecover any lost F1 (86.92 dense vs 86.48 sparse).\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }