C++如何自己实现一个shared_ptr

1. shared_ptr介绍

C++中的shared_ptr智能指针是行为类似于指针的类对象,封装了原始指针并提供了自动内存管理的功能(不用手动delete),从而实现了RAII的思想。

shared_ptr 内部是利用引用计数来实现内存的自动管理,每当复制一个 shared_ptr,引用计数会 + 1。当一个 shared_ptr 离开作用域时,引用计数会 - 1。当引用计数为 0 的时候,则 delete 内存。

2. 实现的功能

  1. 构造函数
  2. 析构函数
  3. 拷贝构造函数
  4. 拷贝赋值运算符
  5. 移动构造函数
  6. 移动赋值运算符
  7. 解引用、箭头运算符
  8. 引用计数、原始指针、重置指针

3. 具体实现

shared_ptr.h 如下:

#pragma once

#include <atomic>  // 引入原子操作

template <typename T>
class shared_ptr {
private:
    T* ptr;                                 // 指向管理的对象
    std::atomic<std::size_t>* ref_count;    // 原子引用计数

    // 释放资源
    void release() {
        // P.S. 这里使用 std::memory_order_acq_rel 内存序,保证释放资源的同步
        if (ref_count && ref_count->fetch_sub(1, std::memory_order_acq_rel) == 1) {
            delete ptr;
            delete ref_count;
        }
    }

public:
    // 默认构造函数
    shared_ptr() : ptr(nullptr), ref_count(nullptr) {}

    // 构造函数
    // P.S. 这里使用 explicit 关键字,防止隐式类型转换
    // shared_ptr<int> ptr1 = new int(10);  不允许出现
    explicit shared_ptr(T* p) : ptr(p), ref_count(p ? new std::atomic<std::size_t>(1) : nullptr) {

    }

    // 析构函数
    ~shared_ptr() {
        release();
    }

    // 拷贝构造函数
    shared_ptr(const shared_ptr<T>& other) : ptr(other.ptr), ref_count(other.ref_count) {
        if (ref_count) {
            ref_count->fetch_add(1, std::memory_order_relaxed);  // 引用计数增加,不需要强内存序
        }
    }

    // 拷贝赋值运算符
    shared_ptr<T>& operator=(const shared_ptr<T>& other) {
        if (this != &other) {
            release();  // 释放当前资源
            ptr = other.ptr;
            ref_count = other.ref_count;
            if (ref_count) {
                ref_count->fetch_add(1, std::memory_order_relaxed);  // 引用计数增加
            }
        }
        return *this;
    }

    // 移动构造函数
	// P.S. noexcept 关键字表示该函数不会抛出异常。
    // 标准库中的某些操作(如 std::swap)要求移动操作是 noexcept 的,以确保异常安全。
	// noexcept 可以帮助编译器生成更高效的代码,因为它不需要为异常处理生成额外的代码。
    shared_ptr(shared_ptr<T>&& other) noexcept : ptr(other.ptr), ref_count(other.ref_count) {
        other.ptr = nullptr;
        other.ref_count = nullptr;
    }

    // 移动赋值运算符
    shared_ptr<T>& operator=(shared_ptr<T>&& other) noexcept {
        if (this != &other) {
            release();  // 释放当前资源
            ptr = other.ptr;
            ref_count = other.ref_count;
            other.ptr = nullptr;
            other.ref_count = nullptr;
        }
        return *this;
    }

    // 解引用运算符
    // P.S. const 关键字表示该函数不会修改对象的状态。
    T& operator*() const {
        return *ptr;
    }

    // 箭头运算符
    T* operator->() const {
        return ptr;
    }

    // 获取引用计数
    std::size_t ./use_count() const {
        return ref_count ? ref_count->load(std::memory_order_acquire) : 0;
    }

    // 获取原始指针
    T* get() const {
        return ptr;
    }

    // 重置指针
    void reset(T* p = nullptr) {
        release();
        ptr = p;
        ref_count = p ? new std::atomic<std::size_t>(1) : nullptr;
    }
};

测试代码testExample.cc如下

#include <iostream>
#include "shared_ptr.h"
#include <thread>
#include <vector>
#include <chrono>

void test_shared_ptr_thread_safety() {
    shared_ptr<int> ptr(new int(42));

    // 创建多个线程,每个线程都增加和减少引用计数
    const int num_threads = 10;
    std::vector<std::thread> threads;
    for (int i = 0; i < num_threads; ++i) {
        threads.emplace_back([&ptr]() {
            for (int j = 0; j < 100; ++j) {
                shared_ptr<int> local_ptr(ptr);
                // 短暂暂停,增加线程切换的可能性
                std::this_thread::sleep_for(std::chrono::milliseconds(1));
            }
        });
    }

    // 等待所有线程完成
    for (auto& thread : threads) {
        thread.join();
    }

    // 检查引用计数是否正确
    std::cout << "use_count: " << ptr.use_count() << std::endl;
    if (ptr.use_count() == 1) {
        std::cout << "Test passed: shared_ptr is thread-safe!" << std::endl;
    } else {
        std::cout << "Test failed: shared_ptr is not thread-safe!" << std::endl;
    }
}

// 测试代码
int main() {
    shared_ptr<int> ptr1(new int(10));
    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;  // 1

    {
        shared_ptr<int> ptr2 = ptr1;
        std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;  // 2
        std::cout << "ptr2 use_count: " << ptr2.use_count() << std::endl;  // 2
    }

    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;  // 1

    shared_ptr<int> ptr3(new int(20));
    ptr1 = ptr3;
    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;  // 2
    std::cout << "ptr3 use_count: " << ptr3.use_count() << std::endl;  // 2

    ptr1.reset();
    std::cout << "ptr1 use_count: " << ptr1.use_count() << std::endl;  // 0
    std::cout << "ptr3 use_count: " << ptr3.use_count() << std::endl;  // 1

    test_shared_ptr_thread_safety();
    return 0;
}

输出结果:

ptr1 use_count: 1
ptr1 use_count: 2
ptr2 use_count: 2
ptr1 use_count: 1
ptr1 use_count: 2
ptr3 use_count: 2
ptr1 use_count: 0
ptr3 use_count: 1
use_count: 1
Test passed: shared_ptr is thread-safe!

补充知识

(1)shared_ptr中的引用计数,可以使用std::atomic来管理

std::atomic<std::size_t>* ref_count;    // 原子引用计数

进而达到以下目的:

  1. 原子操作:对原子变量的操作是不可分割的,意味着在多线程中不会被打断。
  2. 原子变量:一个变量可以被声明为原子类型(如 std::atomic<int>),它保证在多线程环境下对该变量的操作是安全的

更具体的,我们会用到atomic的以下方法:

fetch_add() 和 fetch_sub()

  • fetch_add():执行原子加法操作,并返回旧值。
  • fetch_sub():执行原子减法操作,并返回旧值。

memory_order(内存序)
std::atomic 的操作可以指定不同的内存顺序(memory ordering),控制不同线程之间的操作顺序。这对于高效并发编程非常重要。常见的内存顺序有:

  • memory_order_relaxed:不保证其他线程与该线程的操作顺序。
  • memory_order_consume:保证后续操作依赖于当前操作。
  • memory_order_acquire:保证所有的读取操作不会在当前操作之前执行。
  • memory_order_release:保证所有的写操作不会在当前操作之后执行。
  • memory_order_acq_rel:同时保证 acquire 和 release。
  • memory_order_seq_cst:最强的内存顺序,保证所有操作的顺序一致。

注:在shared_ptr的实现代码中,如果对内存序这个概念不熟,所有出现它的地方都可以不填,默认使用memory_order_seq_cst

(2)线程安全

如果不用std::atomic来管理引用计数,那么可以用mutex(互斥锁),所有对ref_count的操作都要加上mutex。

(3)构造函数与析构函数

  • 1. 默认构造函数 (`Default Constructor`): 用于创建类的对象。如果没有定义任何构造函数,编译器会生成一个默认构造函数。
  • 2. 拷贝构造函数 (`Copy Constructor`): 接收同类的另一个对象的引用,用于通过已存在的对象来初始化新对象的成员。
  • 3. 拷贝赋值操作符 (`Copy Assignment Operator`): 用于将一个对象的内容复制到另一个已经存在的对象中。
  • 4. 移动构造函数 (`Move Constructor`): C++11 引入。如果可能,用于将一个对象的资源“移动”到新创建的对象中,而非复制
  • 5. 移动赋值操作符 (`Move Assignment Operator`): C++11 引入。用于将一个对象的资源转移给另一个已经存在的对象。
  • 6. 析构函数 (`Destructor`): 当对象的生命周期结束时被调用,用于执行清理工作,如释放资源。
     
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值