Inductor C++ 包装器教程¶
Created On: Oct 02, 2023 | Last Updated: Jan 16, 2024 | Last Verified: Nov 05, 2024
作者: Chunyuan Wu, Bin Bao, Jiong Gong
简介¶
Python作为PyTorch的主要接口,使用起来简单且便于开发和调试。Inductor的默认包装器生成Python代码以调用生成的内核和外部内核。然而,在需要高性能的部署中,Python作为一种解释语言,相较于编译语言运行较慢。
我们通过利用PyTorch的C++ API实现了一个Inductor C++包装器来生成结合生成和外部内核的纯C++代码。这允许在纯C++中执行每个捕获的Dynamo图,从而减少图中Python的开销。
启用接口¶
此功能仍处于原型阶段。要激活此功能,请在您的代码中添加以下内容:
import torch._inductor.config as config
config.cpp_wrapper = True
这将通过减少Inductor包装器的Python开销来加速您的模型。
示例代码¶
我们将使用以下前端代码作为示例:
import torch
def fn(x):
return torch.tensor(list(range(2, 40, 2)), device=x.device) + x
x = torch.randn(1)
opt_fn = torch.compile()(fn)
y = opt_fn(x)
对于CPU
默认Python包装器生成的Inductor代码主要部分将如下所示:
def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (1, ), (1, ))
buf0 = empty_strided((19, ), (1, ), device='cpu', dtype=torch.float32)
cpp_fused_add_lift_fresh_0(c_void_p(constant0.data_ptr()), c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
del arg0_1
return (buf0, )
打开C++包装器后,call``函数的生成代码将变为C++扩展``module``的C++函数``inductor_entry_cpp
:
std::vector<at::Tensor> inductor_entry_cpp(const std::vector<at::Tensor>& args) {
at::Tensor arg0_1 = args[0];
at::Tensor constant0 = args[1];
auto buf0 = at::empty_strided({19L, }, {1L, }, at::device(at::kCPU).dtype(at::kFloat));
cpp_fused_add_lift_fresh_0((long*)(constant0.data_ptr()), (float*)(arg0_1.data_ptr()), (float*)(buf0.data_ptr()));
arg0_1.reset();
return {buf0};
}
module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'c2buojsvlqbywxe3itb43hldieh4jqulk72iswa2awalwev7hjn2', False)
def _wrap_func(f):
def g(args):
args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]
constants_tensor = [constant0]
args_tensor.extend(constants_tensor)
return f(args_tensor)
return g
call = _wrap_func(module.inductor_entry_cpp)
对于GPU
基于相同示例代码,GPU生成的代码将如下所示:
def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (1, ), (1, ))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty_strided((19, ), (1, ), device='cuda', dtype=torch.float32)
# Source Nodes: [add, tensor], Original ATen: [aten.add, aten.lift_fresh]
stream0 = get_cuda_stream(0)
triton_poi_fused_add_lift_fresh_0.run(constant0, arg0_1, buf0, 19, grid=grid(19), stream=stream0)
run_intermediate_hooks('add', buf0)
del arg0_1
return (buf0, )
打开C++包装器后,将生成以下等效的C++代码:
std::vector<at::Tensor> inductor_entry_cpp(const std::vector<at::Tensor>& args) {
at::Tensor arg0_1 = args[0];
at::Tensor constant0 = args[1];
at::cuda::CUDAGuard device_guard(0);
auto buf0 = at::empty_strided({19L, }, {1L, }, at::TensorOptions(c10::Device(at::kCUDA, 0)).dtype(at::kFloat));
// Source Nodes: [add, tensor], Original ATen: [aten.add, aten.lift_fresh]
if (triton_poi_fused_add_lift_fresh_0 == nullptr) {
triton_poi_fused_add_lift_fresh_0 = loadKernel("/tmp/torchinductor_user/mm/cmm6xjgijjffxjku4akv55eyzibirvw6bti6uqmfnruujm5cvvmw.cubin", "triton_poi_fused_add_lift_fresh_0_0d1d2d3");
}
CUdeviceptr var_0 = reinterpret_cast<CUdeviceptr>(constant0.data_ptr());
CUdeviceptr var_1 = reinterpret_cast<CUdeviceptr>(arg0_1.data_ptr());
CUdeviceptr var_2 = reinterpret_cast<CUdeviceptr>(buf0.data_ptr());
auto var_3 = 19;
void* kernel_args_var_0[] = {&var_0, &var_1, &var_2, &var_3};
cudaStream_t stream0 = at::cuda::getCurrentCUDAStream(0);
launchKernel(triton_poi_fused_add_lift_fresh_0, 1, 1, 1, 1, 0, kernel_args_var_0, stream0);
arg0_1.reset();
return {buf0};
}
module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'czbpeilh4qqmbyejdgsbpdfuk2ss5jigl2qjb7xs4gearrjvuwem', True)
def _wrap_func(f):
def g(args):
args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]
constants_tensor = [constant0]
args_tensor.extend(constants_tensor)
return f(args_tensor)
return g
call = _wrap_func(module.inductor_entry_cpp)
总结¶
在本教程中,我们介绍了TorchInductor中的一个新C++包装器,只需两行代码的更改即可加速您的模型。我们解释了此新功能的动机,并讲解了用于激活此实验功能的易于使用的API。此外,我们展示了在CPU和GPU上使用默认Python包装器和新C++包装器生成的Inductor代码,并以直观的方式展示了这两个包装器之间的区别。
此功能仍处于原型阶段。如果您有任何功能请求或遇到任何问题,请在`GitHub问题 <https://github.com/pytorch/pytorch/issues>`_中提交错误报告。