• Tutorials >
  • (原型)PyTorch 2导出量化感知训练(QAT)
Shortcuts

(原型)PyTorch 2导出量化感知训练(QAT)

Created On: Oct 02, 2023 | Last Updated: Oct 23, 2024 | Last Verified: Nov 05, 2024

作者: Andrew Or

本教程展示了如何基于`torch.export.export <https://pytorch.org/docs/main/export.html>`_在图模式下进行量化感知训练(QAT)。有关PyTorch 2 Export量化的一般详细信息,请参考`后训练量化教程 <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_。

PyTorch 2 Export QAT流如下——它在大多数方面与后训练量化(PTQ)流相似:

import torch
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
  prepare_qat_pt2e,
  convert_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
  XNNPACKQuantizer,
  get_symmetric_quantization_config,
)

class M(torch.nn.Module):
   def __init__(self):
      super().__init__()
      self.linear = torch.nn.Linear(5, 10)

   def forward(self, x):
      return self.linear(x)


example_inputs = (torch.randn(1, 5),)
m = M()

# Step 1. program capture
# This is available for pytorch 2.5+, for more details on lower pytorch versions
# please check `Export the model with torch.export` section
m = torch.export.export_for_training(m, example_inputs).module()
# we get a model with aten ops

# Step 2. quantization-aware training
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
m = prepare_qat_pt2e(m, quantizer)

# train omitted

m = convert_pt2e(m)
# we have a model with aten ops doing integer computations when possible

# move the quantized model to eval mode, equivalent to `m.eval()`
torch.ao.quantization.move_exported_model_to_eval(m)

请注意,在程序捕获后调用``model.eval()``或``model.train()``是不允许的,因为这些方法不再正确改变某些操作(如dropout和批量归一化)的行为。相反,请分别使用``torch.ao.quantization.move_exported_model_to_eval()``和``torch.ao.quantization.move_exported_model_to_train()``(即将推出)。

定义辅助函数并准备数据集

要使用整个ImageNet数据集运行本教程中的代码,请首先按照此处的说明下载ImageNet`ImageNet Data <http://www.image-net.org/download>`_。将下载的文件解压到``data_path``文件夹中。

接下来,下载`torchvision resnet18模型 <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_并将其重命名为``data/resnet18_pretrained_float.pth``。

我们将从必要的导入、定义一些辅助函数以及准备数据开始。这些步骤与`静态急切模式后训练量化教程 <https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html>`_定义的步骤非常相似:

import os
import sys
import time
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
from torchvision.models.resnet import resnet18
import torchvision.transforms as transforms

# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
_ = torch.manual_seed(191009)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def accuracy(output, target, topk=(1,)):
    """
    Computes the accuracy over the k top predictions for the specified
    values of k.
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def evaluate(model, criterion, data_loader, device):
    torch.ao.quantization.move_exported_model_to_eval(model)
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            image = image.to(device)
            target = target.to(device)
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
    print('')

    return top1, top5

def load_model(model_file):
    model = resnet18(pretrained=False)
    state_dict = torch.load(model_file, weights_only=True)
    model.load_state_dict(state_dict)
    return model

def print_size_of_model(model):
    if isinstance(model, torch.jit.RecursiveScriptModule):
        torch.jit.save(model, "temp.p")
    else:
        torch.jit.save(torch.jit.script(model), "temp.p")
    print("Size (MB):", os.path.getsize("temp.p")/1e6)
    os.remove("temp.p")

def prepare_data_loaders(data_path):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    dataset = torchvision.datasets.ImageNet(
        data_path, split="train", transform=transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    dataset_test = torchvision.datasets.ImageNet(
        data_path, split="val", transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=train_batch_size,
        sampler=train_sampler)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=eval_batch_size,
        sampler=test_sampler)

    return data_loader, data_loader_test

def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches):
    # Note: do not call model.train() here, since this doesn't work on an exported model.
    # Instead, call `torch.ao.quantization.move_exported_model_to_train(model)`, which will
    # be added in the near future
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    avgloss = AverageMeter('Loss', '1.5f')

    cnt = 0
    for image, target in data_loader:
        start_time = time.time()
        print('.', end = '')
        cnt += 1
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        top1.update(acc1[0], image.size(0))
        top5.update(acc5[0], image.size(0))
        avgloss.update(loss, image.size(0))
        if cnt >= ntrain_batches:
            print('Loss', avgloss.avg)

            print('Training: * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
                  .format(top1=top1, top5=top5))
            return

    print('Full imagenet train set:  * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
          .format(top1=top1, top5=top5))
    return

data_path = '~/.data/imagenet'
saved_model_dir = 'data/'
float_model_file = 'resnet18_pretrained_float.pth'

train_batch_size = 32
eval_batch_size = 32

data_loader, data_loader_test = prepare_data_loaders(data_path)
example_inputs = (next(iter(data_loader))[0])
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to("cuda")

使用torch.export导出模型

以下是如何使用``torch.export``导出模型:

from torch._export import capture_pre_autograd_graph

example_inputs = (torch.rand(2, 3, 224, 224),)
# for pytorch 2.5+
exported_model = torch.export.export_for_training(float_model, example_inputs).module()
# for pytorch 2.4 and before
# from torch._export import capture_pre_autograd_graph
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
# or, to capture with dynamic dimensions:

# for pytorch 2.5+
dynamic_shapes = tuple(
  {0: torch.export.Dim("dim")} if i == 0 else None
  for i in range(len(example_inputs))
)
exported_model = torch.export.export_for_training(float_model, example_inputs, dynamic_shapes=dynamic_shapes).module()

# for pytorch 2.4 and before
# dynamic_shape API may vary as well
# from torch._export import dynamic_dim

# example_inputs = (torch.rand(2, 3, 224, 224),)
# exported_model = capture_pre_autograd_graph(
#     float_model,
#     example_inputs,
#     constraints=[dynamic_dim(example_inputs[0], 0)],
# )

导入特定后端的量化器并配置如何对模型进行量化

以下代码片段描述了如何对模型进行量化:

from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))

``Quantizer``是后端特定的,每个``Quantizer``将提供自己的方式来允许用户配置其模型。

备注

查看我们的`教程 <https://pytorch.org/tutorials/prototype/pt2e_quantizer.html>`_,描述如何编写新的``Quantizer``。

为量化感知训练准备模型

prepare_qat_pt2e``在模型中的适当位置插入假量化操作并执行适当的QAT“融合”,例如``Conv2d + BatchNorm2d,以获得更好的训练准确性。融合操作在准备好的图中表示为ATen操作子图。

prepared_model = prepare_qat_pt2e(exported_model, quantizer)
print(prepared_model)

备注

如果您的模型包含批规范化,在导出模型时,您在图中获得的实际 ATen 操作取决于模型所在的设备。如果模型在 CPU 上,那么您将获得 torch.ops.aten._native_batch_norm_legit。如果模型在 CUDA 上,那么您将获得 torch.ops.aten.cudnn_batch_norm。然而,这并不是固定的,将来可能会有所变化。

在这两个操作之间,已经证明 torch.ops.aten.cudnn_batch_norm 在像 MobileNetV2 这样的模型上提供了更好的数值性能。要获取此操作,可以在导出之前调用 model.cuda(),或者在准备之后运行以下代码以手动交换操作:

for n in prepared_model.graph.nodes:
    if n.target == torch.ops.aten._native_batch_norm_legit.default:
        n.target = torch.ops.aten.cudnn_batch_norm.default
prepared_model.recompile()

未来,我们计划整合批规范化操作,以使上述步骤不再必要。

训练循环

训练循环与以前版本的 QAT 中的训练循环类似。为了获得更好的准确性,您可以选择在一定数量的训练周期后禁用观察者和更新批规范化统计数据,或者每隔 N 个周期评估迄今为止训练的 QAT 或量化模型。

num_epochs = 10
num_train_batches = 20
num_eval_batches = 20
num_observer_update_epochs = 4
num_batch_norm_update_epochs = 3
num_epochs_between_evals = 2

# QAT takes time and one needs to train over a few epochs.
# Train and check accuracy after each epoch
for nepoch in range(num_epochs):
    train_one_epoch(prepared_model, criterion, optimizer, data_loader, "cuda", num_train_batches)

    # Optionally disable observer/batchnorm stats after certain number of epochs
    if epoch >= num_observer_update_epochs:
        print("Disabling observer for subseq epochs, epoch = ", epoch)
        prepared_model.apply(torch.ao.quantization.disable_observer)
    if epoch >= num_batch_norm_update_epochs:
        print("Freezing BN for subseq epochs, epoch = ", epoch)
        for n in prepared_model.graph.nodes:
            # Args: input, weight, bias, running_mean, running_var, training, momentum, eps
            # We set the `training` flag to False here to freeze BN stats
            if n.target in [
                torch.ops.aten._native_batch_norm_legit.default,
                torch.ops.aten.cudnn_batch_norm.default,
            ]:
                new_args = list(n.args)
                new_args[5] = False
                n.args = new_args
        prepared_model.recompile()

    # Check the quantized accuracy every N epochs
    # Note: If you wish to just evaluate the QAT model (not the quantized model),
    # then you can just call `torch.ao.quantization.move_exported_model_to_eval/train`.
    # However, the latter API is not ready yet and will be available in the near future.
    if (nepoch + 1) % num_epochs_between_evals == 0:
        prepared_model_copy = copy.deepcopy(prepared_model)
        quantized_model = convert_pt2e(prepared_model_copy)
        top1, top5 = evaluate(quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches)
        print('Epoch %d: Evaluation accuracy on %d images, %2.2f' % (nepoch, num_eval_batches * eval_batch_size, top1.avg))

保存和加载模型检查点

PyTorch 2 导出 QAT 流的模型检查点与其他训练流相同。它们对于暂停训练并稍后恢复训练、从失败的训练运行中恢复以及稍后在不同机器上进行推理很有用。您可以在训练期间或训练后保存模型检查点,如下所示:

checkpoint_path = "/path/to/my/checkpoint_%s.pth" % nepoch
torch.save(prepared_model.state_dict(), "checkpoint_path")

要加载检查点,必须以与最初导出和准备时完全相同的方式导出和准备模型。例如:

from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from torchvision.models.resnet import resnet18

example_inputs = (torch.rand(2, 3, 224, 224),)
float_model = resnet18(pretrained=False)
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
prepared_model.load_state_dict(torch.load(checkpoint_path))

# resume training or perform inference

将训练的模型转换为量化模型

convert_pt2e 接受一个经过校准的模型并生成一个量化模型。请注意,在推理之前,您必须首先调用 torch.ao.quantization.move_exported_model_to_eval() 以确保像 dropout 这样的操作在评估图中表现正确。否则,例如在推理期间,我们会在前向传播中继续错误地应用 dropout。

quantized_model = convert_pt2e(prepared_model)

# move certain ops like dropout to eval mode, equivalent to `m.eval()`
torch.ao.quantization.move_exported_model_to_eval(m)

print(quantized_model)

top1, top5 = evaluate(quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Final evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg))

总结

在本教程中,我们演示了如何运行 PyTorch 2 导出量化感知训练 (QAT) 流程。在转换之后,其余流程与后训练量化 (PTQ) 相同;用户可以序列化/反序列化模型,并进一步将其降低到支持基于 XNNPACK 后端推理的后端。有关更多详细信息,请参阅 PTQ 教程

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源