• Tutorials >
  • (原型)FX图模式后训练静态量化
Shortcuts

(原型)FX图模式后训练静态量化

Created On: Feb 08, 2021 | Last Updated: Jan 24, 2025 | Last Verified: Nov 05, 2024

作者Jerry Zhang 编辑者Charles Hernandez

本教程介绍了基于`torch.fx <https://github.com/pytorch/pytorch/blob/master/torch/fx/__init__.py>`_进行图模式后训练静态量化的步骤。FX图模式量化的优势在于我们可以在模型上完全自动执行量化。尽管可能需要一些努力使模型兼容FX图模式量化(可符号跟踪的``torch.fx``),但我们会有一个单独的教程展示如何让我们想要量化的模型部分与FX图模式量化兼容。我们还提供了一个教程介绍`FX 图模式后训练动态量化 <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_dynamic.html>`_。简单总结:FX图模式API如下所示:

import torch
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import QConfigMapping
float_model.eval()
# The old 'fbgemm' is still available but 'x86' is the recommended default.
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
example_inputs = (next(iter(data_loader))[0]) # get an example input
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)  # fuse modules and insert observers
calibrate(prepared_model, data_loader_test)  # run calibration on sample data
quantized_model = convert_fx(prepared_model)  # convert the calibrated model to a quantized model

1. FX图模式量化的动机

目前,PyTorch只有一种替代方案:Eager模式静态量化:PyTorch中的Eager模式静态量化

可以看到,在Eager模式量化过程中涉及多个手动步骤,包括:

  • 显式量化和反量化激活-当一个模型混合了浮点和量化操作时,这会非常耗时。

  • 显式融合模块-需要手动识别卷积、批归一化和ReLU等融合模式的序列。

  • 对于PyTorch张量操作(如加法、连接等)需要特殊处理。

  • 函数没有一流支持(functional.conv2d和functional.linear不会被量化)。

这些大多数需要的修改都来源于Eager模式量化的基本限制。Eager模式作用于模块级别,因为它无法检查实际运行的代码(forward函数中),通过模块交换实现量化,而且在Eager模式下我们不知道模块在forward函数中的使用情况,因此需要用户手动插入QuantStub和DeQuantStub标记他们想要量化或反量化的位置。在图模式中,我们可以检查forward函数中实际执行代码的情况(例如aten函数调用),通过模块和图操作实现量化。由于图模式具有代码运行的完全可见性,我们的工具可以自动识别要融合的模块和插入观察者调用的位置、量化/反量化函数等,从而实现整个量化过程的自动化。

FX图模式量化的优点是:

  • 简单的量化流程,手动步骤最少

  • 解锁执行更高层次优化的可能性,例如自动精度选择

2. 定义助手函数并准备数据集

我们将首先进行必要的导入,定义一些助手函数并准备数据。这些步骤与`PyTorch中的Eager模式静态量化 <https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html>`_中的步骤完全相同。

要使用整个ImageNet数据集运行本教程中的代码,请首先通过以下说明`ImageNet Data <http://www.image-net.org/download>`_下载imagenet。 将下载的文件解压到&apos;data_path&apos;文件夹。

下载`torchvision resnet18 模型 <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_ 并将其重命名为``data/resnet18_pretrained_float.pth``。

import os
import sys
import time
import numpy as np

import torch
from torch.ao.quantization import get_default_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
from torchvision.models.resnet import resnet18
import torchvision.transforms as transforms

# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
_ = torch.manual_seed(191009)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
    print('')

    return top1, top5

def load_model(model_file):
    model = resnet18(pretrained=False)
    state_dict = torch.load(model_file, weights_only=True)
    model.load_state_dict(state_dict)
    model.to("cpu")
    return model

def print_size_of_model(model):
    if isinstance(model, torch.jit.RecursiveScriptModule):
        torch.jit.save(model, "temp.p")
    else:
        torch.jit.save(torch.jit.script(model), "temp.p")
    print("Size (MB):", os.path.getsize("temp.p")/1e6)
    os.remove("temp.p")

def prepare_data_loaders(data_path):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    dataset = torchvision.datasets.ImageNet(
        data_path, split="train", transform=transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    dataset_test = torchvision.datasets.ImageNet(
        data_path, split="val", transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=train_batch_size,
        sampler=train_sampler)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=eval_batch_size,
        sampler=test_sampler)

    return data_loader, data_loader_test

data_path = '~/.data/imagenet'
saved_model_dir = 'data/'
float_model_file = 'resnet18_pretrained_float.pth'

train_batch_size = 30
eval_batch_size = 50

data_loader, data_loader_test = prepare_data_loaders(data_path)
example_inputs = (next(iter(data_loader))[0])
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to("cpu")
float_model.eval()

# create another instance of the model since
# we need to keep the original model around
model_to_quantize = load_model(saved_model_dir + float_model_file).to("cpu")

3. 设置模型为评估模式

对于后训练量化,我们需要设置模型为评估模式。

model_to_quantize.eval()

4. 使用``QConfigMapping``指定如何量化模型

qconfig_mapping = QConfigMapping.set_global(default_qconfig)

我们使用Eager模式量化中使用的qconfig,``qconfig``只是用于激活和权重的观察者的命名元组。``QConfigMapping``包含从操作到qconfig的映射信息:

qconfig_mapping = (QConfigMapping()
    .set_global(qconfig_opt)  # qconfig_opt is an optional qconfig, either a valid qconfig or None
    .set_object_type(torch.nn.Conv2d, qconfig_opt)  # can be a callable...
    .set_object_type("reshape", qconfig_opt)  # ...or a string of the method
    .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig_opt) # matched in order, first match takes precedence
    .set_module_name("foo.bar", qconfig_opt)
    .set_module_name_object_type_order()
)
    # priority (in increasing order): global, object_type, module_name_regex, module_name
    # qconfig == None means fusion and quantization should be skipped for anything
    # matching the rule (unless a higher priority match is found)

与``qconfig``相关的实用功能可在`qconfig <https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/qconfig.py>`_文件中找到,``QConfigMapping``的相关功能可见`qconfig_mapping <https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/qconfig_mapping_utils.py>`。

# The old 'fbgemm' is still available but 'x86' is the recommended default.
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)

5. 为后训练静态量化准备模型

prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)

prepare_fx将BatchNorm模块折叠到前面的Conv2d模块中,并在模型中的适当位置插入观察者。

prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
print(prepared_model.graph)

6. 校准

校准功能在模型中插入观察者后运行。校准的目的是运行一些能代表工作负载的样本示例(例如训练数据集的样本),以便模型中的观察者能够观察张量的统计信息,以后可以使用这些信息计算量化参数。

def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
calibrate(prepared_model, data_loader_test)  # run calibration on sample data

7. 将模型转换为量化模型

``convert_fx``将校准后的模型转换为量化模型。

quantized_model = convert_fx(prepared_model)
print(quantized_model)

8. 评估

我们现在可以打印量化模型的大小和准确性。

print("Size of model before quantization")
print_size_of_model(float_model)
print("Size of model after quantization")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print("[before serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))

fx_graph_mode_model_file_path = saved_model_dir + "resnet18_fx_graph_mode_quantized.pth"

# this does not run due to some erros loading convrelu module:
# ModuleAttributeError: 'ConvReLU2d' object has no attribute '_modules'
# save the whole model directly
# torch.save(quantized_model, fx_graph_mode_model_file_path)
# loaded_quantized_model = torch.load(fx_graph_mode_model_file_path, weights_only=False)

# save with state_dict
# torch.save(quantized_model.state_dict(), fx_graph_mode_model_file_path)
# import copy
# model_to_quantize = copy.deepcopy(float_model)
# prepared_model = prepare_fx(model_to_quantize, {"": qconfig})
# loaded_quantized_model = convert_fx(prepared_model)
# loaded_quantized_model.load_state_dict(torch.load(fx_graph_mode_model_file_path), weights_only=True)

# save with script
torch.jit.save(torch.jit.script(quantized_model), fx_graph_mode_model_file_path)
loaded_quantized_model = torch.jit.load(fx_graph_mode_model_file_path)

top1, top5 = evaluate(loaded_quantized_model, criterion, data_loader_test)
print("[after serialization/deserialization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))

如果您想获得更好的准确性或性能,可以尝试更改`qconfig_mapping`。我们计划在数值套件中添加对图模式的支持,以便您可以轻松确定模型中不同模块对量化的敏感性。有关更多信息,请参阅`PyTorch Numeric Suite教程 <https://pytorch.org/tutorials/prototype/numeric_suite_tutorial.html>`_。

9. 调试量化模型

我们还可以打印量化和非量化卷积操作的权重以查看差异,我们将首先显式调用融合以融合模型中的卷积和批归一化操作:注意``fuse_fx``仅在评估模式下工作。

fused = fuse_fx(float_model)

conv1_weight_after_fuse = fused.conv1[0].weight[0]
conv1_weight_after_quant = quantized_model.conv1.weight().dequantize()[0]

print(torch.max(abs(conv1_weight_after_fuse - conv1_weight_after_quant)))

10. 与基线浮点模型和Eager模式量化的比较

scripted_float_model_file = "resnet18_scripted.pth"

print("Size of baseline model")
print_size_of_model(float_model)

top1, top5 = evaluate(float_model, criterion, data_loader_test)
print("Baseline Float Model Evaluation accuracy: %2.2f, %2.2f"%(top1.avg, top5.avg))
torch.jit.save(torch.jit.script(float_model), saved_model_dir + scripted_float_model_file)

在本节中,我们将比较使用FX图模式量化的模型与使用Eager模式量化的模型。FX图模式和Eager模式产生非常相似的量化模型,因此预计它们的准确性和加速效果也类似。

print("Size of Fx graph mode quantized model")
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
print("FX graph mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))

from torchvision.models.quantization.resnet import resnet18
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
print("Size of eager mode quantized model")
eager_quantized_model = torch.jit.script(eager_quantized_model)
print_size_of_model(eager_quantized_model)
top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test)
print("eager mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))
eager_mode_model_file = "resnet18_eager_mode_quantized.pth"
torch.jit.save(eager_quantized_model, saved_model_dir + eager_mode_model_file)

我们可以看到,FX图模式和Eager模式量化模型的模型大小和准确性非常相似。

在AIBench中运行模型(单线程)得到以下结果:

Scripted Float Model:
Self CPU time total: 192.48ms

Scripted Eager Mode Quantized Model:
Self CPU time total: 50.76ms

Scripted FX Graph Mode Quantized Model:
Self CPU time total: 50.63ms

如我们所见,对于resnet18,FX图模式和Eager模式量化模型相较浮点模型都获得了类似的加速度,大约比浮点模型快2-4倍。但是,相较于浮点模型的实际加速可能因模型、设备、构建、输入批大小、线程等因素而有所不同。

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源