定制PyTorch后端通信(backend)实战

定制PyTorch后端通信(backend)实战

缘起

相关研究中,需要替换PyTorch的通信后端,定制分布式训练中例如all_reduce的算法,包括通信协议的定制。

PyTorch默认的通信后端为‘gloo’和‘nccl’,其他支持的还有mpi(需要基于源码编译),实际上PyTorch 2.7的源代码里还实现了一个‘ucc’的实验性的通信后端,后续可以参考UCC后端的代码进行相关定制。

参见:Distributed communication package - torch.distributed — PyTorch 2.7 documentation

本地环境

$ lsb_release -a   # Ubuntu 23.04

$ python3 --version  # Python 3.11.4

$ python3 -c "import torch; print(torch.__version__)"  # 2.7.0+cu126

实战

对于一个Python和PyTorch的新手,摸索了不知道多久,终于发现解决问题的入口:

1、使用 C++ 扩展定制进程组后端 — PyTorch 教程 2.7.0+cu126 文档 - PyTorch 深度学习库 此链接为中文,标题‘使用c++扩展定制(PyTorch)进程组后端’,正是我要找的东西。

2、https://github.com/H-Huang/torch_collective_extension 一个扩展torch集合通信的例子,该项目给出了两种扩展方式,一种基于custom_backend(推荐的方式),一种基于custom_process_group(旧方式,网上大部分找到的内容都是这种方式),其中基于custom_backend模式的代码与(1)相同,但更全面,默认给出了有关PyTorch所有集合通信的接口示例。

下面主要基于(1)的代码进行验证

代码

在用户目录下创建相关的目录及代码,本例中目录名为‘custom_bankend’,包含四个文件:

  1. dummy.hpp  --  c++头文件
  2. dummy.cpp  --  c++源文件
  3. setup.py  --  Python项目编译和配置脚本
  4. example.py  --  验证测试代码

以下假设用户目录为‘/home/~usrname’

1. dummy.hpp

// file name: dummy.hpp

// 代码源自PyTorch官网实例:https://pytorch.ac.cn/tutorials/intermediate/process_group_cpp_extension_tutorial.html

#pragma once

// 根据实际情况调整路径,参见 setup.py

// #include <torch/python.h>

#include <torch/csrc/api/include/torch/python.h>

#include <torch/csrc/distributed/c10d/Backend.hpp>

#include <torch/csrc/distributed/c10d/Work.hpp>

#include <torch/csrc/distributed/c10d/Store.hpp>

#include <torch/csrc/distributed/c10d/Types.hpp>

#include <torch/csrc/distributed/c10d/Utils.hpp>

#include <pybind11/chrono.h>

namespace c10d {

class BackendDummy : public Backend {

  public:

    BackendDummy(int rank, int size);

    c10::intrusive_ptr<Work> allgather(

        std::vector<std::vector<at::Tensor>>& outputTensors,

        std::vector<at::Tensor>& inputTensors,

        const AllgatherOptions& opts = AllgatherOptions()) override;

    c10::intrusive_ptr<Work> allreduce(

        std::vector<at::Tensor>& tensors,

        const AllreduceOptions& opts = AllreduceOptions()) override;

    // The collective communication APIs without a custom implementation

    // will error out if invoked by application code.

    static c10::intrusive_ptr<Backend> createBackendDummy(

        const c10::intrusive_ptr<::c10d::Store>& store,

        int rank,

        int size,

        const std::chrono::duration<float>& timeout);

    static void BackendDummyConstructor() __attribute__((constructor)) {

        py::object module = py::module::import("torch.distributed");

        py::object register_backend =

            module.attr("Backend").attr("register_backend");

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值