用互斥量和条件变量实现生产者和消费着模型
- 应用在在线学习SGD中,普通的SGD是单线程计算,可以用生产者和消费着模型模型加速计算,用多个线程读取文件放在一个缓冲区中,再用一个线程从缓冲区读取数据进行SGD优化。
C++代码如下
#include "reader.h"
#include <pthread.h>
#include <fstream>
#include <vector>
#include <list>
#include <iostream>
#include <boost/algorithm/string.hpp>
#include <boost/lexical_cast.hpp>
using namespace std;
using namespace boost;
list<vector<double> > _x_list; //存储样本特征的缓冲区
list<vector<int> > _id_list; //存储样本特征id的缓冲区
list<double> _y_list; //存储y值的缓冲区
int size; //用于记录缓冲区中的数据条数
pthread_mutex_t _f_lock; //文件锁
pthread_mutex_t _m_lock; //缓存锁
pthread_cond_t _r_ready;
pthread_cond_t _w_ready;
int EXIT_NUM = 0;
int THREAD_NUM = 0;
ifstream _fp; //读文件的指针
int _n = 0; //当前缓冲区中存储的样本个数
//
int _N = 4; //缓存允许存的最大数据个数
int _f_mark = 1; //标记文件是否读完
int K = 1;
int K1 = 1;
void* ReadFile2Memory(void*);
void* ReadFromMemory(void*);
int ReadLine(vector<double>& xi, vector<int>& id, double& yi);
void Reader(string path, int thread_num = 2)
{
THREAD_NUM = thread_num;
//打开文件
_fp.open(path.c_str(), ios::in);
if(!_fp)
{
//cout << "open file " << path << " failed." << endl;
return;
}
//初始化互斥量
if(pthread_mutex_init(&_f_lock, NULL) != 0)
{
//cout << "init file thread lock failed." << end;
return;
}
if(pthread_mutex_init(&_m_lock, NULL) != 0)
{
//cout << "init memory thread lock failed." << end;
return;
}
//初始化条件变量
if(pthread_cond_init(&_r_ready, NULL) != 0)
{
return;
}
if(pthread_cond_init(&_w_ready, NULL) != 0)
{
return;
}
//启动读取数据的线程
pthread_t* tid = (pthread_t*)malloc(sizeof(pthread_t) * thread_num);
for(int i = 0; i < thread_num; i++)
{
int err = pthread_create(&tid[i], NULL,ReadFile2Memory ,NULL);
if(err != 0)
{
return;
}
}
pthread_t m_tid;
int err2 = pthread_create(&m_tid, NULL,ReadFromMemory ,NULL);
if(err2 != 0)
{
cout << "create read memory thread faild." << endl;
return;
}
void* tret;
pthread_join(m_tid, &tret);
}
void* ReadFile2Memory(void*)
{
string line;
while(1)
{
pthread_mutex_lock(&_f_lock); //锁文件
istream& t = std::getline(_fp, line);
if(t == NULL || (line.compare("") == 0))
{
pthread_mutex_unlock(&_f_lock); //释放文件锁
_f_mark--;
EXIT_NUM++;
pthread_exit((void*)2);;
}
pthread_mutex_unlock(&_f_lock); //释放文件锁
vector<double> xi;
vector<int> id;
double yi;
Parser(line, xi, id, yi);
pthread_mutex_lock(&_m_lock); //锁住内存
while(_n >= _N)
{
pthread_cond_wait(&_w_ready, &_m_lock);
} //如果缓存中数据大于给定阈值,等待
_x_list.push_back(xi);
_id_list.push_back(id);
_y_list.push_back(yi);
_n++; //将当前的缓冲区数据加一
pthread_mutex_unlock(&_m_lock);
}
}
int ReadLine(vector<double>& xi, vector<int>& id, double& yi)
{
//cout << "ReaderLine _n = " << _n << " _f_mark = "<< _f_mark << endl;
if((EXIT_NUM == THREAD_NUM) & (_n <= 0))
{
return 0;
}
if((EXIT_NUM != THREAD_NUM) & (_n <= 0))
{
return 2;
}
pthread_mutex_lock(&_m_lock); //从缓存读数据前,锁住
//cout << "ReadLine 锁内存 =" << endl;
xi = *(_x_list.begin());
_x_list.pop_front();
id = *(_id_list.begin());
_id_list.pop_front();
yi = *(_y_list.begin());
_y_list.pop_front();
_n--;
pthread_mutex_unlock(&_m_lock);
pthread_cond_signal(&_w_ready);
return 1;
}
void* ReadFromMemory(void*)
{
vector<double> xi;
vector<int> id;
double yi;
int n;
while(1)
{
n = ReadLine(xi, id, yi);
if(n == 0)
{
cout << "read thread " << pthread_self() << " exit." << endl;
pthread_exit((void*)2);;
}
else if(n == 2)
{
//cout << "文件没读完,缓冲区为空" << endl;
continue;
}
cout << n << "," << K++ << " = " << yi << endl;
}
pthread_exit((void*)2);
}
void Close()
{
//关闭文件
_fp.close();
//销毁文件线程锁
pthread_mutex_destroy(&_f_lock);
//销毁内存线程锁
pthread_mutex_destroy(&_m_lock);
//销毁条件变量
pthread_cond_destroy(&_r_ready);
pthread_cond_destroy(&_w_ready);
}
void Parser(string& line, vector<double>& xi, vector<int>& id, double& yi)
{
xi.clear();
id.clear();
vector<string> ele;
split(ele, line, is_any_of(" "));
yi = lexical_cast<double>(ele[0]);
for(vector<string>::iterator it = ele.begin() + 1; it != ele.end(); it++)
{
vector<string> tmp;
split(tmp, *it, is_any_of(":"));
id.push_back(lexical_cast<int>(tmp[0]));
xi.push_back(lexical_cast<double>(tmp[1]));
}
}
测试代码如下
#include <string>
#include <iostream>
#include <vector>
#include <boost/lexical_cast.hpp>
#include "reader.h"
using namespace std;
using namespace boost;
void Reader(string path, int thread_num = 2);
int main(int argc, char** argv)
{
string path("mnist_train01.txt");
int n = lexical_cast<int>(argv[1]);
Reader(path, n);
Close();
return 0;
}