【网络-编程】使用 Ring Buffer 缓存通信数据

本文介绍了网络通信中网络缓冲区的作用,包括发送和接收缓冲区对数据的缓存。还阐述了环形缓冲区的概念和特点,如数据读取后无需移动、容量固定等。同时提供了环形缓冲区的参考实现代码,以及Linux系统中服务端和客户端TCP通信测试代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

声明:仅为个人学习总结,还请批判性查看,如有不同观点,欢迎交流。

摘要

简单介绍 Ring Buffer(环形缓冲区)的概念,并提供一份参考实现代码,以及 Linux 系统中服务端和客户端 TCP 通信测试代码。


1 网络缓冲区

在进行网络通信时,应用层和网络层的数据处理速度可能会不匹配:

  • 应用层产生数据的速度,可能大于网络层发送数据的速度,表现为调用 send/write 等函数时,实际发送长度小于数据长度,此时需要通过“发送缓冲区”进行数据缓存(缓存剩余的数据,和应用层新产生的数据);
  • 同样,网络层通过 recv/read 等函数接收到的数据,应用层可能不能及时处理,此时需要通过“接收缓冲区”进行数据缓存;

另外,一次通信接收到的数据,可能只是完整消息的一部分,或者包含多条消息,此时也需要通过“接收缓冲区”来收集完整消息,或对消息进行拆分,再逐一提供给应用层处理。

因为通信数据具有先后次序,所以“缓冲区”可以采用队列结构,按照 FIFO(First In First Out)方式进行读写。

2 环形缓冲区

环形队列(circular queue)/ 环形缓冲区(ring buffer),表示为一个固定尺寸、头尾相连的数据结构,是一种常见的缓冲区实现方式。

环形缓冲区示意

环形缓冲区的一些特点:

  • 当一个数据元素被读取出后,其余数据元素不需要移动其存储位置;
  • 缓冲区的容量(大小)一般固定,适合事先明确缓冲区最大容量的情形;
    • 存储空间为一次性分配,如果扩大容量,通常需要重新分配空间并移动数据;
    • 如果缓冲区的大小需要经常调整,更加适合使用链表结构。

3 示例代码

3.0 ring buffer 数据结构

代码中的数据结构定义如下:

typedef struct ringbuffer_t
{
    unsigned char *buf; // 缓存指针
    unsigned int size;  // 缓存大小,值为 2 的整数次幂,便于“求余运算”,以及“溢出后对齐”
    unsigned int wpos;  // 累加后的写位置,数据实际在 (wpos % size) 索引位置写入
    unsigned int rpos;  // 累加后的读位置,数据实际在 (rpos % size) 索引位置读出
} ringbuffer_t;

成员说明如下:

  • 对于 size 成员:
    • 值为 2 的整数次幂,便于“求余运算”,即:position % size = position & (size - 1)
    • 例如,假设 position = 25; size = 8; 那么 11001 & 00111 = 1,与 25 % 8 结果一致
  • 对于 wposrpos 成员:
    1. 在写操作时,正向增加 wpos;在读操作时,正向增加 rpos
      利用无符号整型的性质,在达到最大值时,产生溢出,“累加位置”从 0 开始循环;
      由于 size 为 2 的幂,与无符号整型存在整数倍关系,所以“索引位置”也同步从 0 开始循环。
    2. rpos + 数据长度 = wpos,恒定成立。例如,对于溢出处理:
      假设 rposwposunsigned char,取值范围为 0~255,
      假设 rpos = 253; 数据长度 = 4; 那么 wpos = (253 + 4) % 256 = 1,位于正确位置
    3. rpos == wpos 为真时,缓冲区为空;
    4. wpos - rpos == size 为真时,缓冲区为满。

3.1 ring buffer 接口定义

/**
 * @file buffer.h
 * @brief 定义缓冲区通用接口
 */
typedef struct ringbuffer_t buffer_t;

buffer_t *buffer_create(unsigned int size); // 创建缓冲区,实际大小可能大于 size
void buffer_destroy(buffer_t **buf);        // 销毁缓冲区

unsigned int buffer_len(buffer_t *buf);    // 获取缓冲区存储数据长度
unsigned int buffer_remain(buffer_t *buf); // 获取缓冲区剩余空间大小(数据为空时,返回缓冲区大小)

/**
 * @brief 向 buf 缓冲区中写入数据
 * @details 如果缓冲区剩余空间不足,直接返回失败,不会写入任何数据。
 *
 * @param[in] buf  缓冲区指针
 * @param[in] data 待写入数据的指针
 * @param[in] len  待写入数据的长度
 *
 * @return 0 为成功;
 *         -1 为缓冲区剩余空间不足(len > 剩余空间)
 */
int buffer_write(buffer_t *buf, unsigned char *data, unsigned int len);

/**
 * @brief 从 buf 缓冲区中读取数据
 * @details 连续调用 buffer_read 会读取同一地址数据;
 *          如果读取后面数据,需要通过 buffer_drain() 向前移动读指针。
 *
 * @param[in] buf   缓冲区指针
 * @param[out] data 待存储数据的空间指针
 * @param[in] size  data 指向空间的大小
 *
 * @return 实际读取数据的长度
 */
unsigned int buffer_read(buffer_t *buf, unsigned char *data, unsigned int size);

/**
 * @brief 向前移动 buf 缓冲区中读指针的位置
 * @details 在调用 buffer_read() 后,读指针仍然在原来的位置不变,
 *          可以根据数据的实际使用长度,通过 buffer_drain() 向前移动读指针,
 *          以便下一次调用 buffer_read() 时,从新的地址读取数据。
 *
 * @param[in] buf 缓冲区指针
 * @param[in] len 读指针向前移动长度,即需要舍弃的数据长度
 *
 * @return 读指针实际向前移动的长度(不会超过缓冲区中数据的长度)
 */
unsigned int buffer_drain(buffer_t *buf, unsigned int len);

/**
 * @brief 在 buf 缓冲区的数据中,检索指定分隔符。
 * @details 如果数据中包含多个分隔符,通过多次调用 buffer_search() 和 buffer_read(),
 *          直到 buffer_search() 返回 0,可以取出所有完整的数据包。
 *
 * @param[in] buf    缓冲区指针
 * @param[in] sep    分隔符指针
 * @param[in] seplen 分隔符长度
 *
 * @return 第 1 次出现分隔符的数据长度(包含分隔符);
 *         如果没有找到分隔符,返回 0
 */
unsigned int buffer_search(buffer_t *buf, unsigned char *sep, unsigned int seplen);

3.2 ring buffer 代码实现

/**
 * @file ringbuffer.c
 * @brief Ring Buffer 实现
 */
#include "buffer.h"
#include <stdlib.h>
#include <string.h>
#include <assert.h>

#define min(a, b) ((a) < (b) ? (a) : (b))

typedef struct ringbuffer_t
{
    unsigned char *buf; // 缓存指针
    unsigned int size;  // 缓存大小,值为 2 的整数次幂,便于“求余运算”,以及“溢出后对齐”
    unsigned int wpos;  // 累加后的写位置,数据实际在 (wpos % size) 索引位置写入
    unsigned int rpos;  // 累加后的读位置,数据实际在 (rpos % size) 索引位置读出
} ringbuffer_t;

// 判断 val 值是否为 2 的幂
static inline _Bool is_power_of_two(unsigned int val)
{
    if (val < 2)
        return 0;

    // 假设 val 值为 8,那么 1000 & 0111 == 0 为真
    return (val & (val - 1)) == 0;
}

// 向上获取距离 val 值最近的 2 的幂(val 本身不能是 2 的幂)
static inline unsigned int roundup_power_of_two(unsigned int val)
{
    if (val == 0)
        return 2;

    int bits = 0; // val 的二进制位数
    for (; val != 0; bits++)
        val >>= 1;

    return 1U << bits; // 返回比 val 位数多一位的 2 的幂
}

// 创建 ring buffer,实际大小为 2 的幂,可能大于 size
ringbuffer_t *buffer_create(unsigned int size)
{
    if (!is_power_of_two(size)) // 缓冲区大小需要是 2 的幂
        size = roundup_power_of_two(size);

    ringbuffer_t *ringbuf = (ringbuffer_t *)malloc(sizeof(*ringbuf));
    if (ringbuf != NULL)
    {
        ringbuf->buf = (unsigned char *)malloc(size);
        if (ringbuf->buf != NULL)
        {
            ringbuf->size = size;
            ringbuf->rpos = ringbuf->wpos = 0;
            return ringbuf;
        }
        free(ringbuf);
    }
    return NULL;
}

void buffer_destroy(ringbuffer_t **ringbuf)
{
    if (ringbuf == NULL || *ringbuf == NULL)
        return;

    if ((*ringbuf)->buf)
        free((*ringbuf)->buf);

    free(*ringbuf);
    *ringbuf = NULL;
}

unsigned int buffer_len(ringbuffer_t *ringbuf)
{
    assert(ringbuf);
    return ringbuf->wpos - ringbuf->rpos;
}

unsigned int buffer_remain(ringbuffer_t *ringbuf)
{
    assert(ringbuf);
    return ringbuf->size - (ringbuf->wpos - ringbuf->rpos);
}

static inline _Bool buffer_is_empty(ringbuffer_t *ringbuf)
{
    return ringbuf->rpos == ringbuf->wpos;
}

static inline _Bool buffer_is_full(ringbuffer_t *ringbuf)
{
    return ringbuf->wpos - ringbuf->rpos == ringbuf->size;
}

int buffer_write(ringbuffer_t *ringbuf, unsigned char *data, unsigned int len)
{
    assert(ringbuf && data);

    if (len > ringbuf->size - (ringbuf->wpos - ringbuf->rpos)) // 剩余空间不足
        return -1;

    unsigned int widx = ringbuf->wpos & (ringbuf->size - 1); // 写索引

    unsigned int wlen = min(len, ringbuf->size - widx); // 第 1 次写入长度
    memcpy(ringbuf->buf + widx, data, wlen);            // 写入写索引后面空间
    if (wlen < len)                                     // 写入缓冲区起始空间
        memcpy(ringbuf->buf, data + wlen, len - wlen);

    ringbuf->wpos += len;
    return 0;
}

unsigned int buffer_read(ringbuffer_t *ringbuf, unsigned char *data, unsigned int size)
{
    assert(ringbuf && data);

    unsigned int len = min(size, ringbuf->wpos - ringbuf->rpos); // 实际读取数据长度
    unsigned int ridx = ringbuf->rpos & (ringbuf->size - 1);     // 读索引

    unsigned int rlen = min(len, ringbuf->size - ridx); // 第 1 次读取长度
    memcpy(data, ringbuf->buf + ridx, rlen);            // 从读索引后面空间读取
    if (rlen < len)                                     // 从缓冲区起始空间读取
        memcpy(data + rlen, ringbuf->buf, len - rlen);

    // ringbuf->rpos += len; // 允许多次读
    return len;
}

unsigned int buffer_drain(ringbuffer_t *ringbuf, unsigned int len)
{
    assert(ringbuf);

    if (len > ringbuf->wpos - ringbuf->rpos)
        len = ringbuf->wpos - ringbuf->rpos; // 最大为缓冲区数据长度

    ringbuf->rpos += len;
    return len;
}

unsigned int buffer_search(ringbuffer_t *ringbuf, unsigned char *sep, unsigned int seplen)
{
    assert(ringbuf && sep);

    unsigned int len = ringbuf->wpos - ringbuf->rpos; // 缓冲区数据长度
    if (len < seplen)
        return 0;

    for (unsigned int i = 0; i <= len - seplen; i++)
    {
        unsigned int idx = (ringbuf->rpos + i) & (ringbuf->size - 1); // 当前比较索引

        if (idx + seplen <= ringbuf->size) // “比较区间”连续,只需要一次比较
        {
            if (memcmp(ringbuf->buf + idx, sep, seplen) == 0)
                return i + seplen;
            continue;
        }

        unsigned int len1 = ringbuf->size - idx; // “比较区间”不连续,需要两次比较
        if (memcmp(ringbuf->buf + idx, sep, len1) == 0 &&
            memcmp(ringbuf->buf, sep + len1, seplen - len1) == 0)
            return i + seplen;
    }

    return 0;
}

3.3 服务端测试代码

/**
 * @file server.c
 * @brief 缓冲区测试代码(服务端)
 * @details 为每个客户端连接创建一个缓冲区,首先将接收到的客户端数据写入缓冲区,
 *          然后通过指定“分隔符”对数据进行分割,将每个独立的数据包发送给客户端。
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>

#include <unistd.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/epoll.h>

#include "buffer.h"

#define SOCK_IP INADDR_ANY
#define SOCK_PORT 2048
#define LISTEN_BACKLOG 10
#define BUF_SIZE 1024
#define MAX_FD_SIZE 1024
#define COMMU_SEPARATOR "#"

#define handle_error(msg) \
    do { perror(msg); exit(EXIT_FAILURE); } while (0)

#define set_error_en(en, msg) \
    do { errno = en; perror(msg); } while (0)

typedef enum
{
    kRetFail = -1,
    kRetSuccess = 0,
    kRetBlock,
    kRetClose,
    kRetQuit
} ret_e;

int set_nonblock(int fd);
int event_ctl(int epfd, int op, int fd, uint32_t events);
int accept_proc(int sockfd, int epfd, buffer_t **buf);
ret_e send_proc(int clientfd, buffer_t *buf, _Bool wait_call, int epfd);
ret_e recv_proc(int clientfd, buffer_t *buf);
void close_fd(int clientfd, int *buf_fd, buffer_t **buf, int epfd);

int main()
{
    int sockfd;
    socklen_t addrlen;
    struct sockaddr_in addr, server_addr;

    // 创建一个套接字,用于监听客户端的连接
    sockfd = socket(AF_INET, SOCK_STREAM, 0);
    if (sockfd == -1)
        handle_error("socket");

    // 允许地址的立即重用
    int reuse = 1;
    if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) == -1)
        handle_error("setsockopt");

    memset(&addr, 0, sizeof(addr));
    addr.sin_family = AF_INET;
    addr.sin_addr.s_addr = htonl(SOCK_IP);
    addr.sin_port = htons(SOCK_PORT);

    if (bind(sockfd, (struct sockaddr *)&addr, sizeof(addr)) == -1)
        handle_error("bind");

    // 获取本地套接字地址
    addrlen = sizeof(server_addr);
    if (getsockname(sockfd, (struct sockaddr *)&server_addr, &addrlen) == -1)
        handle_error("getsockname");
    printf("Server Addr [%s:%d]\n", inet_ntoa(server_addr.sin_addr), ntohs(server_addr.sin_port));

    if (listen(sockfd, LISTEN_BACKLOG) == -1)
        handle_error("listen");

    // 设置为非阻塞套接字,在调用 accept 时可以快速返回
    if (set_nonblock(sockfd) == -1)
        handle_error("set_nonblock");

    struct // 简化实现,将 fd(如 3)保存在索引为 fd(如 fd_buf[3])的位置
    {
        int fd;
        buffer_t *buf;
    } fd_buf[MAX_FD_SIZE] = {0}; // 读写缓存数组

    struct epoll_event events[MAX_FD_SIZE] = {0};

    int epfd = epoll_create(1);
    if (epfd == -1)
        handle_error("epoll_create");

    // 设置 “监听套接字” 监测 “读” 事件
    if (event_ctl(epfd, EPOLL_CTL_ADD, sockfd, EPOLLIN) == -1)
        handle_error("EPOLL_CTL_ADD");

    while (1)
    {
        int nready = epoll_wait(epfd, events, MAX_FD_SIZE, -1);
        if (nready == -1)
            handle_error("epoll_wait");

        for (int i = 0; i < nready; i++)
        {
            if (events[i].data.fd == sockfd) // “监听套接字” 可读
            {
                buffer_t *buf;
                int clientfd = accept_proc(sockfd, epfd, &buf);
                if (clientfd > 0)
                {
                    fd_buf[clientfd].fd = clientfd;
                    fd_buf[clientfd].buf = buf;

                    if (send_proc(clientfd, fd_buf[clientfd].buf, 0, epfd) == kRetFail)
                        close_fd(clientfd, &fd_buf[clientfd].fd, &fd_buf[clientfd].buf, epfd);
                }
                continue;
            }

            // 处理每一个客户端连接套接字事件
            int clientfd = events[i].data.fd;
            if (events[i].events & EPOLLIN) // “客户端连接套接字” 可读
            {
                ret_e ret = recv_proc(clientfd, fd_buf[clientfd].buf);
                if (ret == kRetClose || ret == kRetFail) // 断开连接,或网络出错
                {
                    close_fd(clientfd, &fd_buf[clientfd].fd, &fd_buf[clientfd].buf, epfd);
                    continue;
                }

                if (ret == kRetQuit) // 退出服务
                {
                    for (int cfd = epfd + 1; cfd < MAX_FD_SIZE; cfd++) // 关闭所有客户端 FD
                    {
                        if (fd_buf[cfd].fd > 0) // 查找未关闭的客户端连接
                            close_fd(cfd, &fd_buf[cfd].fd, &fd_buf[cfd].buf, epfd);
                    }
                    if (close(epfd) == -1)
                        perror("close epfd");
                    if (close(sockfd) == -1)
                        perror("close sockfd");
                    return 0;
                }

                if (send_proc(clientfd, fd_buf[clientfd].buf, 0, epfd) == kRetFail)
                    close_fd(clientfd, &fd_buf[clientfd].fd, &fd_buf[clientfd].buf, epfd);
            }

            if (events[i].events & EPOLLOUT) // “客户端连接套接字” 可写
            {
                ret_e ret = send_proc(clientfd, fd_buf[clientfd].buf, 1, epfd);
                if (ret == kRetSuccess) // 成功发送,取消“可写”监测
                {
                    if (event_ctl(epfd, EPOLL_CTL_MOD, clientfd, EPOLLIN) == -1)
                    {
                        perror("EPOLL_CTL_MOD");
                        close_fd(clientfd, &fd_buf[clientfd].fd, &fd_buf[clientfd].buf, epfd);
                    }
                }
                else if (ret == kRetFail)
                    close_fd(clientfd, &fd_buf[clientfd].fd, &fd_buf[clientfd].buf, epfd);
            }
        }
    }

    return 0;
}

int set_nonblock(int fd)
{
    int flag = fcntl(fd, F_GETFL, 0);
    if (flag < 0)
        return flag;

    return fcntl(fd, F_SETFL, flag | O_NONBLOCK);
}

int event_ctl(int epfd, int op, int fd, uint32_t events)
{
    struct epoll_event ev;
    ev.data.fd = fd;
    ev.events = events;
    return epoll_ctl(epfd, op, fd, &ev);
}

int accept_proc(int sockfd, int epfd, buffer_t **buf)
{
    struct sockaddr_in client_addr;
    socklen_t addrlen = sizeof(client_addr);

    // 从连接请求队列中取出排在最前面的客户端请求
    int clientfd = accept(sockfd, (struct sockaddr *)&client_addr, &addrlen);
    if (clientfd >= 0)
    {
        // 设置为非阻塞套接字,在调用 recv/send 时可以快速返回
        if (set_nonblock(clientfd) == -1)
        {
            perror("set_nonblock");
            close(clientfd);
            return -1;
        }

        if (event_ctl(epfd, EPOLL_CTL_ADD, clientfd, EPOLLIN) >= 0)
        {
            *buf = buffer_create(BUF_SIZE);
            if (*buf != NULL)
            {
                printf("accept fd(%d) [%s:%d]\n", clientfd, inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));

                char msg[256];
                sprintf(msg, "Welcome [%s:%d] to the server!\n"
                             "Send \"bye\" to end the communication,\n"
                             "Send \"q\" to quit the service,\n"
                             "Send anything else to continue the communication.\n"
                             "Each message needs to end with the separator: %s",
                        inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port), COMMU_SEPARATOR);

                if (buffer_write(*buf, msg, strlen(msg)) == 0) // 向缓冲区写入数据
                    return clientfd;

                set_error_en(ENOMEM, "buffer_write");
                buffer_destroy(buf);
            }
            else
                set_error_en(ENOMEM, "buffer_create");
            event_ctl(epfd, EPOLL_CTL_DEL, clientfd, 0);
        }
        else
            perror("event_ctl");
        close(clientfd);
    }
    else
        perror("accept");

    return -1;
}

void close_fd(int clientfd, int *buf_fd, buffer_t **buf, int epfd)
{
    printf("close fd(%d)\n", clientfd);

    *buf_fd = -1;
    buffer_destroy(buf);

    if (event_ctl(epfd, EPOLL_CTL_DEL, clientfd, 0) == -1)
        perror("EPOLL_CTL_DEL");

    if (close(clientfd) == -1)
        perror("close clientfd");
}

ret_e send_proc(int clientfd, buffer_t *buf, _Bool wait_call, int epfd)
{
    ret_e ret = kRetSuccess;
    char msg[BUF_SIZE];

    while (1) // 多次调用 buffer_search()
    {
        unsigned int msglen = buffer_search(buf, COMMU_SEPARATOR, strlen(COMMU_SEPARATOR));
        if (msglen == 0) // 没有完整数据包
            break;

        msglen = msglen < BUF_SIZE ? msglen : BUF_SIZE;
        unsigned int rlen = buffer_read(buf, msg, msglen);
        if (rlen != msglen)
        {
            sprintf(msg, "buffer_read(%u/%u)", rlen, msglen);
            set_error_en(EINVAL, msg);
            return kRetFail;
        }

        while (1) // 在 EINTR 情况下多次调用 send()
        {
            // int wlen = write(clientfd, msg, msglen);
            int wlen = send(clientfd, msg, msglen, 0);
            if (wlen == -1)
            {
                if (errno == EINTR)
                    continue;
                if (errno == EWOULDBLOCK || errno == EAGAIN)
                {
                    ret = kRetBlock;
                    break;
                }

                perror("send");
                return kRetFail;
            }

            if (wlen < msglen)
                ret = kRetBlock;

            buffer_drain(buf, wlen); // 移动读指针

            msg[wlen] = '\0';
            printf("fd(%d) send: %s(%d)\n", clientfd, msg, wlen);

            break;
        }

        if (ret == kRetBlock && !wait_call) // 没有完全发送,并且不是通过 epoll_wait 调用
        {
            if (event_ctl(epfd, EPOLL_CTL_MOD, clientfd, EPOLLIN | EPOLLOUT) == -1)
            {
                perror("EPOLL_CTL_MOD");
                return kRetFail;
            }
            break;
        }
    }

    return ret;
}

ret_e recv_proc(int clientfd, buffer_t *buf)
{
    char msg[BUF_SIZE];

    while (1) // 多次调用 recv()
    {
        // int rlen = read(clientfd, msg, BUF_SIZE);
        int rlen = recv(clientfd, msg, BUF_SIZE, 0);
        if (rlen == -1)
        {
            if (errno == EINTR)
                continue;
            if (errno == EWOULDBLOCK || errno == EAGAIN)
                return kRetBlock;

            perror("recv");
            return kRetFail;
        }

        msg[rlen] = '\0';
        printf("fd(%d) recv: %s(%d)\n", clientfd, msg, rlen);

        // 通信控制
        if (rlen == 0 || strcmp(msg, "bye") == 0)
            return kRetClose;
        else if (strcmp(msg, "q") == 0)
            return kRetQuit;

        if (buffer_write(buf, msg, rlen) < 0)
        {
            set_error_en(ENOMEM, "buffer_write");
            return kRetFail;
        }
    }

    return kRetSuccess;
}

3.4 客户端测试代码

/**
 * @file client.c
 * @brief 缓冲区测试代码(客户端)
 * @details 通过 stdin 读取数据,写入缓冲区;再从缓冲区读取数据,发送给服务端;
 *          通过 stdout 显示从服务端接收到的数据。
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>

#include <unistd.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/epoll.h>

#include "buffer.h"

#define SERVER_IP "192.168.126.128"
#define SERVER_PORT 2048
#define BUF_SIZE 1024

#define handle_error(msg) \
    do { perror(msg); exit(EXIT_FAILURE); } while (0)

#define handle_error_en(en, msg) \
    do { errno = en; perror(msg); exit(EXIT_FAILURE); } while (0)

typedef enum
{
    kRetFail = -1,
    kRetSuccess = 0,
    kRetBlock,
    kRetClose,
    kRetQuit
} ret_e;

int set_nonblock(int fd);
int event_ctl(int epfd, int op, int fd, uint32_t events);
ret_e recv_proc(int clientfd);
ret_e send_proc(int clientfd, buffer_t *buf, _Bool wait_call, int epfd);
ret_e input_send(int sockfd, buffer_t *buf, int epfd);
void exit_proc(int clientfd, buffer_t **buf, int epfd, int status);

int main()
{
    int sockfd = socket(AF_INET, SOCK_STREAM, 0);
    if (sockfd == -1)
        handle_error("socket");

    struct sockaddr_in server_addr;
    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_addr.s_addr = inet_addr(SERVER_IP);
    server_addr.sin_port = htons(SERVER_PORT);

    if (connect(sockfd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1)
        handle_error("connect");

    if (set_nonblock(sockfd) == -1)
        handle_error("set_nonblock");

    buffer_t *buf = buffer_create(BUF_SIZE);
    if (buf == NULL)
        handle_error_en(ENOMEM, "buffer_create");

    int epfd = epoll_create(1);
    if (epfd == -1)
        handle_error("epoll_create");

    if (event_ctl(epfd, EPOLL_CTL_ADD, sockfd, EPOLLIN | EPOLLET) == -1)
        handle_error("EPOLL_CTL_ADD");

    if (event_ctl(epfd, EPOLL_CTL_ADD, 0, EPOLLIN | EPOLLET) == -1)
        handle_error("EPOLL_CTL_ADD");

    struct epoll_event events[2]; // 监听 网络套接字 和 标准输入

    while (1) // 在循环中与服务器端通信
    {
        int nready = epoll_wait(epfd, events, 2, -1);
        if (nready == -1)
            handle_error("epoll_wait");

        for (int i = 0; i < nready; i++)
        {
            if (events[i].events & EPOLLERR || events[i].events & EPOLLHUP)
                exit_proc(sockfd, &buf, epfd, EXIT_FAILURE);

            if (events[i].events & EPOLLIN)
            {
                if (events[i].data.fd == sockfd)
                {
                    // 从服务端接收数据,并显示
                    ret_e ret = recv_proc(sockfd);
                    if (ret == kRetClose || ret == kRetFail)
                        exit_proc(sockfd, &buf, epfd, ret == kRetClose ? EXIT_SUCCESS : EXIT_FAILURE);
                }

                if (events[i].data.fd == 0)
                {
                    // 从 stdin 读取一行数据,并发送
                    ret_e ret = input_send(sockfd, buf, epfd);
                    if (ret == kRetFail)
                        exit_proc(sockfd, &buf, epfd, EXIT_FAILURE);
                }
            }

            if (events[i].events & EPOLLOUT && events[i].data.fd == sockfd)
            {
                ret_e ret = send_proc(sockfd, buf, 1, epfd);
                if (ret == kRetSuccess) // 成功发送,取消“可写”监测
                {
                    if (event_ctl(epfd, EPOLL_CTL_MOD, sockfd, EPOLLIN | EPOLLET) == -1)
                    {
                        perror("EPOLL_CTL_MOD");
                        exit_proc(sockfd, &buf, epfd, EXIT_FAILURE);
                    }
                }
                else if (ret == kRetFail)
                    exit_proc(sockfd, &buf, epfd, EXIT_FAILURE);
            }
        }
    }

    exit_proc(sockfd, &buf, epfd, EXIT_SUCCESS);

    return 0;
}

void exit_proc(int clientfd, buffer_t **buf, int epfd, int status)
{
    buffer_destroy(buf);

    if (event_ctl(epfd, EPOLL_CTL_DEL, clientfd, 0) == -1)
        perror("EPOLL_CTL_DEL");

    if (event_ctl(epfd, EPOLL_CTL_DEL, 0, 0) == -1)
        perror("EPOLL_CTL_DEL");

    if (close(epfd) == -1)
        perror("close epfd");

    if (close(clientfd) == -1)
        perror("close clientfd");

    exit(status);
}

int set_nonblock(int fd)
{
    int flag = fcntl(fd, F_GETFL, 0);
    if (flag < 0)
        return flag;

    return fcntl(fd, F_SETFL, flag | O_NONBLOCK);
}

int event_ctl(int epfd, int op, int fd, uint32_t events)
{
    struct epoll_event ev;
    ev.data.fd = fd;
    ev.events = events;
    return epoll_ctl(epfd, op, fd, &ev);
}

ret_e recv_proc(int clientfd)
{
    char msg[BUF_SIZE];

    while (1) // 多次调用 recv()
    {
        int rlen = recv(clientfd, msg, BUF_SIZE, 0);
        if (rlen == -1)
        {
            if (errno == EINTR)
                continue;
            if (errno == EWOULDBLOCK || errno == EAGAIN)
                return kRetBlock;
            perror("recv");
            return kRetFail;
        }

        msg[rlen] = '\0';
        printf("recv: %s(%d)\n", msg, rlen);

        if (rlen == 0)
            return kRetClose;
    }

    return kRetSuccess;
}

ret_e send_proc(int clientfd, buffer_t *buf, _Bool wait_call, int epfd)
{
    ret_e ret = kRetSuccess;
    char msg[BUF_SIZE];

    while (1) // 多次调用 buffer_read()
    {
        unsigned int rlen = buffer_read(buf, msg, BUF_SIZE);
        if (rlen == 0) // 没有数据
            break;

        while (1) // 在 EINTR 情况下多次调用 send()
        {
            int wlen = send(clientfd, msg, rlen, 0);
            if (wlen == -1)
            {
                if (errno == EINTR)
                    continue;
                if (errno == EWOULDBLOCK || errno == EAGAIN)
                {
                    ret = kRetBlock;
                    break;
                }

                perror("send");
                return kRetFail;
            }

            if (wlen < rlen)
                ret = kRetBlock;

            buffer_drain(buf, wlen); // 移动读指针

            msg[wlen] = '\0';
            printf("send: %s(%d)\n", msg, wlen);

            break;
        }

        if (ret == kRetBlock && !wait_call) // 没有完全发送,并且不是通过 epoll_wait 调用
        {
            if (event_ctl(epfd, EPOLL_CTL_MOD, clientfd, EPOLLIN | EPOLLOUT | EPOLLET) == -1)
            {
                perror("EPOLL_CTL_MOD");
                return kRetFail;
            }
            break;
        }
    }

    return ret;
}

ret_e input_send(int sockfd, buffer_t *buf, int epfd)
{
    ret_e ret = kRetSuccess;
    char msg[BUF_SIZE];

    fgets(msg, sizeof(msg), stdin);
    int datalen = strlen(msg);
    if (datalen != 1)
        msg[--datalen] = '\0'; // 如果非空,去掉换行符

    if (buffer_write(buf, msg, strlen(msg)) < 0)
    {
        errno = ENOMEM;
        perror("buffer_write");
        return kRetFail;
    }

    return send_proc(sockfd, buf, 0, epfd);
}

参考

  1. 知乎文章:ring buffer,一篇文章讲透它?

宁静以致远,感谢 Mark 老师。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值