前言
看过《Windows网络与通信程序设计》的人都知道,里面有一段有关于IOCP的经典封装。大大方便了“伸手党”服务器端程序的开发(我也是其中之一)。但是应用到实际程序中你会发现经常出现一个莫名奇妙的问题:一旦客户端发送的字节数过多,服务器端接受其中几条后就“死掉”了,我也深受其害,于是乎
今天花了2小时时间通读了代码,把其中的bug找到(PS:不敢保证是否还是其他bug,暂时还未发现)
问题
引起上述现象的是由于作者的小疏忽导致的,我们知道采用完成端口时要开辟多个线程(一般为CPU核心数)来监听请求,如果客户端在短时间内发送了一大堆字节,
这些字节在客户端上肯定是分多次顺序进行发送。服务端接受的时候却是多个线程(单线程CPU例外)各自接受各自的原来的顺序就被打乱了,如果只是简单拼凑起来肯定会出现问题,如何解决这个问题呢,作者已经帮我想好了。虽然各个线程处理数据的顺序可能不一直,但是投递读请求的顺序肯定是与客户端发送的顺序是一直的,于是作者CIOCPBuffer中添加了一个顺序标识
nSequenceNumber
用来标识当前读取数据的发送顺序,并建立了一个队列将当前接受CIOCPBuffer联系起来。
// 这是per-I/O数据。它包含了在套节字上处理I/O操作的必要信息
struct CIOCPBuffer
{
WSAOVERLAPPED ol;
SOCKET sClient; // AcceptEx接收的客户方套节字
char *buff; // I/O操作使用的缓冲区
int nLen; // buff缓冲区(使用的)大小
ULONG nSequenceNumber; // 此I/O的序列号
int nOperation; // 操作类型
#define OP_ACCEPT 1
#define OP_WRITE 2
#define OP_READ 3
CIOCPBuffer *pNext;
};
读取数据的时候先判断读取数据CIOCPBuffer的nSequenceNumber与当前队列头的CIOCPBuffer是否一致,一致则循环将队列中的所有Buffer按顺序传递给用户,不一致就将它按顺序加入队列。
这种设计是不是很巧妙。
作者的代码:
CIOCPBuffer *CIOCPServer::GetNextReadBuffer(CIOCPContext *pContext, CIOCPBuffer *pBuffer)
{
if(pBuffer != NULL)
{
// 如果与要读的下一个序列号相等,则读这块缓冲区
if(pBuffer->nSequenceNumber == pContext->nCurrentReadSequence)
{
return pBuffer;
}
// 如果不相等,则说明没有按顺序接收数据,将这块缓冲区保存到连接的pOutOfOrderReads列表中
// 列表中的缓冲区是按照其序列号从小到大的顺序排列的
pBuffer->pNext = NULL;
CIOCPBuffer *ptr = pContext->pOutOfOrderReads;
CIOCPBuffer *pPre = NULL;
while(ptr != NULL)
{
if(pBuffer->nSequenceNumber < ptr->nSequenceNumber)
break;
pPre = ptr;
ptr = ptr->pNext;
}
if(pPre == NULL) // 应该插入到表头
{
pBuffer->pNext = pContext->pOutOfOrderReads;
pContext->pOutOfOrderReads = pBuffer;
}
else // 应该插入到表的中间
{
pBuffer->pNext = pPre->pNext;
pPre->pNext = pBuffer->pNext;
}
}
// 检查表头元素的序列号,如果与要读的序列号一致,就将它从表中移除,返回给用户
CIOCPBuffer *ptr = pContext->pOutOfOrderReads;
if(ptr != NULL && (ptr->nSequenceNumber == pContext->nCurrentReadSequence))
{
pContext->pOutOfOrderReads = ptr->pNext;
return ptr;
}
return NULL;
}
我处理后的代码
CIOCPBuffer *CIOCPServer::GetNextReadBuffer(CIOCPContext *pContext, CIOCPBuffer *pBuffer)
{
if(pBuffer != NULL)
{
// 如果与要读的下一个序列号相等,则读这块缓冲区
if(pBuffer->nSequenceNumber == pContext->nCurrentReadSequence)
{
return pBuffer;
}
// 如果不相等,则说明没有按顺序接收数据,将这块缓冲区保存到连接的pOutOfOrderReads列表中
// 列表中的缓冲区是按照其序列号从小到大的顺序排列的
pBuffer->pNext = NULL;
CIOCPBuffer *ptr = pContext->pOutOfOrderReads;
CIOCPBuffer *pPre = NULL;
while(ptr != NULL)
{
if(pBuffer->nSequenceNumber < ptr->nSequenceNumber)
break;
pPre = ptr;
ptr = ptr->pNext;
}
if(pPre == NULL) // 应该插入到表头
{
pBuffer->pNext = pContext->pOutOfOrderReads;
pContext->pOutOfOrderReads = pBuffer;
}
else // 应该插入到表的中间
{
pBuffer->pNext = pPre->pNext;
pPre->pNext = pBuffer;
}
}
// 检查表头元素的序列号,如果与要读的序列号一致,就将它从表中移除,返回给用户
CIOCPBuffer *ptr = pContext->pOutOfOrderReads;
if(ptr != NULL && (ptr->nSequenceNumber == pContext->nCurrentReadSequence))
{
pContext->pOutOfOrderReads = ptr->pNext;
return ptr;
}
return NULL;
}
对,你没有看错,仅仅是这么简单一句代码,就造成了整个程序的错误。
看不懂的同学自己去看看如何向链表中见插入一个元素吧。
废话不多说上一段完整程序:
////////////////////////////////////////
// IOCP.h文件
#ifndef __IOCP_H__
#define __IOCP_H__
#include <winsock2.h>
#include <windows.h>
#include <Mswsock.h>
#define BUFFER_SIZE 1024*2 // I/O请求的缓冲区大小
// 这是per-I/O数据。它包含了在套节字上处理I/O操作的必要信息
struct CIOCPBuffer
{
WSAOVERLAPPED ol;
SOCKET sClient; // AcceptEx接收的客户方套节字
char *buff; // I/O操作使用的缓冲区
int nLen; // buff缓冲区(使用的)大小
ULONG nSequenceNumber; // 此I/O的序列号
int nOperation; // 操作类型
#define OP_ACCEPT 1
#define OP_WRITE 2
#define OP_READ 3
CIOCPBuffer *pNext;
};
// 这是per-Handle数据。它包含了一个套节字的信息
struct CIOCPContext
{
SOCKET s; // 套节字句柄
SOCKADDR_IN addrLocal; // 连接的本地地址
SOCKADDR_IN addrRemote; // 连接的远程地址
BOOL bClosing; // 套节字是否关闭
int nOutstandingRecv; // 此套节字上抛出的重叠操作的数量
int nOutstandingSend;
ULONG nReadSequence; // 安排给接收的下一个序列号
ULONG nCurrentReadSequence; // 当前要读的序列号
CIOCPBuffer *pOutOfOrderReads; // 记录没有按顺序完成的读I/O
CRITICAL_SECTION Lock; // 保护这个结构
CIOCPContext *pNext;
};
class CIOCPServer // 处理线程
{
public:
CIOCPServer();
~CIOCPServer();
// 开始服务
BOOL Start(int nPort = 4567, int nMaxConnections = 2000,
int nMaxFreeBuffers = 200, int nMaxFreeContexts = 100, int nInitialReads = 4);
// 停止服务
void Shutdown();
// 关闭一个连接和关闭所有连接
void CloseAConnection(CIOCPContext *pContext);
void CloseAllConnections();
// 取得当前的连接数量
ULONG GetCurrentConnection() { return m_nCurrentConnection; }
// 向指定客户发送文本
BOOL SendText(CIOCPContext *pContext, char *pszText, int nLen);
// 获得本机处理器的数量
static int _GetNoOfProcessors();
protected:
// 申请和释放缓冲区对象
CIOCPBuffer *AllocateBuffer(int nLen);
void ReleaseBuffer(CIOCPBuffer *pBuffer);
// 申请和释放套节字上下文
CIOCPContext *AllocateContext(SOCKET s);
void ReleaseContext(CIOCPContext *pContext);
// 释放空闲缓冲区对象列表和空闲上下文对象列表
void FreeBuffers();
void FreeContexts();
// 向连接列表中添加一个连接
BOOL AddAConnection(CIOCPContext *pContext);
// 插入和移除未决的接受请求
BOOL InsertPendingAccept(CIOCPBuffer *pBuffer);
BOOL RemovePendingAccept(CIOCPBuffer *pBuffer);
// 取得下一个要读取的
CIOCPBuffer *GetNextReadBuffer(CIOCPContext *pContext, CIOCPBuffer *pBuffer);
// 投递接受I/O、发送I/O、接收I/O
BOOL PostAccept(CIOCPBuffer *pBuffer);
BOOL PostSend(CIOCPContext *pContext, CIOCPBuffer *pBuffer);
BOOL PostRecv(CIOCPContext *pContext, CIOCPBuffer *pBuffer);
void HandleIO(DWORD dwKey, CIOCPBuffer *pBuffer, DWORD dwTrans, int nError);
// 事件通知函数
// 建立了一个新的连接
virtual void OnConnectionEstablished(CIOCPContext *pContext, CIOCPBuffer *pBuffer);
// 一个连接关闭
virtual void OnConnectionClosing(CIOCPContext *pContext, CIOCPBuffer *pBuffer);
// 在一个连接上发生了错误
virtual void OnConnectionError(CIOCPContext *pContext, CIOCPBuffer *pBuffer, int nError);
// 一个连接上的读操作完成
virtual void OnReadCompleted(CIOCPContext *pContext, CIOCPBuffer *pBuffer);
// 一个连接上的写操作完成
virtual void OnWriteCompleted(CIOCPContext *pContext, CIOCPBuffer *pBuffer);
protected:
// 记录空闲结构信息
CIOCPBuffer *m_pFreeBufferList;
CIOCPContext *m_pFreeContextList;
int m_nFreeBufferCount;
int m_nFreeContextCount;
CRITICAL_SECTION m_FreeBufferListLock;
CRITICAL_SECTION m_FreeContextListLock;
// 记录抛出的Accept请求
CIOCPBuffer *m_pPendingAccepts; // 抛出请求列表。
long m_nPendingAcceptCount;
CRITICAL_SECTION m_PendingAcceptsLock;
// 记录连接列表
CIOCPContext *m_pConnectionList;
int m_nCurrentConnection;
CRITICAL_SECTION m_ConnectionListLock;
// 用于投递Accept请求
HANDLE m_hAcceptEvent;
HANDLE m_hRepostEvent;
LONG m_nRepostCount;
int m_nThread;
int m_nPort; // 服务器监听的端口
int m_nInitialAccepts;
int m_nInitialReads;
int m_nMaxAccepts;
int m_nMaxSends;
int m_nMaxFreeBuffers;
int m_nMaxFreeContexts;
int m_nMaxConnections;
HANDLE m_hListenThread