(原型)FX图模式量化用户指南¶
Created On: Aug 20, 2021 | Last Updated: Dec 12, 2023 | Last Verified: Nov 05, 2024
作者:Jerry Zhang
FX图模式量化需要一个可符号跟踪的模型。我们使用FX框架将一个可符号跟踪的nn.Module实例转换为IR,并操作IR以执行量化过程。如有关于符号跟踪模型的问题,请在`PyTorch讨论论坛 <https://discuss.pytorch.org/c/quantization/17>`_中发布您的问题。
量化只适用于模型中可以符号化跟踪的部分。基于数据的控制流(if语句/for循环等使用符号化跟踪值)是常见的一种模式,但不被支持。如果您的模型无法端到端符号化跟踪,您有以下几个选项可以仅对模型的一部分启用FX图模式量化。您可以使用以下选项的任意组合:
- 不可跟踪的代码不需要量化
只符号化跟踪需要量化的代码
跳过符号化跟踪不可跟踪的代码
- 不可跟踪的代码需要量化
重构代码使其可符号化跟踪
编写自己的已观察和量化的子模块
如果不可符号化跟踪的代码不需要量化,我们有以下两个选项来运行FX图模式量化:
只符号化跟踪需要量化的代码¶
当整个模型不可符号化跟踪,但我们想要量化的子模块可符号化跟踪时,我们可以仅对该子模块运行量化。
之前:
class M(nn.Module):
def forward(self, x):
x = non_traceable_code_1(x)
x = traceable_code(x)
x = non_traceable_code_2(x)
return x
之后:
class FP32Traceable(nn.Module):
def forward(self, x):
x = traceable_code(x)
return x
class M(nn.Module):
def __init__(self):
self.traceable_submodule = FP32Traceable(...)
def forward(self, x):
x = self.traceable_code_1(x)
# We'll only symbolic trace/quantize this submodule
x = self.traceable_submodule(x)
x = self.traceable_code_2(x)
return x
量化代码:
qconfig_mapping = QConfigMapping().set_global(qconfig)
model_fp32.traceable_submodule = \
prepare_fx(model_fp32.traceable_submodule, qconfig_mapping, example_inputs)
注意如果需要保留原始模型,您需要在调用量化API之前自行复制。
跳过符号化跟踪不可跟踪的代码¶
当我们在模块中有一些不可跟踪的代码,并且这部分代码不需要量化时,我们可以将这部分代码拆分为一个子模块并跳过符号化跟踪该子模块。
之前
class M(nn.Module):
def forward(self, x):
x = self.traceable_code_1(x)
x = non_traceable_code(x)
x = self.traceable_code_2(x)
return x
之后,将不可跟踪的部分移到一个模块并标记为叶子
class FP32NonTraceable(nn.Module):
def forward(self, x):
x = non_traceable_code(x)
return x
class M(nn.Module):
def __init__(self):
...
self.non_traceable_submodule = FP32NonTraceable(...)
def forward(self, x):
x = self.traceable_code_1(x)
# we will configure the quantization call to not trace through
# this submodule
x = self.non_traceable_submodule(x)
x = self.traceable_code_2(x)
return x
量化代码:
qconfig_mapping = QConfigMapping.set_global(qconfig)
prepare_custom_config_dict = {
# option 1
"non_traceable_module_name": "non_traceable_submodule",
# option 2
"non_traceable_module_class": [MNonTraceable],
}
model_prepared = prepare_fx(
model_fp32,
qconfig_mapping,
example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict,
)
如果不可符号化跟踪的代码需要量化,我们有以下两个选项:
重构代码使其可符号化跟踪¶
如果重构代码并使其可符号化跟踪较容易,那么我们可以重构代码并去除在Python中使用不可跟踪的结构。
关于符号化跟踪支持的更多信息可以查看 这里。
之前:
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
此处不可符号化跟踪,因为x.view(*new_x_shape)的解包操作不被支持,然而,我们可以轻松移除解包操作,因为x.view也支持列表输入。
之后:
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
这可以与其他方法结合使用,而量化代码依赖于模型。
编写自己的已观察和量化的子模块¶
如果不可跟踪的代码不能重构为可符号化跟踪,例如其中有些循环无法消除(例如 nn.LSTM),我们需要将不可跟踪的代码拆分成一个子模块(在FX图模式量化中我们称其为 CustomModule),并定义该子模块的已观察和量化版本(用于静态后训练量化或静态量化感知训练),或者定义量化版本(用于动态后训练量化和仅权重量化)。
之前:
class M(nn.Module):
def forward(self, x):
x = traceable_code_1(x)
x = non_traceable_code(x)
x = traceable_code_1(x)
return x
之后:
1. Factor out non_traceable_code to FP32NonTraceable non-traceable logic, wrapped in a module
class FP32NonTraceable:
...
2. Define observed version of FP32NonTraceable
class ObservedNonTraceable:
@classmethod
def from_float(cls, ...):
...
3. Define statically quantized version of FP32NonTraceable and a class method “from_observed” to convert from ObservedNonTraceable to StaticQuantNonTraceable
class StaticQuantNonTraceable:
@classmethod
def from_observed(cls, ...):
...
# refactor parent class to call FP32NonTraceable
class M(nn.Module):
def __init__(self):
...
self.non_traceable_submodule = FP32NonTraceable(...)
def forward(self, x):
x = self.traceable_code_1(x)
# this part will be quantized manually
x = self.non_traceable_submodule(x)
x = self.traceable_code_1(x)
return x
量化代码:
# post training static quantization or
# quantization aware training (that produces a statically quantized module)v
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
"static": {
FP32NonTraceable: ObservedNonTraceable,
}
},
}
model_prepared = prepare_fx(
model_fp32,
qconfig_mapping,
example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict)
校准/训练(未显示)
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"static": {
ObservedNonTraceable: StaticQuantNonTraceable,
}
},
}
model_quantized = convert_fx(
model_prepared,
convert_custom_config_dict)
动态后训练量化/仅权重量化在这两种模式中,我们不需要观察原始模型,因此我们只需定义量化模型
class DynamicQuantNonTraceable: # or WeightOnlyQuantMNonTraceable
...
@classmethod
def from_observed(cls, ...):
...
prepare_custom_config_dict = {
"non_traceable_module_class": [
FP32NonTraceable
]
}
# The example is for post training quantization
model_fp32.eval()
model_prepared = prepare_fx(
model_fp32,
qconfig_mapping,
example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict)
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"dynamic": {
FP32NonTraceable: DynamicQuantNonTraceable,
}
},
}
model_quantized = convert_fx(
model_prepared,
convert_custom_config_dict)
您还可以在“torch/test/quantization/test_quantize_fx.py”的``test_custom_module_class``测试中找到自定义模块的示例。