声明:本文参考自《C++并发编程实战(第二版)》
本文试图根据书中的代码在力所能及范围内实现一个stack类,虽然底层用了现成的stack,比较丑陋(请忽略)
在实现的类中功能自认为不够全面,后面写TinySTL项目的时候再模板化和追加功能吧.
5-17 更新,追加swap方法的类内实现和友元,改写push方法,追加try_pop和wait_and_pop方法
// 实现一个多线程安全的栈
// 用一个不会抛出异常的类防止异常扩散
struct empty_stack :std::exception
{
const char ch = 'n';
const char* what() const throw() { return &ch; }
};
template<class T>
class threadsafe_stack;
template<class T>
void swap(threadsafe_stack<T>& l, threadsafe_stack<T>& r);
template<class T>
bool operator==(threadsafe_stack<T>&, threadsafe_stack<T>&);
template<class T>
class threadsafe_stack
{
private:
stack<T>data;
mutable std::mutex m;
mutable std::shared_mutex sm;
std::condition_variable dt;
public:
threadsafe_stack() {}
threadsafe_stack(const threadsafe_stack& other)
{
std::lock_guard<std::mutex>lck(other.m);
data = other.data;
}
// 对于是否在取数据时加锁,我的评价是,服从比较的逻辑
// 如果强调比较时的瞬时性,建议在取数据之前就加锁
std::mutex& get_mutex()
{
return this->m;
}
stack<T>& get_data()
{
return this->data;
}
// 取消掉等号赋值,将这个功能内化到构造函数
threadsafe_stack& operator=(const threadsafe_stack&) = delete;
bool empty() const
{
std::lock_guard<std::mutex>lck(m);
return data.empty();
}
void swap(threadsafe_stack<T>& other) noexcept
{
// c++11 or c++14
//lock(this->m, other.m);
//std::lock_guard<std::mutex>lck1(other.m, std::adopt_lock);
//std::lock_guard<std::mutex>lck2(this->m, std::adopt_lock);
// c++17
std::scoped_lock guard(this->m, other.m);
this->data.swap(other.data);
}
friend void swap(threadsafe_stack<T>& l, threadsafe_stack<T>& r)
{
// c++11 or c++14
//std::lock(l.get_mutex(), r.get_mutex());
//std::lock_guard<std::mutex>lck1(l.get_mutex(), std::adopt_lock);
//std::lock_guard<std::mutex>lck2(l.get_mutex(), std::adopt_lock);
// c++17
std::scoped_lock guard(l.get_mutex(), r.get_mutex());
std::swap(l.get_data(), r.get_data());
}
// top 和 size 两个函数都不是多线程安全的,应该在具体的上下文情境中加锁使用
// 但是函数功能本身并没有问题
/*
_NODISCARD unsigned int size()
{
lock_guard<std::mutex>lck(m);
return data.size();
}
*/
T top()
{
std::lock_guard<std::shared_mutex>lck(sm);
if (data.empty())
throw empty_stack();
return data.top();
}
std::shared_ptr<T>top(T)
{
std::lock_guard<std::shared_mutex>lck(sm);
if (data.empty())
throw empty_stack();
const std::shared_ptr<T>res(std::make_shared<T>(data.top()));
return res;
}
void push(const T& val)
{
std::lock_guard<std::mutex>lck(m);
data.push(val);
dt.notify_one();
}
void push(T&& val)
{
std::lock_guard<std::mutex>lck(m);
data.push(std::move(val));
dt.notify_one();
}
// 提供几个重载版本
std::shared_ptr<T>pop()
{
std::lock_guard<std::mutex>lck(m);
if (data.empty())
throw empty_stack();
const std::shared_ptr<T> res(std::make_shared<T>(data.top()));
data.pop();
return res;
}
void pop(T& value)
{
std::lock_guard<std::mutex>lck(m);
if (data.empty())
throw empty_stack();
value = data.top();
data.pop();
}
bool try_pop(T& value)
{
std::lock_guard<std::mutex>lck(m);
if (data.empty())
return false;
value = data.top();
data.pop();
return true;
}
std::shared_ptr<T>try_pop()
{
std::lock_guard<std::mutex>lck(m);
if (data.empty())
return std::shared_ptr<T>();
std::shared_ptr<T>res(std::make_shared<T>(data.top()));
data.pop();
return res;
}
void wait_and_pop(T& value)
{
std::unique_lock<std::mutex>lck(m);
dt.wait(lck, [this] {return !data.empty(); });
value = data.top();
data.pop();
}
std::shared_ptr<T>wait_and_pop()
{
std::unique_lock<std::mutex>lck(m);
dt.wait(lck, [this] {return !data.empty(); });
std::shared_ptr<T>res(std::make_shared<T>(data.top()));
data.pop();
return res;
}
bool operator==(const threadsafe_stack<T>& oth)
{
// c++11 or c++14
//std::lock(l.get_mutex(), r.get_mutex());
//std::lock_guard<std::mutex>lck1(l.get_mutex(), std::adopt_lock);
//std::lock_guard<std::mutex>lck2(l.get_mutex(), std::adopt_lock);
// c++17
std::scoped_lock guard(m, oth.m);
return data == oth.data;
}
friend bool operator==(threadsafe_stack<T>& l, threadsafe_stack<T>& r)
{
// c++11 or c++14
//std::lock(l.get_mutex(), r.get_mutex());
//std::lock_guard<std::mutex>lck1(l.get_mutex(), std::adopt_lock);
//std::lock_guard<std::mutex>lck2(l.get_mutex(), std::adopt_lock);
// c++17
std::scoped_lock guard(l.get_mutex(), r.get_mutex());
return l.get_data() == r.get_data();
}
};
测试代码如下:
void oper(threadsafe_stack<int>& s)
{
for (int i = 0; i < 100; i++)
{
s.push(i);
cout << "push : " << i << endl;
}
for (int i = 0; i < 50; i++)
{
shared_ptr<int> ptr = s.pop();
cout << "pop : " << *ptr << endl;
}
for (int i = 0; i < 50; i++)
{
int j = 0;
s.wait_and_pop(j);
cout << "wait_and_pop : " << j << endl;
}
if (s.empty())
cout << "empty" << endl;
}
int main()
{
threadsafe_stack<int>s1;
threadsafe_stack<int>s2;
thread obj1(oper, ref(s1));
thread obj2(oper, ref(s2));
obj1.join();
obj2.join();
threadsafe_stack<int> s3(s2);
thread obj3(oper, ref(s3));
obj3.join();
threadsafe_stack<int>s4;
for (int i = 0; i < 5; i++)
s4.push(i);
threadsafe_stack<int>s5;
for (int i = 5; i < 10; i++)
s5.push(i);
swap(s4, s5);
while (!s4.empty())
{
int i = 0;
int tmp = s4.top();
cout << tmp << " ";
s4.pop(i);
cout << i << " " << endl;
}
s4.swap(s5);
while (!s4.empty())
{
int i = 0;
shared_ptr<int>tmp = s4.top(i);
cout << *tmp << " ";
s4.pop(i);
cout << i << " " << endl;
}
threadsafe_stack<int>s6;
for (int i = 0; i < 5; i++)
s6.push(i);
threadsafe_stack<int>s7;
for (int i = 0; i < 5; i++)
s7.push(i);
bool flg1 = (s6 == s7);
bool flg2 = s6.operator==(s7);
cout << flg1 << " " << flg2 << endl;
threadsafe_stack<int>s8;
for (int i = 0; i < 5; i++)
s8.push(i);
for (int i = 0; i < 10; i++)
{
int t = 0;
bool res = s8.try_pop(t);
if (res)
cout << t << endl;
else
cout << "empty" << endl;
}
return 0;
}