实现引用计数线程安全的shared_ptr

c++11引入了三个智能指针,用来自动管理内存,使用智能指针可以有效地减少内存泄漏。

其中,shared_ptr是共享智能指针,可以被多次拷贝,拷贝时其内部的引用计数+1,被销毁时引用计数-1,如果引用计数为0,那么释放其所管理的资源

线程安全上,shared_ptr具有如下特点:

  • shared_ptr的引用计数是线程安全的
  • 修改shared_ptr不是线程安全的
  • 读写shared_ptr管理的数据不是线程安全的

具体可以参考:https://zhuanlan.zhihu.com/p/664993437

在网上找到的shared_ptr的手动实现都是线程不安全的,那么如何实现一个引用计数线程安全的shared_ptr呢?

参考:从零简单实现一个线程安全的C++共享指针(shared_ptr)-优快云博客,本文在这篇博客的基础上增加了验证代码,并指出原有实现一个潜在的bug

#include <iostream>
#include <atomic>
#include <mutex>
#include <thread>
#include <vector>

using namespace std;

#define N 10000

class Counter
{
public:
    Counter() { count = 1; }

    void add() {
        lock_guard<std::mutex> lk(mutex_);
        count++; 
    }

    void sub() {
        lock_guard<std::mutex> lk(mutex_);
        count--;
    }

    int get() {
        lock_guard<std::mutex> lk(mutex_);
        return count; 
    }

private:
    int count;
    std::mutex mutex_;
};

template <typename T>
class Sp
{
public:
	Sp();                           //默认构造函数
	~Sp();                          //析构函数
	Sp(T *ptr);                     //参数构造函数
	Sp(const Sp &obj);              //复制构造函数
	Sp &operator=(const Sp &obj);   //重载=
	T *get();                       //得到共享指针指向的类
	   
	int getcount();                 //得到引用计数器
private:
	
	T *my_ptr;                      //共享指针所指向的对象
	Counter* counter;                   //引用计数器
	void clear();                   //清理函数
};
 
//默认构造函数,参数为空,构造一个引用计数器
template<typename T>
Sp<T>::Sp()
{
	my_ptr = nullptr;
	counter = new Counter();
}
 
//复制构造函数,新的共享指针指向旧的共享指针所指对象
template<typename T>
Sp<T>::Sp(const Sp &obj)
{
    //将所指对象也变为目标所指的对象
	my_ptr = obj.my_ptr;
    //获取引用计数器,使得两个共享指针用一个引用计数器
	counter = obj.counter;
    //使这个对象的引用计数器+1
	counter->add();	
};
 
//重载=
template<typename T>
Sp<T> &Sp<T>::operator=(const Sp&obj)
{
    //清理当前所引用对象和引用计数器
	clear();

    //指向新的对象,并获取目标对象的引用计数器
	my_ptr = obj.my_ptr;
	counter = obj.counter;
    //引用计数器+1
	counter->add();
    //返回自己
	return *this;	
}
 
//创建一个共享指针指向目标类,构造一个新的引用计数器
template<typename T>
Sp<T>::Sp(T *ptr)
{
	my_ptr = ptr;
	counter = new Counter();
}
 
//析构函数,出作用域的时候,调用清理函数
template<typename T>
Sp<T>:: ~Sp()
{
	clear();
}
 
//清理函数,调用时将引用计数器的值减1,若减为0,清理指向的对象内存区域
template<typename T>
void Sp<T>::clear()
{
    //引用计数器-1
	counter->sub();
    //如果引用计数器变为0,清理对象
	if(0 == counter->get())
	{
        // 这里有个bug,如果在此间隙处,有另外一个地方执行了share ptr的copy操作,则会crash
		if(my_ptr)
		{
			delete my_ptr;
		}
		delete counter;
	}
}
 
//当前共享指针指向的对象,被几个共享指针所引用
template<typename T>
int Sp<T>::getcount()
{
	return counter->get();	
};
 
class A{
public:
	A(){ cout<<"A construct!"<<endl; };
	~A() { cout<<"A destruct!"<<endl; };
};

Sp<A> sp(new A);
std::vector<Sp<A>> vec1(N);
std::vector<Sp<A>> vec2(N);

Sp<A> sp1(new A);
Sp<A> sp2(new A);
Sp<A> sp3(new A);


void thread_func1() {
    for(int i = 0; i < N; i++) {
        vec1[i] = sp;
    }
}

void thread_func2() {
    for(int i = 0; i < N; i++) {
        vec2[i] = sp;
    }
}

void test_crash_func1() {
    sp1 = sp2;
}

void test_crash_func2() {
    sp3 = sp1;
}

void test_crash() {
    for(int i = 0; i < 10 * N; i++) {
        std::thread t1(test_crash_func1);
        std::thread t2(test_crash_func2);
        t1.join();
        t2.join();
    }
}



int main()
{
    std::thread t1(thread_func1);
    std::thread t2(thread_func2);
    t1.join();
    t2.join();
    std::cout<<"the count is:"<<sp.getcount()<<std::endl;

    test_crash();
}

按理说调用test_crash应该会导致crash才对,但是不知道为什么没有crash

TODO:使用原子操作实现,对比性能

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值