pybind11中虚函数与继承结合

pybind11虚函数与继承结合

1. 概述

当在pybind11中处理具有继承关系的C++类层次结构时,重写虚函数需要特别注意。特别是当一个类继承自另一个已经有trampoline类的类时,需要正确处理所有继承的虚函数。

2. 多层继承的挑战

2.1 扩展的类层次结构

考虑以下扩展的C++类层次结构:

class Animal {
public:
    virtual ~Animal() {}
    virtual std::string name() = 0;
    virtual std::string go(int n_times) = 0;
};

class Dog : public Animal {
public:
    std::string name() override { return "Dog"; }
    std::string go(int n_times) override {
        std::string result;
        for (int i = 0; i < n_times; ++i)
            result += bark() + " ";
        return result;
    }
    virtual std::string bark() { return "woof!"; }
};

class Husky : public Dog {};  // 不添加新的虚函数

2.2 简单trampoline的局限性

如果只为Animal创建trampoline类,Python代码无法正确继承Dog类。因为:

  • Dog添加了新的虚函数bark()
  • Dog继承并重写了Animal的虚函数
  • Python需要能够重写所有这些虚函数

3. 正确的trampoline实现

3.1 Animal的trampoline类

class PyAnimal : public Animal, public py::trampoline_self_life_support {
public:
    using Animal::Animal;  // 继承构造函数
    
    std::string name() override {
        PYBIND11_OVERRIDE_PURE(
            std::string,  // 返回类型
            Animal,       // 父类
            name          // 函数名,注意没有括号和逗号
        );
    }
    
    std::string go(int n_times) override {
        PYBIND11_OVERRIDE_PURE(
            std::string,  // 返回类型
            Animal,       // 父类
            go,           // 函数名
            n_times       // 参数
        );
    }
};

3.2 Dog的trampoline类

重要:Dog的trampoline类必须重写所有虚函数,包括:

  • 从Animal继承的虚函数(即使Dog已经提供了实现)
  • Dog自己添加的新虚函数
class PyDog : public Dog, public py::trampoline_self_life_support {
public:
    using Dog::Dog;  // 继承构造函数
    
    // 重写从Animal继承的虚函数
    std::string name() override {
        PYBIND11_OVERRIDE(
            std::string,  // 返回类型
            Dog,          // 父类(现在是Dog,不是Animal)
            name          // 函数名,无参数时的写法
        );
    }
    
    std::string go(int n_times) override {
        PYBIND11_OVERRIDE(
            std::string,  // 返回类型
            Dog,          // 父类
            go,           // 函数名
            n_times       // 参数
        );
    }
    
    // 重写Dog添加的新虚函数
    std::string bark() override {
        PYBIND11_OVERRIDE(
            std::string,  // 返回类型
            Dog,          // 父类
            bark          // 函数名,无参数
        );
    }
};

3.3 无新虚函数的派生类

即使Husky没有添加新的虚函数,也需要trampoline类:

class PyHusky : public Husky, public py::trampoline_self_life_support {
public:
    using Husky::Husky;  // 继承构造函数
    
    // 必须重写所有继承的虚函数
    std::string name() override {
        PYBIND11_OVERRIDE(
            std::string,
            Husky,
            name
        );
    }
    
    std::string go(int n_times) override {
        PYBIND11_OVERRIDE(
            std::string,
            Husky,
            go,
            n_times
        );
    }
    
    std::string bark() override {
        PYBIND11_OVERRIDE(
            std::string,
            Husky,
            bark
        );
    }
};

4. 关键注意事项

4.1 无参数函数的特殊处理

注意无参数函数在宏中的写法:

  • 正确PYBIND11_OVERRIDE(std::string, Dog, name)
  • 错误PYBIND11_OVERRIDE(std::string, Dog, name,)(有逗号)

pybind11需要这种特殊语法来正确处理无参数函数。

4.2 父类指定

在每个trampoline类中,第二个参数应该是直接父类,而不是最顶层的基类:

  • 在PyDog中使用Dog而不是Animal
  • 在PyHusky中使用Husky而不是DogAnimal

4.3 纯虚函数vs普通虚函数

  • 使用PYBIND11_OVERRIDE_PURE处理纯虚函数
  • 使用PYBIND11_OVERRIDE处理有默认实现的虚函数

5. 避免代码重复的模板技术

5.1 模板trampoline基类

为了避免在每个trampoline类中重复实现相同的虚函数,可以使用模板:

template <class AnimalBase = Animal>
class PyAnimalT : public AnimalBase, public py::trampoline_self_life_support {
public:
    using AnimalBase::AnimalBase;  // 继承构造函数
    
    std::string name() override {
        PYBIND11_OVERRIDE_PURE(
            std::string,
            AnimalBase,
            name
        );
    }
    
    std::string go(int n_times) override {
        PYBIND11_OVERRIDE_PURE(
            std::string,
            AnimalBase,
            go,
            n_times
        );
    }
};

template <class DogBase = Dog>
class PyDogT : public PyAnimalT<DogBase> {
public:
    using PyAnimalT<DogBase>::PyAnimalT;  // 继承构造函数
    
    // 重写PyAnimalT的纯虚函数为普通虚函数
    std::string name() override {
        PYBIND11_OVERRIDE(
            std::string,
            DogBase,
            name
        );
    }
    
    std::string go(int n_times) override {
        PYBIND11_OVERRIDE(
            std::string,
            DogBase,
            go,
            n_times
        );
    }
    
    // 添加Dog特有的虚函数
    std::string bark() override {
        PYBIND11_OVERRIDE(
            std::string,
            DogBase,
            bark
        );
    }
};

5.2 使用模板trampoline

PYBIND11_MODULE(example, m) {
    // Animal使用基础模板
    py::class_<Animal, PyAnimalT<>,py::smart_holder> animal(m, "Animal");
    animal.def(py::init<>())
          .def("name", &Animal::name)
          .def("go", &Animal::go);
    
    // Dog使用Dog模板
    py::class_<Dog, PyDogT<>,py::smart_holder> dog(m, "Dog");
    dog.def(py::init<>())
        .def("bark", &Dog::bark);
    
    // Husky不需要专门的模板,使用Dog模板即可
    py::class_<Husky, PyDogT<Husky>,py::smart_holder> husky(m, "Husky");
    husky.def(py::init<>());
}

5.3 模板技术的优势

  1. 减少代码重复:只需实现一次虚函数重写
  2. 灵活性:可以轻松适应类层次结构的变化
  3. 可维护性:修改虚函数时只需在一个地方更改

6. Python中的使用示例

6.1 继承和重写

from example import *

# 测试基础功能
dog = Dog()
print(dog.name())    # 输出: Dog
print(dog.bark())    # 输出: woof!
print(dog.go(2))     # 输出: woof! woof!

# 在Python中继承Animal
class Cat(Animal):
    def name(self):
        return "Cat"
    
    def go(self, n_times):
        return "meow! " * n_times

cat = Cat()
print(cat.name())    # 输出: Cat
print(cat.go(3))     # 输出: meow! meow! meow!

# 在Python中继承Dog
class Puppy(Dog):
    def name(self):
        return "Puppy"
    
    def bark(self):
        return "yap!"

puppy = Puppy()
print(puppy.name())  # 输出: Puppy
print(puppy.bark())  # 输出: yap!
print(puppy.go(2))   # 输出: yap! yap!

# 在Python中继承Husky
class HuskyPuppy(Husky):
    def bark(self):
        return "aroooo!"

husky_pup = HuskyPuppy()
print(husky_pup.name())  # 输出: Dog (继承自Dog)
print(husky_pup.bark())  # 输出: aroooo!
print(husky_pup.go(2))   # 输出: aroooo! aroooo!

6.2 多层继承的优势

# 更复杂的继承
class SuperDog(Dog):
    def name(self):
        return "Super " + super().name()
    
    def bark(self):
        return super().bark().upper()
    
    def fly(self):
        return "I'm flying!"

super_dog = SuperDog()
print(super_dog.name())  # 输出: Super Dog
print(super_dog.bark())  # 输出: WOOF!
print(super_dog.go(2))   # 输出: WOOF! WOOF!
print(super_dog.fly())   # 输出: I'm flying!

7. 绑定代码的最佳实践

7.1 完整的绑定代码

#include <pybind11/pybind11.h>
#include <string>

namespace py = pybind11;

// C++类定义
class Animal {
public:
    virtual ~Animal() {}
    virtual std::string name() = 0;
    virtual std::string go(int n_times) = 0;
};

class Dog : public Animal {
public:
    std::string name() override { return "Dog"; }
    std::string go(int n_times) override {
        std::string result;
        for (int i = 0; i < n_times; ++i)
            result += bark() + " ";
        return result;
    }
    virtual std::string bark() { return "woof!"; }
};

class Husky : public Dog {};

// 模板trampoline类
template <class AnimalBase = Animal>
class PyAnimalT : public AnimalBase, public py::trampoline_self_life_support {
public:
    using AnimalBase::AnimalBase;
    
    std::string name() override {
        PYBIND11_OVERRIDE_PURE(
            std::string,
            AnimalBase,
            name
        );
    }
    
    std::string go(int n_times) override {
        PYBIND11_OVERRIDE_PURE(
            std::string,
            AnimalBase,
            go,
            n_times
        );
    }
};

template <class DogBase = Dog>
class PyDogT : public PyAnimalT<DogBase> {
public:
    using PyAnimalT<DogBase>::PyAnimalT;
    
    std::string name() override {
        PYBIND11_OVERRIDE(
            std::string,
            DogBase,
            name
        );
    }
    
    std::string go(int n_times) override {
        PYBIND11_OVERRIDE(
            std::string,
            DogBase,
            go,
            n_times
        );
    }
    
    std::string bark() override {
        PYBIND11_OVERRIDE(
            std::string,
            DogBase,
            bark
        );
    }
};

// 绑定代码
PYBIND11_MODULE(example, m) {
    py::class_<Animal, PyAnimalT<>,py::smart_holder> animal(m, "Animal");
    animal.def(py::init<>())
          .def("name", &Animal::name)
          .def("go", &Animal::go);
    
    py::class_<Dog, PyDogT<>,py::smart_holder> dog(m, "Dog");
    dog.def(py::init<>())
        .def("bark", &Dog::bark);
    
    py::class_<Husky, PyDogT<Husky>,py::smart_holder> husky(m, "Husky");
    husky.def(py::init<>());
}

7.2 绑定要点总结

  1. 模板参数顺序:trampoline类作为第二个模板参数
  2. 绑定目标:始终绑定到原始C++类,不是trampoline类
  3. 继承关系:正确指定C++类之间的继承关系
  4. 函数签名:确保Python和C++的函数签名一致

8. 常见问题和解决方案

8.1 函数名不匹配

问题:Python中重写的函数不被调用
原因:C++和Python函数名不匹配
解决方案:使用PYBIND11_OVERRIDE_NAME

// 当C++函数名与Python函数名不同时
std::string operator()() override {
    PYBIND11_OVERRIDE_NAME(
        std::string,
        MyClass,
        "__call__",  // Python中的函数名
        operator()   // C++中的函数名
    );
}

8.2 继承切片问题

问题:通过基类指针访问时丢失派生类信息
解决方案:使用py::smart_holder代替std::shared_ptr

8.3 内存泄漏

问题:使用自定义trampoline导致内存泄漏
解决方案:确保正确继承py::trampoline_self_life_support

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值