• Tutorials >
  • (原型)使用半结构化(2:4)稀疏性加速BERT
Shortcuts

(原型)使用半结构化(2:4)稀疏性加速BERT

Created On: Oct 03, 2023 | Last Updated: Jan 16, 2024 | Last Verified: Nov 05, 2024

作者: Jesse Cai

与其他形式的稀疏性一样,半结构化稀疏性**是一种模型优化技术,旨在通过牺牲一定的模型准确性来降低神经网络的内存开销和延迟。它也被称为**细粒度结构稀疏性**或**2:4结构稀疏性

半结构化稀疏性因其独特的稀疏模式而得名,其中2n个元素中的n个元素被裁剪。我们通常看到n=2,因此称为2:4稀疏性。半结构化稀疏性特别有趣,因为它可以在GPU上高效加速,并且它对模型准确性的损害不像其他稀疏性模式那么大。

随着`半结构化稀疏性支持 <https://pytorch.org/docs/2.1/sparse.html#sparse-semi-structured-tensors>`_的引入,可以在不离开PyTorch的情况下裁剪并加速半结构化稀疏模型。我们将在本教程中解释这一过程。

../_static/img/pruning_flow.jpg

通过本教程结束时,我们将使一个BERT问答模型变为2:4稀疏,并对其进行微调以恢复几乎所有的F1损失(86.92密集 vs 86.48稀疏)。最后,我们将加速这个2:4稀疏模型进行推理,带来1.3倍的加速。

要求

  • PyTorch >= 2.1。

  • 支持半结构化稀疏性的NVIDIA GPU (计算能力8.0+)。

备注

本教程是为半结构化稀疏性/稀疏性新手设计的。对于已有2:4稀疏模型的用户,使用``to_sparse_semi_structured``加速推理的``nn.Linear``层非常简单:

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
SparseSemiStructuredTensor._FORCE_CUTLASS = True

# mask Linear weight to be 2:4 sparse
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
linear.weight = torch.nn.Parameter(mask * linear.weight)

x = torch.rand(3072, 10240).half().cuda()

with torch.inference_mode():
    dense_output = linear(x)
    dense_t = Timer(stmt="linear(x)",
                    globals={"linear": linear,
                             "x": x}).blocked_autorange().median * 1e3

    # accelerate via SparseSemiStructuredTensor
    linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))

    sparse_output = linear(x)
    sparse_t = Timer(stmt="linear(x)",
                    globals={"linear": linear,
                             "x": x}).blocked_autorange().median * 1e3

    # sparse and dense matmul are numerically equivalent
    assert torch.allclose(sparse_output, dense_output, atol=1e-3)
    print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")

在A100 80GB上,我们看到了:密集:0.870ms 稀疏:0.630ms | 加速比:1.382x

半结构化稀疏性解决了什么问题?

稀疏性的总体动机很简单:如果网络中存在零值参数,你可以避免存储/计算这些参数。然而,稀疏性本身的细节是复杂的。将参数置零不是直接带来模型的延迟/内存开销的减少。

这是因为稠密张量中仍然包含被裁剪(零值)的元素,稠密矩阵乘法内核仍然对这些元素执行操作。为了实现性能提升,我们需要将稠密内核替换为稀疏内核,从而跳过涉及被裁剪元素的计算。

为实现这一点,这些内核使用稀疏矩阵,这些矩阵不存储被裁剪的元素,并以压缩格式存储指定的元素。

对于半结构化稀疏性,我们精确保存了原始参数的一半以及关于这些元素如何排列的一些压缩元数据。

有许多不同的稀疏布局,每种布局都有其自己的优点和缺点。2:4半结构化稀疏布局因两个原因特别有趣:1. 与之前的稀疏格式不同,半结构化稀疏性是专门为在GPU上高效加速而设计的。

2020年NVIDIA通过其Ampere架构引入了硬件支持半结构化稀疏性,并通过CUTLASS/`cuSPARSELt <https://docs.nvidia.com/cuda/cusparselt/index.html>`_发布了快速的稀疏内核。

  1. 同时,与其他稀疏格式相比,半结构化稀疏性对模型准确性的影响往往较小,尤其是在考虑更高级裁剪/微调方法时。NVIDIA在其`白皮书 <https://arxiv.org/abs/2104.08378>`_中显示一次简单的范式:基于幅值裁剪将模型裁剪为2:4稀疏,然后重新训练模型,从而实现几乎相同的模型准确性。

半结构化稀疏性处于一个甜蜜点,提供了2倍(理论)加速,同时保持较低的稀疏水平(50%),而仍然足够细粒度以保持模型准确性。

网络

数据集

指标

稠密FP16

稀疏FP16

ResNet-50

ImageNet

Top-1

76.1

76.2

ResNeXt-101_32x8d

ImageNet

Top-1

79.3

79.3

Xception

ImageNet

Top-1

79.2

79.2

SSD-RN50

COCO2017

bbAP

24.8

24.8

MaskRCNN-RN50

COCO2017

bbAP

37.9

37.9

FairSeq Transformer

EN-DE WMT14

BLEU

28.2

28.5

BERT-Large

SQuAD v1.1

F1

91.9

91.9

从工作流程上看,半结构化稀疏性还具有额外的优势。由于稀疏水平固定为50%,将模型稀疏化问题分解为两个独立的子问题变得更加容易:

  • 准确性 - 我们如何找到一组2:4稀疏权重以将模型的准确性降级降至最小?

  • 性能 - 我们如何加速推理的2:4稀疏权重并减少内存开销?

\[\begin{bmatrix} 1 & 1 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ 1 & 0 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ \end{bmatrix} \]

这两个问题的自然交接点是零化的稠密张量。我们的推理解决方案旨在压缩和加速此格式中的张量。我们预计许多人会提出自定义屏蔽解决方案,因为这是一个活跃的研究领域。

现在我们已经更多地了解了半结构化稀疏性,让我们将其应用于一个以问答任务训练过的BERT模型,即SQuAD。

介绍和设置

让我们开始导入所有需要的包。

import collections
import datasets
import evaluate
import numpy as np
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
import transformers

# force CUTLASS use if cuSPARSELt is not available
SparseSemiStructuredTensor._FORCE_CUTLASS = True
torch.manual_seed(100)

我们还需要定义一些特定于数据集/任务的辅助函数。这些函数是参考`这个 <https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt>`_ huggingface课程进行改编的。

def preprocess_validation_function(examples, tokenizer):
    inputs = tokenizer(
        [q.strip() for q in examples["question"]],
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])
        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs


def preprocess_train_function(examples, tokenizer):
    inputs = tokenizer(
        [q.strip() for q in examples["question"]],
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs["offset_mapping"]
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, (offset, answer) in enumerate(zip(offset_mapping, answers)):
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs


def compute_metrics(start_logits, end_logits, features, examples):
    n_best = 20
    max_answer_length = 30
    metric = evaluate.load("squad")

    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    # for example in tqdm(examples):
    for example in examples:
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0
                    # or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[
                            offsets[start_index][0] : offsets[end_index][1]
                        ],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [
        {"id": ex["id"], "answers": ex["answers"]} for ex in examples
    ]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

现在这些都已定义,我们只需要一个额外的辅助函数,它将帮助我们对我们的模型进行基准测试。

def measure_execution_time(model, batch_sizes, dataset):
    dataset_for_model = dataset.remove_columns(["example_id", "offset_mapping"])
    dataset_for_model.set_format("torch")
    model.cuda()
    batch_size_to_time_sec = {}
    for batch_size in batch_sizes:
        batch = {
            k: dataset_for_model[k][:batch_size].to(model.device)
            for k in dataset_for_model.column_names
        }

        with torch.inference_mode():
            timer = benchmark.Timer(
                stmt="model(**batch)", globals={"model": model, "batch": batch}
            )
            p50 = timer.blocked_autorange().median * 1000
        batch_size_to_time_sec[batch_size] = p50
    return batch_size_to_time_sec

我们将从加载我们的模型和分词器开始,然后设置我们的数据集。

# load model
model_name = "bert-base-cased"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)
print(f"Loading tokenizer: {model_name}")
print(f"Loading model: {model_name}")

# set up train and val dataset
squad_dataset = datasets.load_dataset("squad")
tokenized_squad_dataset = {}
tokenized_squad_dataset["train"] = squad_dataset["train"].map(
    lambda x: preprocess_train_function(x, tokenizer), batched=True
)
tokenized_squad_dataset["validation"] = squad_dataset["validation"].map(
    lambda x: preprocess_validation_function(x, tokenizer),
    batched=True,
    remove_columns=squad_dataset["train"].column_names,
)
data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)

接下来,我们将在SQuAD上快速训练一个模型基线。此任务要求我们的模型在给定上下文(Wikipedia文章)中识别出回答给定问题的文本段或片段。运行以下代码让我得到了86.9的F1分数。这与NVIDIA报告的分数相当接近,差异可能是由于使用了BERT-base与BERT-large或微调超参数。

training_args = transformers.TrainingArguments(
    "trainer",
    num_train_epochs=1,
    lr_scheduler_type="constant",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=512,
)

trainer = transformers.Trainer(
    model,
    training_args,
    train_dataset=tokenized_squad_dataset["train"],
    eval_dataset=tokenized_squad_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()

# batch sizes to compare for eval
batch_sizes = [4, 16, 64, 256]
# 2:4 sparsity require fp16, so we cast here for a fair comparison
with torch.autocast("cuda"):
    with torch.inference_mode():
        predictions = trainer.predict(tokenized_squad_dataset["validation"])
    start_logits, end_logits = predictions.predictions
    fp16_baseline = compute_metrics(
        start_logits,
        end_logits,
        tokenized_squad_dataset["validation"],
        squad_dataset["validation"],
    )
    fp16_time = measure_execution_time(
        model,
        batch_sizes,
        tokenized_squad_dataset["validation"],
    )
print("fp16", fp16_baseline)
print("cuda_fp16 time", fp16_time)

# fp16 {'exact_match': 78.53358561967833, 'f1': 86.9280493093186}
# cuda_fp16 time {4: 10.927572380751371, 16: 19.607915310189128, 64: 73.18846387788653, 256: 286.91255673766136}

将BERT裁剪为2:4稀疏

现在我们有了基线,是时候裁剪BERT了。有许多不同的裁剪策略,但最常见的一种是**幅值裁剪**,其目标是移除L1范数最低的权重。NVIDIA在所有结果中都使用了幅值裁剪,这也是一种常见的基准。

为此,我们将使用``torch.ao.pruning``包,该包包含一个权重范数(幅值)稀疏器。这些稀疏器通过对模型中的权重张量应用掩码参数化来工作。这使得它们可以通过掩码掉被裁剪的权重来模拟稀疏性。

我们还必须决定将稀疏性应用于模型的哪些层,这种情况下是所有的`nn.Linear`层,除了任务特定的输出头部。这是因为半结构化稀疏性有`形状约束 <https://pytorch.org/docs/2.1/sparse.html#constructing-sparse-semi-structured-tensors>`_,而任务特定的nn.Linear层不满足这些约束。

sparsifier = WeightNormSparsifier(
    # apply sparsity to all blocks
    sparsity_level=1.0,
    # shape of 4 elemens is a block
    sparse_block_shape=(1, 4),
    # two zeros for every block of 4
    zeros_per_block=2
)

# add to config if nn.Linear and in the BERT model.
sparse_config = [
    {"tensor_fqn": f"{fqn}.weight"}
    for fqn, module in model.named_modules()
    if isinstance(module, nn.Linear) and "layer" in fqn
]

裁剪模型的第一步是插入参数化以屏蔽模型的权重。这是通过准备步骤完成的。每当我们尝试访问``.weight``时,我们将得到``mask * weight``。

# Prepare the model, insert fake-sparsity parameterizations for training
sparsifier.prepare(model, sparse_config)
print(model.bert.encoder.layer[0].output)

# BertOutput(
#   (dense): ParametrizedLinear(
#     in_features=3072, out_features=768, bias=True
#     (parametrizations): ModuleDict(
#       (weight): ParametrizationList(
#         (0-5): 6 x FakeSparsity()
#       )
#     )
#   )
#   (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#   (dropout): Dropout(p=0.1, inplace=False)
# )

然后,我们将进行一步裁剪。所有裁剪器都实现了一个``update_mask()``方法,该方法根据裁剪器实现的逻辑更新掩码。步骤方法为稀疏配置中指定的权重调用此``update_mask``函数。

我们还将评估模型以显示零次裁剪,即未经微调/重新训练的裁剪的准确性降级。

sparsifier.step()
with torch.autocast("cuda"):
    with torch.inference_mode():
        predictions = trainer.predict(tokenized_squad_dataset["validation"])
    pruned = compute_metrics(
        *predictions.predictions,
        tokenized_squad_dataset["validation"],
        squad_dataset["validation"],
    )
print("pruned eval metrics:", pruned)
# pruned eval metrics: {'exact_match': 40.59602649006622, 'f1': 56.51610004515979}

在这个阶段,我们可以开始微调模型,更新那些不会被剪枝的元素,以更好地弥补准确度损失。一旦达到满意的状态,我们可以调用``squash_mask``来将掩码和权重融合在一起。这将移除参数化,我们最终将得到一个填充零的2:4稠密模型。

trainer.train()
sparsifier.squash_mask()
torch.set_printoptions(edgeitems=4)
print(model.bert.encoder.layer[0].intermediate.dense.weight)

# Parameter containing:
# tensor([[ 0.0000, -0.0237,  0.0000,  0.0130,  ..., -0.0462, -0.0000, 0.0000, -0.0272],
#        [ 0.0436, -0.0000, -0.0000,  0.0492,  ..., -0.0000,  0.0844,  0.0340, -0.0000],
#        [-0.0302, -0.0350,  0.0000,  0.0000,  ...,  0.0303,  0.0175, -0.0000,  0.0000],
#        [ 0.0000, -0.0000, -0.0529,  0.0327,  ...,  0.0213,  0.0000, -0.0000,  0.0735],
#        ...,
#        [ 0.0000, -0.0000, -0.0258, -0.0239,  ..., -0.0000, -0.0000,  0.0380,  0.0562],
#        [-0.0432, -0.0000,  0.0000, -0.0598,  ...,  0.0000, -0.0000,  0.0262  -0.0227],
#        [ 0.0244,  0.0921, -0.0000, -0.0000,  ..., -0.0000, -0.0784,  0.0000,  0.0761],
#        [ 0.0000,  0.0225, -0.0395, -0.0000,  ..., -0.0000,  0.0684, -0.0344, -0.0000]], device='cuda:0', requires_grad=True)

加速2:4稀疏模型的推理 ——–i———————————— 现在我们已经拥有这种格式的模型,可以像快速入门指南中一样加速其推理。

model = model.cuda().half()
# accelerate for sparsity
for fqn, module in model.named_modules():
    if isinstance(module, nn.Linear) and "layer" in fqn:
        module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))

with torch.inference_mode():
    predictions = trainer.predict(tokenized_squad_dataset["validation"])
start_logits, end_logits = predictions.predictions
metrics_sparse = compute_metrics(
    start_logits,
    end_logits,
    tokenized_squad_dataset["validation"],
    squad_dataset["validation"],
)
print("sparse eval metrics: ", metrics_sparse)
sparse_perf = measure_execution_time(
    model,
    batch_sizes,
    tokenized_squad_dataset["validation"],
)
print("sparse perf metrics: ", sparse_perf)

# sparse eval metrics:  {'exact_match': 78.43897824030275, 'f1': 86.48718950090766}
# sparse perf metrics:  {4: 12.621004460379481, 16: 15.368514601141214, 64: 58.702805917710066, 256: 244.19364519417286}

重量剪枝后的模型重新训练几乎恢复了所有在剪枝时丢失的F1。同时我们在bs=16的情况下,达到了1.28倍的速度提升。请注意,并不是所有形状都适合性能提升。当批量大小较小且计算稀疏内核所用时间有限时,稀疏模型的性能可能比稠密模型更慢。

结果

指标

fp16

2:4稀疏

差异 / 提速

精确匹配率 (%)

78.53

78.44

-0.09

F1率 (%)

86.93

86.49

-0.44

时间 (bs=4)

10.93

12.62

0.87倍

时间 (bs=16)

19.61

15.37

1.28倍

时间 (bs=64)

73.19

58.70

1.25倍

时间 (bs=256)

286.91

244.19

1.18倍

总结

在本教程中,我们展示了如何对BERT进行2:4稀疏剪枝以及如何加速一个2:4稀疏模型的推理。通过利用我们的SparseSemiStructuredTensor子类,我们实现了比fp16基线快1.3倍的速度提升。我们还通过微调BERT恢复了因2:4稀疏性导致的任何F1损失(86.92稠密 vs 86.48稀疏)。

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源