紧接着前两篇文章:
生产者-消费者模型_cxb1998的博客-优快云博客和RAII + 接口模式_cxb1998的博客-优快云博客
这篇文章学习如何用RAII+接口模式对生产者—消费者模型进行封装。
infer.hpp
#ifndef INFER_HPP
#define INFER_HPP
#include<memory>
#include<string>
#include<future>
class InferInterface{
public:
virtual std::shared_future<std::string> forward(std::string pic) = 0;
};
std::shared_ptr<InferInterface> create_infer(const std::string& file);
#endif /. INFER_HPP
infer.cpp
#include"infer.hpp"
#include<thread>
#include<queue>
#include<mutex>
#include<future>
#include<condition_variable>
using namespace std;
struct Job{
shared_ptr<promise<string>> pro;
string input;
};
//消费者
class Infer : public InferInterface{
public:
virtual ~Infer(){
worker_running_= false;
cv_.notify_one();
if(worker_thread_.joinable()){
worker_thread_.join();
}
}
bool load_model(const string& file){
promise<bool> pro;
worker_running_ = true;
worker_thread_ = thread(&Infer::worker,this,file, std::ref(pro));
// context_ = file;
return pro.get_future().get();
}
virtual shared_future<string> forward(string pic) override{
//正常逻辑
// printf("使用%s进行推理\n",context.c_str());
//往队列抛任务
Job job;
job.pro.reset(new promise<string>());
job.input = pic;
lock_guard<mutex> l(job_lock_);
qjobs.push(job);
//采用通知推理,一旦有新的数据需要处理,通知消费者
//发送通知
cv_.notify_one();
return job.pro->get_future();
}
// void destroy(){
// context_.clear();
// }
void worker(string file,promise<bool>& pro){
//worker内实现模型加载、使用和释放
string context = file;
if(context.empty()){
pro.set_value(false);
return;
}else{
pro.set_value(true);
}
int max_batch_size = 5;
vector<Job> jobs;
int batch_id=0;
//从队列里面取任务
while(worker_running_){
//等待接受通知
unique_lock<mutex> l(job_lock_);
cv_.wait(l,[&](){
//为true退出等待
//为false等待
return !qjobs.empty() || !worker_running_;
});
if(!worker_running_){
break;
}
//批处理,一次拿一批
while(jobs.size() < max_batch_size && !qjobs.empty()){
jobs.emplace_back(qjobs.front());
qjobs.pop();
}
//batch inference
for(int i=0;i<jobs.size();i++){
auto& job = jobs[i];
char name[100];
sprintf(name,"%s : batch->%d[%d]",job.input.c_str(),batch_id++,jobs.size());
job.pro->set_value(name);
}
jobs.clear();
}
context.clear();
printf("释放模型");
printf("worker done\n");
}
private:
atomic<bool> worker_running_{false};
thread worker_thread_;
queue<Job> qjobs;
mutex job_lock_;
condition_variable cv_;
};
//RAII
//获取infer实例,即表示加载模型
//加载模型失败则表示资源获取失败
//加载模型成功则表示资源获取成功
shared_ptr<InferInterface> create_infer(const string& file){
shared_ptr<Infer> instance(new Infer());
//如果load_model失败则instance为空指针
if(!instance->load_model(file)){
instance.reset();
}
return instance;
}
//接口模式
//1. 拒绝外部调用load_model
//2. 解决成员变量外部可见的问题
// 例如当成员变量是特殊类型如cudaStream_t,那么使用者必然会包含头文件cuda_runtime.h
// 此时会造成头文件污染
main.cpp:
#include"infer.hpp"
int main(){
auto infer = create_infer("A");
if(infer == nullptr){
printf("failed");
return -1;
}
auto fa = infer->forward("A");
auto fb = infer->forward("B");
auto fc = infer->forward("C");
printf("%s\n,%s\n,%s\n",fa.get().c_str(),fb.get().c_str(),fc.get().c_str());
return 0;
}