定制PyTorch后端通信(backend)实战
缘起
相关研究中,需要替换PyTorch的通信后端,定制分布式训练中例如all_reduce的算法,包括通信协议的定制。
PyTorch默认的通信后端为‘gloo’和‘nccl’,其他支持的还有mpi(需要基于源码编译),实际上PyTorch 2.7的源代码里还实现了一个‘ucc’的实验性的通信后端,后续可以参考UCC后端的代码进行相关定制。
参见:Distributed communication package - torch.distributed — PyTorch 2.7 documentation
本地环境
$ lsb_release -a # Ubuntu 23.04
$ python3 --version # Python 3.11.4
$ python3 -c "import torch; print(torch.__version__)" # 2.7.0+cu126
实战
对于一个Python和PyTorch的新手,摸索了不知道多久,终于发现解决问题的入口:
1、使用 C++ 扩展定制进程组后端 — PyTorch 教程 2.7.0+cu126 文档 - PyTorch 深度学习库 此链接为中文,标题‘使用c++扩展定制(PyTorch)进程组后端’,正是我要找的东西。
2、https://github.com/H-Huang/torch_collective_extension 一个扩展torch集合通信的例子,该项目给出了两种扩展方式,一种基于custom_backend(推荐的方式),一种基于custom_process_group(旧方式,网上大部分找到的内容都是这种方式),其中基于custom_backend模式的代码与(1)相同,但更全面,默认给出了有关PyTorch所有集合通信的接口示例。
下面主要基于(1)的代码进行验证
代码
在用户目录下创建相关的目录及代码,本例中目录名为‘custom_bankend’,包含四个文件:
- dummy.hpp -- c++头文件
- dummy.cpp -- c++源文件
- setup.py -- Python项目编译和配置脚本
- example.py -- 验证测试代码
以下假设用户目录为‘/home/~usrname’
1. dummy.hpp
| // file name: dummy.hpp // 代码源自PyTorch官网实例:https://pytorch.ac.cn/tutorials/intermediate/process_group_cpp_extension_tutorial.html #pragma once // 根据实际情况调整路径,参见 setup.py // #include <torch/python.h> #include <torch/csrc/api/include/torch/python.h> #include <torch/csrc/distributed/c10d/Backend.hpp> #include <torch/csrc/distributed/c10d/Work.hpp> #include <torch/csrc/distributed/c10d/Store.hpp> #include <torch/csrc/distributed/c10d/Types.hpp> #include <torch/csrc/distributed/c10d/Utils.hpp> #include <pybind11/chrono.h> namespace c10d { class BackendDummy : public Backend { public: BackendDummy(int rank, int size); c10::intrusive_ptr<Work> allgather( std::vector<std::vector<at::Tensor>>& outputTensors, std::vector<at::Tensor>& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr<Work> allreduce( std::vector<at::Tensor>& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; // The collective communication APIs without a custom implementation // will error out if invoked by application code. static c10::intrusive_ptr<Backend> createBackendDummy( const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration<float>& timeout); static void BackendDummyConstructor() __attribute__((constructor)) { py::object module = py::module::import("torch.distributed"); py::object register_backend = module.attr("Backend").attr("register_backend"); |

最低0.47元/天 解锁文章
159

被折叠的 条评论
为什么被折叠?



