备注
点击 此处 下载完整示例代码
TorchScript中的模型冻结¶
Created On: Jul 28, 2020 | Last Updated: Dec 02, 2024 | Last Verified: Nov 05, 2024
警告
TorchScript不再处于活跃开发阶段。
在本教程中,我们介绍了TorchScript中的*模型冻结*语法。冻结是将Pytorch模块参数和属性值内联到TorchScript内部表示的过程。参数和属性值被视为最终值,它们不能在生成的冻结模块中修改。
基础语法¶
模型冻结可以通过以下API调用:
torch.jit.freeze(mod : ScriptModule, names : str[]) -> ScriptModule
注意,输入模块可以是通过脚本化或跟踪生成的结果。参阅 https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
接下来,我们通过示例展示冻结如何工作:
import torch, time
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
self.dropout1 = torch.nn.Dropout2d(0.25)
self.dropout2 = torch.nn.Dropout2d(0.5)
self.fc1 = torch.nn.Linear(9216, 128)
self.fc2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = torch.nn.functional.relu(x)
x = self.conv2(x)
x = torch.nn.functional.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = torch.nn.functional.log_softmax(x, dim=1)
return output
@torch.jit.export
def version(self):
return 1.0
net = torch.jit.script(Net())
fnet = torch.jit.freeze(net)
print(net.conv1.weight.size())
print(net.conv1.bias)
try:
print(fnet.conv1.bias)
# without exception handling, prints:
# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
# with name 'conv1'
except RuntimeError:
print("field 'conv1' is inlined. It does not exist in 'fnet'")
try:
fnet.version()
# without exception handling, prints:
# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
# with name 'version'
except RuntimeError:
print("method 'version' is not deleted in fnet. Only 'forward' is preserved")
fnet2 = torch.jit.freeze(net, ["version"])
print(fnet2.version())
B=1
warmup = 1
iter = 1000
input = torch.rand(B, 1,28, 28)
start = time.time()
for i in range(warmup):
net(input)
end = time.time()
print("Scripted - Warm up time: {0:7.4f}".format(end-start), flush=True)
start = time.time()
for i in range(warmup):
fnet(input)
end = time.time()
print("Frozen - Warm up time: {0:7.4f}".format(end-start), flush=True)
start = time.time()
for i in range(iter):
input = torch.rand(B, 1,28, 28)
net(input)
end = time.time()
print("Scripted - Inference: {0:5.2f}".format(end-start), flush=True)
start = time.time()
for i in range(iter):
input = torch.rand(B, 1,28, 28)
fnet2(input)
end = time.time()
print("Frozen - Inference time: {0:5.2f}".format(end-start), flush =True)
在我的机器上,我测量了时间:
脚本化模型 - 热身时间: 0.0107
冻结模型 - 热身时间: 0.0048
脚本化模型 - 推理时间: 1.35
冻结模型 - 推理时间: 1.17
在我们的示例中,热身时间测量前两次运行。冻结模型比脚本化模型快了50%。在一些更复杂的模型上,我们观察到热身时间的更高提速。冻结实现这种提速是因为它完成了一些TorchScript在首次运行时必须做的工作。
推理时间测量在模型热身后推理执行时间。虽然我们观察到执行时间的显著变化,但冻结模型通常比脚本化模型快约15%。当输入较大时,我们观察到较小的提速,因为执行被 Tensor 操作主导。