epoll+协程

本文介绍了一个自定义的日志类Fiber,用于实现协程,支持在单线程中执行异步任务。Fiber类包含了状态管理、栈分配等功能,允许在执行过程中通过Yeild2Hold和Yeild2Ready进行上下文切换。通过Fiber的使用,可以在函数调用之间自由切换,实现类似函数指针的功能但具备更灵活的控制流程。示例代码展示了如何在epoll事件驱动中结合Fiber实现客户端连接和读取的异步处理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

自己看代码

忽略LOG代码,自己实现的日志类,可以替换为printf
忽略命令空间,可以把其去掉
// fiber.h
#ifndef __FIBER_H__
#define __FIBER_H__

#include <stdio.h>
#include <ucontext.h>
#include <functional>
#include <memory>

namespace Jarvis {
/**
 * @brief 协程类
 */ 
class Fiber : public std::enable_shared_from_this<Fiber>
{
public:
    typedef std::shared_ptr<Fiber> sp;
    enum FiberState {
        HOLD,       // 暂停状态
        EXEC,       // 执行状态
        TERM,       // 结束状态
        READY,      // 可执行态
        EXCEPT      // 异常状态
    };
    Fiber(std::function<void()> cb, uint64_t stackSize = 0);
    ~Fiber();

           void         Reset(std::function<void()> cb);
    static void         SetThis(Fiber *f);  // 设置当前正在执行的协程
    static Fiber::sp    GetThis();          // 获取当前正在执行的协程
           void         Resume();           // 唤醒协程
    static void         Yeild2Hold();       // 将当前正在执行的协程让出执行权给主协程,并设置状态为HOLD
    
private:
    static void         Yeild2Ready();      // 将当前正在执行的协程让出执行权给主协程,并设置状态为READY
    Fiber();                    // 线程的第一个协程调用
    static void FiberEntry();   // 协程入口函数
    void SwapIn();              // 切换到前台, 获取执行权限
    void SwapOut();             // 切换到后台, 让出执行权限

private:
    ucontext_t      mCtx;
    FiberState      mState;
    uint64_t        mFiberId;
    uint64_t        mStackSize;
    void *          mStack;
    std::function<void()> mCb;
};
} // namespace Jarvis

#endif // __FIBER_H__
// fiber.cpp
#include "fiber.h"
#include <log/log.h>
#include <atomic>
#include <exception>

#ifdef LOG_TAG
#undef LOG_TAG
#define LOG_TAG "fiber"
#else
#define LOG_TAG "fiber"
#endif

namespace Jarvis{

static std::atomic<uint64_t> gFiberId(0);       // 协程ID
static std::atomic<uint64_t> gFiberCount(0);    // 当前协程总数

static thread_local Fiber *gCurrFiber = nullptr;            // 当前正在执行的协程
static thread_local Fiber::sp gThreadMainFiber = nullptr;   // 一个线程的主协程

// TODO: 配置yaml文件,获取栈大小
uint64_t getStackSize()
{
    static uint64_t size = 1024 * 1024;
    if (size == 0) {
        // size = Config::Lookup<uint64_t>("fiber.stack_size", 1024 * 1024);
    }
    return size;
}

class MallocAllocator
{
public:
    static void *alloc(uint64_t size)
    {
        return malloc(size);
    }
    static void dealloc(void *ptr, uint64_t size)
    {
        LOG_ASSERT(ptr, "dealloc a null pointer");
        free(ptr);
    }
};

using Allocator = MallocAllocator;

Fiber::Fiber() :
    mFiberId(++gFiberId)
{
    ++gFiberCount;
    mState = EXEC;
    if (getcontext(&mCtx)) {
        LOG_ASSERT(false, "getcontext error, %d %s", errno, strerror(errno));
    }
    SetThis(this);
    LOGD("Fiber::Fiber() start id = %d, total = %d", mFiberId, gFiberCount.load());
}

Fiber::Fiber(std::function<void()> cb, uint64_t stackSize) :
    mFiberId(++gFiberId),
    mCb(cb),
    mState(READY)
{
    ++gFiberCount;

    mStackSize = stackSize ? stackSize : getStackSize();
    mStack = Allocator::alloc(mStackSize);
    if (!mStack) {
        LOGD("Fiber id = %lu, stack pointer is null", mFiberId);
    }
    if (getcontext(&mCtx)) {
        LOG_ASSERT(false, "Fiber::Fiber(std::function<void()>, uint64_t) getcontext error, %d %s",
            errno, strerror(errno));
    }
    mCtx.uc_stack.ss_sp = mStack;
    mCtx.uc_stack.ss_size = mStackSize;
    mCtx.uc_link = nullptr;
    makecontext(&mCtx, &FiberEntry, 0);

    LOGD("Fiber::Fiber(std::function<void()>, uint64_t) id = %lu, total = %d start",
        mFiberId, gFiberCount.load());
}

Fiber::~Fiber()
{
    --gFiberCount;
    LOGD("Fiber::~Fiber() id = %lu, total = %d", mFiberId, gFiberCount.load());
    if (mStack) {
        LOG_ASSERT(mState == TERM || mState == EXCEPT,
            "file %s, line %d", __FILE__, __LINE__);
        Allocator::dealloc(mStack, mStackSize);
    } else {    // main fiber
        LOG_ASSERT(!mCb, "");
        if (gCurrFiber == this) {
            SetThis(nullptr);
        }
    }
}

// 调用位置在主协程中。
void Fiber::Reset(std::function<void()> cb)
{
    LOG_ASSERT(mStack, "main fiber can't reset"); // 排除main fiber
    // 暂停态,执行态,ready态无法reset
    LOG_ASSERT(mState == TERM || mState == EXCEPT, "Reset unauthorized operation");
    mCb = cb;
    if (getcontext(&mCtx)) {
        LOG_ASSERT(false, "File %s, Line %s. getcontex error.", __FILE__, __LINE__);
    }
    mCtx.uc_stack.ss_sp = mStack;
    mCtx.uc_stack.ss_size = mStackSize;
    mCtx.uc_link = nullptr;
    makecontext(&mCtx, &FiberEntry, 0);
    mState = READY;
}

/**
 * 同一个线程应使SwapIn的调用次数比Yeild调用次数多一,如果多二则会在FiberEntry结尾处退出线程
 * 如果不多一则会使只能指针的引用计数大于1,导致释放不干净,原因是FiberEntry的回调未执行完毕,
 * 即最后一次Yeild操作保存了堆栈,使得在栈上的只能无法释放
 */
void Fiber::SwapIn()
{
    SetThis(this);
    LOG_ASSERT(mState != EXEC, "");
    mState = EXEC;
    if (swapcontext(&gThreadMainFiber->mCtx, &mCtx)) {
        LOG_ASSERT(false, "SwapIn() id = %d, errno = %d, %s", mFiberId, errno, strerror(errno));
    }
}

void Fiber::SwapOut()
{
    SetThis(gThreadMainFiber.get());
    if (swapcontext(&mCtx, &gThreadMainFiber->mCtx)) {
        LOG_ASSERT(false, "SwapIn() id = %d, errno = %d, %s", mFiberId, errno, strerror(errno));
    }
}

void Fiber::SetThis(Fiber *f)
{
    gCurrFiber = f;
}

Fiber::sp Fiber::GetThis()
{
    if (gCurrFiber) {
        return gCurrFiber->shared_from_this();
    }
    Fiber::sp fiber(new Fiber());
    LOG_ASSERT(fiber.get() == gCurrFiber, "");
    gThreadMainFiber = fiber;
    LOG_ASSERT(gThreadMainFiber, "");
    return gCurrFiber->shared_from_this();
}

void Fiber::Resume()
{
    SwapIn();
}

void Fiber::Yeild2Hold()
{
    Fiber::sp ptr = GetThis();
    LOG_ASSERT(ptr != nullptr, "");
    LOG_ASSERT(ptr->mState == EXEC, "");
    ptr->mState = HOLD;
    ptr->SwapOut();
}

void Fiber::Yeild2Ready()
{
    Fiber::sp ptr = GetThis();
    LOG_ASSERT(ptr != nullptr, "");
    LOG_ASSERT(ptr->mState == EXEC, "");
    ptr->mState = READY;
    ptr->SwapOut();
}

void Fiber::FiberEntry()
{
    Fiber::sp curr = GetThis();
    LOG_ASSERT(curr != nullptr, "");
    try {
        curr->mCb();
        curr->mCb = nullptr;
        curr->mState = TERM;
    } catch (const std::exception& e) {
        curr->mState = EXCEPT;
        LOGE("Fiber except: %s; id = %lu", e.what(), curr->mFiberId);
    } catch (...) {
        curr->mState = EXCEPT;
        LOGE("Fiber except. id = %lu", curr->mFiberId);
    }

    Fiber *ptr = curr.get();
    curr.reset();
    ptr->SwapOut();

    LOG_ASSERT(false, "never reach here");
}

}
// epoll.cpp
#include "fiber.h"
#include <log/log.h>
#include <utils/Errors.h>
#include <utils/thread.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <errno.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <arpa/inet.h>
#include <sys/time.h>
#include <sys/epoll.h>

#define LOG_TAG "epoll"
#define EPOLL_EVENT_SIZE 512
#define RD_BUF_SIZE 512
#define WR_BUF_SIZE 512



int InitSocket(uint16_t port = 8080)
{
    int sock = ::socket(AF_INET, SOCK_STREAM, 0);
    if (sock < 0) {
        LOGE("socket error. error code = %d, error message: %s", errno, strerror(errno));
        return Jarvis::UNKNOWN_ERROR;
    }
    sockaddr_in server_addr;
    bzero(&server_addr, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(port);
    server_addr.sin_addr.s_addr = inet_addr("127.0.0.1");;

    int opt = 1;
    setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));

    if (::bind(sock, (sockaddr *)&server_addr, sizeof(server_addr)) < 0) {
        LOGE("bind error. error code = %d, error message: %s", errno, strerror(errno));
        goto error_return;
    }
    if (::listen(sock, 128) < 0) {
        LOGE("listen error. error code = %d, error message: %s", errno, strerror(errno));
        goto error_return;
    }
    return sock;

error_return:
    ::close(sock);
    return Jarvis::UNKNOWN_ERROR;
}

static int gServerSocket = 0;
static int gEpollFd = 0;

void AcceptClient()
{
    LOGI("%s() begin", __func__);
    if (!gServerSocket || !gEpollFd) {
        return;
    }
    static epoll_event event;
    sockaddr_in client_addr;
    bzero(&client_addr, sizeof(client_addr));
    socklen_t addrLen = 0;
    int clientSock = 0;
    clientSock = ::accept(gServerSocket, (sockaddr *)&client_addr, &addrLen);
    if (clientSock < 0) {
        LOGE("accept error. error code = %d, error message: %s", errno, strerror(errno));
        return;
    }
    LOGI("[IP: %s, port: %u] connected.", inet_ntoa(client_addr.sin_addr),
        ntohs(client_addr.sin_port));
    // 设置非阻塞
    int flags = fcntl(clientSock, F_GETFL, 0);
    fcntl(clientSock, F_SETFL, flags | O_NONBLOCK);
    event.data.fd = clientSock;
    event.events = EPOLLIN | EPOLLET;
    if (epoll_ctl(gEpollFd, EPOLL_CTL_ADD, clientSock, &event) != 0) {
        LOGE("epoll_ctl error. error code = %d, error message: %s", errno, strerror(errno));
    }
    LOGI("%s() end", __func__);
}

static epoll_event gEvent;
static int gClientFd = 0;
void ReadFormClient()
{
    LOGI("%s() begin", __func__);
    if (gClientFd <= 0) {
        return;
    }
    char readBuf[RD_BUF_SIZE] = {0};
    bzero(readBuf, RD_BUF_SIZE);
    int readSize = ::read(gClientFd, readBuf, RD_BUF_SIZE);
    if (readSize < 0) {
        LOGE("read error. error code = %d, error message: %s", errno, strerror(errno));
        if (errno == ECONNRESET) {
            epoll_ctl(gEpollFd, EPOLL_CTL_DEL, gClientFd, &gEvent);
            close(gClientFd);
        }
    } else if (readSize == 0) {
        // client quit
        LOGI("client quit fd = %d", gClientFd);
        epoll_ctl(gEpollFd, EPOLL_CTL_DEL, gClientFd, &gEvent);
        close(gClientFd);
    } else {
        LOGI("recv buf: %s\n", readBuf);
    }
    LOGI("%s() end", __func__);
    gClientFd = 0;
}

int main(int argc, char *argv[])
{
    int serverSock = InitSocket();
    if (serverSock < 0) {
        return 0;
    }

    char readBuf[RD_BUF_SIZE] = {0};
    int allSock[EPOLL_EVENT_SIZE] = {0};
    int clientSock = -1;
    sockaddr_in client_addr;
    bzero(&client_addr, sizeof(client_addr));
    socklen_t addrLen = 0;

    struct epoll_event event;
    struct epoll_event allEvent[EPOLL_EVENT_SIZE];

    int epollfd = epoll_create(EPOLL_EVENT_SIZE);
    event.events = EPOLLIN | EPOLLET;   // 边沿触发,数据到达才会触发
    event.data.fd = serverSock;
    epoll_ctl(epollfd, EPOLL_CTL_ADD, serverSock, &event);
    gServerSocket = serverSock;
    gEpollFd = epollfd;
    Jarvis::Fiber::GetThis();
    Jarvis::Fiber::sp acceptFiebr(new Jarvis::Fiber(AcceptClient));
    Jarvis::Fiber::sp readFiber(new Jarvis::Fiber(ReadFormClient));

    LOGI("server fd[%d] waiting for client...", serverSock);
    while (1) {
        int ret = epoll_wait(epollfd, allEvent, EPOLL_EVENT_SIZE, -1);
        if (ret < 0) {
            LOGE("epoll_wait error. error code = %d, error message: %s", errno, strerror(errno));
            if (errno == EINTR) {
                continue;
            }
            break;
        } else if (ret > 0) {
            LOGI("event size = %d", ret);
            for (int i = 0; i < ret; ++i) {
                int fdtmp = allEvent[i].data.fd;
                gClientFd = fdtmp;
                LOGI("fd = %d\n", fdtmp);
                if (fdtmp == serverSock) {    // accept event
                    acceptFiebr->Resume();
                    acceptFiebr->Reset(AcceptClient);
                } else if (allEvent[i].events | EPOLLIN){    // read event
                    readFiber->Resume();
                    readFiber->Reset(ReadFormClient);
                }
            }
        }
    }
    exit(0);	// 不想释放内存
}
// 编译
g++ epoll.cpp fiber.cpp -o epoll -std=c++11 -g

从执行效果来看,如果不调用Yeild2,则和函数指针没有什么区别
但为什么还要选择协程呢?是因为协程可以在单线程中执行异步,从这个函数(1)调用到另一个函数(2),在函数(2)还未结束之前调用Yeild可以返回上一个函数(1),函数(1)在调用resume可以继续按刚才的进度执行函数(2)
这是函数指针所不能实现的

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值