关于WSARECV与WSASEND,SOCKET_ERROR关闭997异常的问题

本文详细解析了IOCP编程模型中WSASend与WSARecv的三种工作状态及其对内存的影响,并提供了优化建议。

原文链接:

http://www.xixis.net/doserver.net//read.php/2063.htm

最近写一个IOCP的服务器程序, 在WSARECV出现SOCKET_ERROR后,关闭与之相关的SOCKET结果出现异常997。 在网上找到一篇不错的文章,。

 应该是windows网络编程第二版里面提到过。现在整理一下。

1:在IOCP中投递WSASend返回WSA_IO_PENDING的时候,表示异步投递已经成功,但是稍后发送才会完成。这其中涉及到了三个缓冲区。
网卡缓冲区,TCP/IP层缓冲区,程序缓冲区。
情况一:调用WSASend发送正确的时候(即立即返回,且没有错误),TCP/IP将数据从程序缓冲区中拷贝到TCP/IP层缓冲区中,然后不锁定该程序缓冲区,由上层程序自己处理。TCP/IP层缓冲区在网络合适的时候,将其数据拷贝到网卡缓冲区,进行真正的发送。
情况二:调用WSASend发送错误,但是错误码是WSA_IO_PENDING的时候,表示此时TCP/IP层缓冲区已满,暂时没有剩余的空间将程序缓冲区的数据拷贝出来,这时系统将锁定用户的程序缓冲区,按照书上说的WSASend指定的缓冲区将会被锁定到系统的非分页内存中。直到TCP/IP层缓冲区有空余的地方来接受拷贝我们的程序缓冲区数据才拷贝走,并将给IOCP一个完成消息。
情况三:调用WSASend发送错误,但是错误码不是WSA_IO_PENDING,此时应该是发送错误,应该释放该SOCKET对应的所有资源。

2:在IOCP中投递WSARecv的时候,情况相似。
情况一:调用WSARecv正确,TCP/IP将数据从TCP/IP层缓冲区拷贝到缓冲区,然后由我们的程序自行处理了。清除TCP/IP层缓冲区数据。
情况二:调用WSARecv错误,但是返回值是WSA_IO_PENDING,此时是因为TCP/IP层缓冲区中没有数据可取,系统将会锁定我们投递的WSARecv的buffer,直到TCP/IP层缓冲区中有新的数据到来。
情况三:调用WSARecv错误,错误值不是WSA_IO_PENDING,此时是接收出错,应该释放该SOCKET对应的所有资源。

在以上情况中有几个非常要注意的事情:
系统锁定非分页内存的时候,最小的锁定大小是4K(当然,这个取决于您系统的设置,也可以设置小一些,在注册表里面可以改,当然我想这些数值微软应该比我们更知道什么合适了),所以当我们投递了很多WSARecv或者WSASend的时候,不管我们投递的Buffer有多大(0除外),系统在出现IO_PENGDING的时候,都会锁定我们4K的内存。这也就是经常有开发者出现WSANOBUF的情况原因了。

我们在解决这个问题的时候,要针对WSASend和WSARecv做处理
1:投递WSARecv的时候,可以采用一个巧妙的设计,先投递0大小Buf的WSARecv,如果返回,表示有数据可以接收,我们开启真正的recv将数据从TCP/IP层缓冲区取出来,直到WSA_IO_PENGDING.
2:对投递的WSARecv以及WSASend进行计数统计,如果超过了我们预定义的值,就不进行WSASend或者WSARecv投递了。
3:现在我们应该就可以明白为什么WSASend会返回小于我们投递的buffer空间数据值了,是因为TCP/IP层缓冲区小于我们要发送的缓冲区,TCP/IP只会拷贝他剩余可被Copy的缓冲区大小的数据走,然后给我们的WSASend的已发送缓冲区设置为移走的大小,下一次投递的时候,如果TCP/IP层还未被发送,将返回WSA_IO_PENGDING。
4:在很多地方有提到,可以关闭TCP/IP层缓冲区,可以提高一些效率和性能,这个从上面的分析来看,有这个可能,要实际的网络情况去实际分析了。

那么WSASend/WSARecv如何GetQueuedCompletionStatus联动,int MyServer::threadIocp() { DWORD transferred = 0; ULONG_PTR CompletionKey = 0; OVERLAPPED* lpOverlapped = NULL; if (GetQueuedCompletionStatus(m_hIOCP, &transferred, &CompletionKey, &lpOverlapped, INFINITE)) { if (CompletionKey != 0) { MyOverlapped* pOverlapped = CONTAINING_RECORD(lpOverlapped, MyOverlapped, m_overlapped); TRACE("pOverlapped->m_operator %d \r\n", pOverlapped->m_operator); pOverlapped->m_server = this; switch (pOverlapped->m_operator) { case EAccept: { ACCEPTOVERLAPPED* pOver = (ACCEPTOVERLAPPED*)pOverlapped; m_pool.DispatchWorker(pOver->m_worker); } break; case ERecv: { RECVOVERLAPPED* pOver = (RECVOVERLAPPED*)pOverlapped; m_pool.DispatchWorker(pOver->m_worker); } break; case ESend: { SENDOVERLAPPED* pOver = (SENDOVERLAPPED*)pOverlapped; m_pool.DispatchWorker(pOver->m_worker); } break; case EError: { ERROROVERLAPPED* pOver = (ERROROVERLAPPED*)pOverlapped; m_pool.DispatchWorker(pOver->m_worker); } break; } } else { return -1; } } return 0; } template<MyOperator op> int AcceptOverlapped<op>::AcceptWorker() { INT lLength = 0, rLength = 0; if (m_client->GetBufferSize() > 0) { sockaddr* plocal = NULL, * remote = NULL; GetAcceptExSockaddrs(*m_client, 0, sizeof(sockaddr_in) + 16, sizeof(sockaddr_in) + 16, (sockaddr**)&plocal, &lLength, (sockaddr**)&remote, &rLength); memcpy(m_client->GetLocalAddr(), plocal, sizeof(sockaddr_in)); memcpy(m_client->GetLocalAddr(), remote, sizeof(sockaddr_in)); m_server->BindNewSocket(*m_client); int ret = WSARecv((SOCKET)*m_client, m_client->RecvWSABuffer(), 1, *m_client, &m_client->flags(), m_client->RecvOverlapped(), NULL); if (ret == SOCKET_ERROR && (WSAGetLastError() != WSA_IO_PENDING)) { TRACE("ret = %d error = %d\r\n", ret, WSAGetLastError()); } if (!m_server->NewAccept()) { return -2; } } return -1; }
12-08
#pragma once #include <windows.h> #include <winsock2.h> #include <ws2tcpip.h> #include <mswsock.h> #include <vector> #include <memory> #include <functional> #include <string> #include <iostream> #include <mutex> #include <unordered_map> #include <atomic> #include <thread> #include <queue> #pragma comment(lib, "ws2_32.lib") #pragma comment(lib, "mswsock.lib") namespace IOCP2 { //定义常量 constexpr int BUFFER_SIZE = 4096; //定义IO操作类型 enum IO_OPERATION{ ACCEPT, READ, WRITE }; //节点类型 enum NODE_TYPE { NODE, SERVER, CLIENT }; /** * I/O上下文数据============================================================================= */ class IOContext { public: //该I/O操作唯一标识符,相同的I/O操作,参数的地址相同,通过 GetQueuedCompletionStatus返回的 LPOVERLAPPED可反向获取 IOContext: OVERLAPPED m_overlapped; //描述数据缓冲区的信息,用于 WSARecvWSASend。 WSABUF m_dataBuf; //I/O操作类型 IO_OPERATION m_operation; //缓冲区数据 char m_buffer[BUFFER_SIZE]; SOCKET m_socket; public: IOContext() { ZeroMemory(&this->m_overlapped,sizeof(OVERLAPPED)); this->m_dataBuf.len = BUFFER_SIZE; this->m_dataBuf.buf = this->m_buffer; this->m_operation = IO_OPERATION::READ; this->m_socket = INVALID_SOCKET; } ~IOContext(){} /** * * @param other * @return */ bool operator == (const IOContext& other) { return &this->m_overlapped == &other.m_overlapped; } /** * * @param other * @return */ bool operator != (const IOContext& other) { return &this->m_overlapped != &other.m_overlapped; } }; /** * 套接字控制工具,用于调用Windows扩展功能======================================================= */ class WSAIoctlTool { private: //服务器接收链接函数地址 static GUID guidAcceptEx; //获取接受到的链接地址的函数的地址 static GUID guidGetAcceptExSockaddrs; private: /** * * @tparam FUNC_POINTERE * @param targetSocket * @param funcGuID * @param funcPointer * @return */ template<typename FUNC_POINTERE> static bool WSAIoctlFun(SOCKET& targetSocket,GUID& funcGuID, FUNC_POINTERE& funcPointer) { DWORD bytesReceived; return WSAIoctl( targetSocket, SIO_GET_EXTENSION_FUNCTION_POINTER, &funcGuID, sizeof(funcGuID), &funcPointer, sizeof(funcPointer), &bytesReceived, nullptr, nullptr) != SOCKET_ERROR; } public: /** *异步接收套接字连接 * @param sListenSocket 监听 socket(bind+ listen后的 socket) * @param sAcceptSocket 预先创建的 socket,用于接受新连接 * @param lpOutputBuffer 接收数据的缓冲区,同时包含本地和远程地址信息 * @param lpdwBytesReceived 返回实际接收的字节数(仅在同步操作时有效) * @param lpOverlapped 用于异步操作的 OVERLAPPED结构 * @return */ static bool acceptExFun(SOCKET sListenSocket,SOCKET sAcceptSocket,PVOID lpOutputBuffer, LPDWORD lpdwBytesReceived,LPOVERLAPPED lpOverlapped) { LPFN_ACCEPTEX lpAcceptEx = nullptr; if (!WSAIoctlTool::WSAIoctlFun(sListenSocket,WSAIoctlTool::guidAcceptEx,lpAcceptEx)) return false; return lpAcceptEx(sListenSocket,sAcceptSocket,lpOutputBuffer,0,sizeof(sockaddr_in)+16,sizeof(sockaddr_in)+16,lpdwBytesReceived,lpOverlapped) != FALSE; } /** *获取连接的地址 * @param sAcceptSocket //AcceptEx的套接字 * @param lpOutputBuffer //AcceptEx 返回的缓冲区 * @return ip:port */ static std::string GetAcceptExSockaddrsFun(SOCKET sAcceptSocket, PVOID lpOutputBuffer) { LPFN_GETACCEPTEXSOCKADDRS lpGetAcceptExSockaddrs = nullptr; //获取客户端地址信息 sockaddr_in* localAddr = nullptr; sockaddr_in* remoteAddr = nullptr; int localAddrLen = 0; int remoteAddrLen = 0; constexpr const int addressBufferSize = sizeof(sockaddr_in) + 16; if (!WSAIoctlTool::WSAIoctlFun(sAcceptSocket,WSAIoctlTool::guidGetAcceptExSockaddrs,lpGetAcceptExSockaddrs)) return NULL; //使用使用GetAcceptExSockaddrs获取地址信息 lpGetAcceptExSockaddrs(lpOutputBuffer,0,addressBufferSize, addressBufferSize, (sockaddr**)&localAddr, &localAddrLen, (sockaddr**)&remoteAddr, &remoteAddrLen); char clientIP[INET_ADDRSTRLEN]; inet_ntop(AF_INET,&(remoteAddr->sin_addr),clientIP,INET_ADDRSTRLEN); unsigned short clientPort = ntohs(remoteAddr->sin_port); return std::string(clientIP,INET_ADDRSTRLEN)+":"+std::to_string(clientPort); } }; class RemoteNodesData; /** * 节点Socket上下文数据======================================================================= */ class NodeSocketContextData { private: SOCKET m_socket; std::string m_IPPortStr; std::shared_ptr<IOContext> m_readContext; //void指针存储overlapped地址,overlapped地址一样说明是同一个I/O操作 std::unordered_map<void*,std::shared_ptr<IOContext>> m_writeContexts; private: /** * * @param context */ void deleteReadContext(std::shared_ptr<IOContext>& context) { //不是同一个cosntext if (&this->m_readContext.get()->m_overlapped != &context->m_overlapped) return; this->m_readContext.reset(); } /** * * @param context */ void setWriteContext(std::shared_ptr<IOContext>& context) { this->m_writeContexts[&context->m_overlapped] = context; } /** * * @param context */ void setReadContext(std::shared_ptr<IOContext>& context) { this->m_readContext = context; } /** * * @param context */ void deleteWriteContext(std::shared_ptr<IOContext>& context) { this->m_writeContexts.erase(&context->m_overlapped); } public: NodeSocketContextData() { this->m_socket = INVALID_SOCKET; this->m_IPPortStr = ""; this->m_readContext = nullptr; } /** * * @param ipPortStr * @param socket */ NodeSocketContextData(const std::string& ipPortStr,SOCKET& socket) { this->m_socket = socket; this->m_IPPortStr = ipPortStr; } ~NodeSocketContextData() { this->m_writeContexts.clear(); } /** * * @return */ std::string& getIPPortStr() { return this->m_IPPortStr; } /** * * @return */ SOCKET& getSocket() { return this->m_socket; } /** * * @param socket */ void setSocket(SOCKET& socket) { this->m_socket = socket; } public: friend class RemoteNodesData; }; /** * 远程节点数据=============================================================================== */ class RemoteNodesData { private: std::unordered_map<std::string, std::shared_ptr<NodeSocketContextData>> m_remoteNodes; std::unordered_map<SOCKET,std::string> m_remoteNodesSocketIPPortStrMap; std::mutex m_remoteNodesMx; public: RemoteNodesData() {} ~RemoteNodesData() {} /** * 是否存在节点 * @param socket * @return */ bool hashNode(SOCKET& socket) { auto it = this->m_remoteNodesSocketIPPortStrMap.find(socket); return it != this->m_remoteNodesSocketIPPortStrMap.end(); } /** * 是否存在该节点 * @param IPPortStr * @return */ bool hashNode(const std::string& IPPortStr) { auto it = this->m_remoteNodes.find(IPPortStr); return it != this->m_remoteNodes.end(); } /** * * @param IPPortStr * @return */ SOCKET getNodeSocketByIPPortStr(std::string& IPPortStr) { auto it = this->m_remoteNodes.find(IPPortStr); if (it == this->m_remoteNodes.end()) return INVALID_SOCKET; return it->second->getSocket(); } /** * * @param remoteNode */ void addRemoteNode(std::shared_ptr<NodeSocketContextData>& remoteNode) { std::lock_guard<std::mutex> lock(this->m_remoteNodesMx); if (remoteNode.get() == nullptr || remoteNode.get()->getIPPortStr() == "" || remoteNode.get()->getSocket() == INVALID_SOCKET) return; this->m_remoteNodesSocketIPPortStrMap[remoteNode.get()->getSocket()] = remoteNode.get()->getIPPortStr(); this->m_remoteNodes[remoteNode.get()->getIPPortStr()] = remoteNode; } /** * 调用此函数会清除Socket * @param ipPortStr */ void deleteRemeoteNodesByIPPortStr(const std::string& ipPortStr) { std::lock_guard<std::mutex> lock(this->m_remoteNodesMx); if (ipPortStr == "" || ipPortStr.empty()) return; auto it = this->m_remoteNodes.find(ipPortStr); //不存在该节点数据 if (it == this->m_remoteNodes.end()) return; SOCKET socket = it->second.get()->getSocket(); this->m_remoteNodesSocketIPPortStrMap.erase(socket); this->m_remoteNodes.erase(it); //清除Socket closesocket(socket); } /** * 调用此函数会清除Socket * @param socket */ void deleteRemeoteNodesBySocket(SOCKET& socket) { std::lock_guard<std::mutex> lock(this->m_remoteNodesMx); if (socket == INVALID_SOCKET) return; auto it = this->m_remoteNodesSocketIPPortStrMap.find(socket); //不存在该节点数据 if (it == this->m_remoteNodesSocketIPPortStrMap.end()) return; this->m_remoteNodes.erase(it->second); this->m_remoteNodesSocketIPPortStrMap.erase(it); //清除Socket closesocket(socket); } /** * * @param context */ void addRemeoteNodeReadContext(std::shared_ptr<IOContext> & context) { std::lock_guard<std::mutex> lock(this->m_remoteNodesMx); if (context.get()->m_socket == INVALID_SOCKET) return; auto it = this->m_remoteNodesSocketIPPortStrMap.find(context->m_socket); //该socket对应的节点数据不存在 if (it == this->m_remoteNodesSocketIPPortStrMap.end()) return; this->m_remoteNodes[it->second]->setReadContext(context); } /** * * @param context */ void deleteRemeoteNodeReadContext(std::shared_ptr<IOContext> & context) { std::lock_guard<std::mutex> lock(this->m_remoteNodesMx); if (context.get()->m_socket == INVALID_SOCKET) return; auto it = this->m_remoteNodesSocketIPPortStrMap.find(context->m_socket); //该socket对应的节点数据不存在 if (it == this->m_remoteNodesSocketIPPortStrMap.end()) return; this->m_remoteNodes[it->second]->deleteReadContext(context); } /** * * @param context */ void addRemeoteNodeWriteContext(std::shared_ptr<IOContext> & context) { std::lock_guard<std::mutex> lock(this->m_remoteNodesMx); if (context.get()->m_socket == INVALID_SOCKET) return; auto it = this->m_remoteNodesSocketIPPortStrMap.find(context->m_socket); //该socket对应的节点数据不存在 if (it == this->m_remoteNodesSocketIPPortStrMap.end()) return; this->m_remoteNodes[it->second]->setWriteContext(context); } /** * * @param context */ void deleteRemeoteNodeWriteContext(std::shared_ptr<IOContext> & context) { std::lock_guard<std::mutex> lock(this->m_remoteNodesMx); if (context.get()->m_socket == INVALID_SOCKET) return; auto it = this->m_remoteNodesSocketIPPortStrMap.find(context->m_socket); //该socket对应的节点数据不存在 if (it == this->m_remoteNodesSocketIPPortStrMap.end()) return; this->m_remoteNodes[it->second]->deleteWriteContext(context); } /** * 遍历处理节点 * @param handleFun 传入的处理函数 返回false退出循环 返回true退出循环 */ void traverseHandleNodes(std::function<bool(std::shared_ptr<NodeSocketContextData>&)> handleFun) { std::lock_guard<std::mutex> lock(this->m_remoteNodesMx); for (auto& pair : this->m_remoteNodes) if (!handleFun(pair.second)) return; } /** * 清除数据 */ void clear() { this->m_remoteNodesSocketIPPortStrMap.clear(); this->m_remoteNodes.clear(); } }; /** * 预备服务器信息============================================================================== */ class PrePareServerInformation { public: int m_port; const char* m_IP; public: /** * * @param IP * @param port */ PrePareServerInformation(const char* IP , int port) { this->m_port = port; this->m_IP = IP; } ~PrePareServerInformation(){} }; /** * 节点接口================================================================================ */ class NodeInterface { protected: HANDLE m_iocpHandle; std::vector<std::thread> m_workerThreads; size_t m_maxConcurrentThreads = 1;//线程数 std::atomic<bool> m_running; static std::unordered_map<NODE_TYPE,const std::string> ipPortStrMap; //回调函数 std::function<void(SOCKET)> m_onConnected = NULL; std::function<void(SOCKET)> m_onDisconnected = NULL; std::function<void(SOCKET,const char*,int)> m_onDataReceived = NULL; //远程节点数据 RemoteNodesData m_remoteNodesData; static std::mutex handleErrorMx; //注册IOContext上下文Map,需要new新IOcontext时的登记列表 std::unordered_map<void*,std::shared_ptr<IOContext>> m_registerIOContexts; std::mutex m_registerIOContextsMx; protected: /** * * @return */ virtual NODE_TYPE getNodeType() {} /** * 跳转处理函数 * @param context * @param bytesTransferred */ virtual void jumpHandleFun(std::shared_ptr<IOContext>& context ,DWORD bytesTransferred ){} /** * 启动回调函数 */ virtual void startCallBackFun() {} /** * 关闭回调函数 */ virtual void shutDownCallBackFun() {} /** * * @return */ const std::string& getNodeTypeStr() { return NodeInterface::ipPortStrMap[this->getNodeType()]; } /** * 注册上下文 * @param socket * @param IOOperation */ void registerIOContext(SOCKET socket, IO_OPERATION IOOperation) { if (socket == INVALID_SOCKET) return; IOContext* context = new IOContext(); context->m_operation = IOOperation; context->m_socket = socket; this->m_registerIOContexts[&context->m_overlapped] = std::shared_ptr<IOContext>(context); } /** * 注册上下文 * @return */ std::shared_ptr<IOContext> registerIOContext() { std::shared_ptr<IOContext> context(new IOContext()); std::cout << "Register IOContext = " << &context<< std::endl; this->m_registerIOContexts[&context.get()->m_overlapped] = context; return context; } /** * * @param context */ void registerIOContext(std::shared_ptr<IOContext>& context) { this->m_registerIOContexts[&context.get()->m_overlapped] = context; } /** * * @param overlappedAddress */ void deleteRegisterIOContext(void* overlappedAddress) { auto it = this->m_registerIOContexts.find(overlappedAddress); if (it == this->m_registerIOContexts.end()) return; std::lock_guard<std::mutex> lock(this->m_registerIOContextsMx); std::cout << "Delete RegisterIOContext" << std::endl; it = this->m_registerIOContexts.find(overlappedAddress); if (it == this->m_registerIOContexts.end()) return; this->m_registerIOContexts.erase(it); } /** * * @param overlappedAddress * @return */ std::shared_ptr<IOContext> getRegisterIOContext(void* overlappedAddress) { auto it = this->m_registerIOContexts.find(overlappedAddress); if (it == this->m_registerIOContexts.end()) return nullptr; return it->second; } /** * * @param message * @param fatal */ void handleError(const char* message, bool fatal = false) { std::lock_guard<std::mutex> lock(NodeInterface::handleErrorMx); DWORD error = WSAGetLastError(); std::string nodeTypeStr = this->getNodeTypeStr()+":"; std::string str= error == 0 ? "Unknown" : strerror( error ); std::cout <<nodeTypeStr<< message << " failed with error" << error << ":" << str << std::endl; if (!fatal) return; WSACleanup(); ExitProcess(1); } /** * * @param socket */ void closeSocket(SOCKET& socket) { if (this->m_onDisconnected != NULL) this->m_onDisconnected(socket); this->m_remoteNodesData.deleteRemeoteNodesBySocket(socket); } /** * * @param context */ void postRead(std::shared_ptr<IOContext>& context) { DWORD flags = 0; DWORD bytesRead; context.get()->m_operation = IO_OPERATION::READ; if (WSARecv( context.get()->m_socket, &context.get()->m_dataBuf, 1, &bytesRead, &flags, &context.get()->m_overlapped, NULL) == SOCKET_ERROR) { DWORD error = WSAGetLastError(); if (error != ERROR_IO_PENDING) { this->handleError("WSARecv"); this->closeSocket(context.get()->m_socket); } } } /** * * @param IP * @param port * @param data * @param length */ void postWrite(const char* IP, int port, const char* data, int length) { std::string IPPortStr = std::string(IP) + ":" + std::to_string(port); //不存在该节点 if (!this->m_remoteNodesData.hashNode(IPPortStr)) return; //注册上下文 std::shared_ptr<IOContext> writeContext = this->registerIOContext(); writeContext.get()->m_socket = this->m_remoteNodesData.getNodeSocketByIPPortStr(IPPortStr); writeContext.get()->m_operation = IO_OPERATION::WRITE; memcpy(writeContext.get()->m_buffer,data,length); writeContext.get()->m_dataBuf.len = length; DWORD bytesSent; if (WSASend( writeContext.get()->m_socket, &writeContext.get()->m_dataBuf, 1, &bytesSent, 0, &writeContext.get()->m_overlapped, NULL) == SOCKET_ERROR) { DWORD error = WSAGetLastError(); if (error != WSA_IO_PENDING) { handleError("WSASend"); this->closeSocket(writeContext.get()->m_socket); return; } } //加入节点读写上下文 this->m_remoteNodesData.addRemeoteNodeWriteContext(writeContext); } /** * * @param socket * @param data * @param length */ void postWrite(SOCKET socket, const char* data, int length) { //Socket为空 if (socket == INVALID_SOCKET) return; //注册上下文 std::shared_ptr<IOContext> writeContext = this->registerIOContext(); writeContext.get()->m_socket = socket; writeContext.get()->m_operation = IO_OPERATION::WRITE; memcpy(writeContext.get()->m_buffer,data,length); writeContext.get()->m_dataBuf.len = length; DWORD bytesSent; if (WSASend( writeContext.get()->m_socket, &writeContext.get()->m_dataBuf, 1, &bytesSent, 0, &writeContext.get()->m_overlapped, NULL) == SOCKET_ERROR) { DWORD error = WSAGetLastError(); if (error != WSA_IO_PENDING) { handleError("WSASend"); this->closeSocket(writeContext.get()->m_socket); this->deleteRegisterIOContext(&writeContext.get()->m_overlapped); return; } } //加入节点读写上下文 this->m_remoteNodesData.addRemeoteNodeWriteContext(writeContext); } /** * 处理读操作 * @param context * @param bytesTransfered */ void handleRead(std::shared_ptr<IOContext>& context, DWORD bytesTransferred) { //客户端断开连接 if (bytesTransferred == 0) { this->closeSocket(context.get()->m_socket); return; } if (this->m_onDataReceived != NULL) { this->m_onDataReceived(context.get()->m_socket,context.get()->m_buffer,bytesTransferred); } //继续投递读操作 this->postRead(context); } /** * * @param context * @param bytesTransferred */ void handleWrite(std::shared_ptr<IOContext>& context, DWORD bytesTransferred) { if (!context) return; this->deleteRegisterIOContext(&context.get()->m_overlapped); //写入完成,释放上下文 this->m_remoteNodesData.deleteRemeoteNodeWriteContext(context); //context.reset(); } /** * 工作线程 */ void workerThread() { DWORD bytesTransferred; ULONG_PTR completionKey; LPOVERLAPPED overlapped; while (this->m_running) { std::cout << this->getNodeTypeStr() << "等待I/O端口返回=========" << std::endl; BOOL result = GetQueuedCompletionStatus(this->m_iocpHandle,&bytesTransferred,&completionKey,&overlapped,INFINITE); std::cout << this->getNodeTypeStr() << "I/O端口返回数据=========" << std::endl; //处理错误或者超时 if (!result) { DWORD error = GetLastError(); //超时 if (overlapped == NULL && error == WAIT_TIMEOUT) continue; if (error != WAIT_TIMEOUT && error != ERROR_NETNAME_DELETED) this->handleError("GetQueuedCompletionStatus"); //客户端断开连接 if (overlapped != NULL) { //获取注册的上下文 std::shared_ptr<IOContext> context = this->getRegisterIOContext(overlapped); if (context == nullptr) continue; if (context.get()->m_operation == IO_OPERATION::READ || context.get()->m_operation == IO_OPERATION::WRITE) this->closeSocket(context.get()->m_socket); } continue; } //收到退出信号 if (bytesTransferred == 0 && completionKey == NULL && overlapped == NULL) break; //获取注册的上下文 std::shared_ptr<IOContext> context = this->getRegisterIOContext(overlapped); if (context == nullptr) continue; this->jumpHandleFun(context,bytesTransferred); } } public: void setOnConnected(std::function<void(SOCKET)> callback) {this->m_onConnected = callback;} void setOnDisconnected(std::function<void(SOCKET)> callback) {this->m_onDisconnected = callback;} void setOnDataReceived(std::function<void(SOCKET,const char* ,int)> callback){this->m_onDataReceived =callback;} /** * 启动节点 */ void start() { //已经启动 if (this->m_running) return; this->startCallBackFun(); //启动失败 if (!this->m_running) return; std::cout << this->getNodeTypeStr() << ":startCallBackFun is runned!" << std::endl; //join线程 for (std::thread& worker : this->m_workerThreads) if (worker.joinable())worker.join(); } /** * 关闭节点 */ void shutDown() { if (!this->m_running) return; this->m_running = false; //关闭所有远程节点的Socket auto closeSocketFun = [](std::shared_ptr<NodeSocketContextData>& remoteNode)->bool { if(remoteNode->getSocket() == INVALID_SOCKET) return true; //关闭socket closesocket(remoteNode.get()->getSocket()); return true; }; //封装遍历处理函数 //传入遍历处理远程节点数据函数 this->m_remoteNodesData.traverseHandleNodes(closeSocketFun); //清除远程节点数据 this->m_remoteNodesData.clear(); //关闭IOCP句柄 if (this->m_iocpHandle != NULL) { CloseHandle(this->m_iocpHandle); this->m_iocpHandle = NULL; } //调用节点关闭回调函数 this->shutDownCallBackFun(); //唤醒所有工作线程, 发送退出信号 for (size_t i = 0; i < this->m_workerThreads.size();i++) PostQueuedCompletionStatus(this->m_iocpHandle,0,(ULONG_PTR)NULL,NULL); //等待所有工作线程 for (std::thread& thred:this->m_workerThreads) if (thred.joinable()) thred.join(); this->m_workerThreads.clear(); WSACleanup(); } /** * * @param IP * @param port * @param message */ void send(const char* IP, int port, const char* message) { if (!this->m_running) return; this->postWrite(IP,port,message,strlen(message)); } /** * * @param socket * @param message */ void send(SOCKET socket, const char* message) { if (!this->m_running) return; this->postWrite(socket,message,strlen(message)); } NodeInterface() {} virtual ~NodeInterface() {} }; /** * 服务器================================================================================== */ class Server: public NodeInterface { private: int m_port; SOCKET m_listenSocket; //static constexpr不需要类外定义 static constexpr NODE_TYPE nodeType = IOCP2::NODE_TYPE::SERVER; private: /** * 投递新的接收连接操作 */ void postAccept() { SOCKET clientSocket = WSASocket(AF_INET,SOCK_STREAM,IPPROTO_TCP,NULL,0,WSA_FLAG_OVERLAPPED); if (clientSocket == INVALID_SOCKET) { this->handleError("WSASocket for accept"); return; } //注册上下文,防止上下文被析构 std::shared_ptr<IOContext> context = this->registerIOContext(); context.get()->m_operation = IO_OPERATION::ACCEPT; context.get()->m_socket = clientSocket; DWORD bytesReceived; //接收连接 if (!WSAIoctlTool::acceptExFun( this->m_listenSocket, clientSocket, context.get()->m_buffer, &bytesReceived, &context.get()->m_overlapped)) { DWORD error = WSAGetLastError(); if (error == ERROR_IO_PENDING) return; this->handleError("AcceptEx"); //删除注册上下文 this->deleteRegisterIOContext(&context.get()->m_overlapped); closesocket(clientSocket); } } /** * 处理接收连接 * @param context */ void handleAccept(std::shared_ptr<IOContext>& context) { //调用连接到客户端回调函数 if (this->m_onConnected != NULL) this->m_onConnected(context.get()->m_socket); //将新socket关联到IOCP if (CreateIoCompletionPort( (HANDLE)context.get()->m_socket, this->m_iocpHandle, (ULONG_PTR)context.get()->m_socket, 0) == NULL) { this->handleError("CreateIoCompletionPort for client socket"); closesocket(context.get()->m_socket); return; } //获取客户端信息 std::string IPPortStr = WSAIoctlTool::GetAcceptExSockaddrsFun(context.get()->m_socket,context.get()->m_buffer); //获取到的数据为空 if (IPPortStr.empty() || IPPortStr == "") { this->handleError("GetAcceptExSockaddrs"); return; } //加入远程节点数据 std::shared_ptr<NodeSocketContextData> nodeContextData(new NodeSocketContextData(IPPortStr,context.get()->m_socket)); this->m_remoteNodesData.addRemoteNode(nodeContextData); /*//注册上下文 this->registerIOContext(context);*/ //为新连接投递读操作 this->postRead(context); //将读操作的上下文加入节点管理,防止提前释放 //this->m_remoteNodesData.addRemeoteNodeReadContext(context); //继续投递接收连接操作 this->postAccept(); } /** * * @param context * @param bytesTransferred */ void jumpHandleFun(std::shared_ptr<IOContext> &context, DWORD bytesTransferred) override { switch (context->m_operation) { case IO_OPERATION::ACCEPT: this->handleAccept(context); break; case IO_OPERATION::READ: this->handleRead(context, bytesTransferred); break; case IO_OPERATION::WRITE: this->handleWrite(context, bytesTransferred); break; } } /** * 启动回调函数 */ void startCallBackFun() override { //初始化WinSocket WSADATA WSAData; if (WSAStartup(MAKEWORD(2,2), &WSAData) != 0) { this->handleError("WSAStartup",true); return; } //创建监听socket this->m_listenSocket = WSASocket(AF_INET,SOCK_STREAM,IPPROTO_TCP,NULL,0,WSA_FLAG_OVERLAPPED); if (this->m_listenSocket == INVALID_SOCKET) { this->handleError("WSASocket",true); return; } //绑定地址 sockaddr_in serverAddr; serverAddr.sin_family = AF_INET; serverAddr.sin_addr.s_addr = htonl(INADDR_ANY); serverAddr.sin_port = htons(this->m_port); if (bind(this->m_listenSocket,(sockaddr*)&serverAddr,sizeof(serverAddr)) == SOCKET_ERROR) { this->handleError("bind",true); closesocket(this->m_listenSocket); return; } //创建IOCP this->m_iocpHandle =CreateIoCompletionPort(INVALID_HANDLE_VALUE,NULL,0,0); if (this->m_iocpHandle == NULL) { handleError("CreateIoCompletionPort",true); closesocket(this->m_listenSocket); return; } //将监听socket关联到IOCP if (CreateIoCompletionPort( (HANDLE)this->m_listenSocket, this->m_iocpHandle, (ULONG_PTR)this->m_listenSocket, 0) == NULL) { this->handleError("CreateIoCompletionPort for listen socket",true); closesocket(this->m_listenSocket); CloseHandle(this->m_iocpHandle); return; } //开始监听 if (listen(this->m_listenSocket,SOMAXCONN) == SOCKET_ERROR) { this->handleError("listen",true); closesocket(this->m_listenSocket); CloseHandle(this->m_iocpHandle); return; } //创建工作线程 this->m_running = true; for (size_t i = 0; i < this->m_maxConcurrentThreads; i++) this->m_workerThreads.emplace_back(&IOCP2::Server::workerThread,this); //投递初始化接收连接操作 this->postAccept(); } /** * 关闭回调函数 */ void shutDownCallBackFun() override { if (this->m_listenSocket == INVALID_SOCKET) return; closesocket(this->m_listenSocket); this->m_listenSocket = INVALID_SOCKET; } public: /** * * @return */ NODE_TYPE getNodeType() override { return Server::nodeType; } /** * * @param port */ Server(int port) { this->m_port = port; this->m_running = false; } ~Server() override { this->shutDown(); } }; /** * 客户端================================================================================== */ class Client: public NodeInterface { private: static constexpr NODE_TYPE nodeType = NODE_TYPE::CLIENT; //预备服务器信息 std::queue<PrePareServerInformation> m_prePareServers; private: /** * * @param context * @param bytesTransferred */ void jumpHandleFun(std::shared_ptr<IOContext> &context, DWORD bytesTransferred) override { switch (context.get()->m_operation) { case IO_OPERATION::READ: this->handleRead(context,bytesTransferred); break; case IO_OPERATION::WRITE: this->handleWrite(context,bytesTransferred); break; default: break; } } /** * * @param prePareServer */ void connectServer(PrePareServerInformation& prePareServer) { //创建Socket SOCKET serverSocket = WSASocket(AF_INET,SOCK_STREAM,IPPROTO_TCP,NULL,0,WSA_FLAG_OVERLAPPED); if (serverSocket == INVALID_SOCKET) { this->handleError("WSASocket",true); return; } //解析服务器地址 sockaddr_in serverAddr; serverAddr.sin_family = AF_INET; serverAddr.sin_port = htons(prePareServer.m_port); inet_pton(AF_INET,prePareServer.m_IP,&serverAddr.sin_addr); //连接服务器 if (connect(serverSocket,(sockaddr*)&serverAddr,sizeof(serverAddr)) == SOCKET_ERROR) { DWORD errror = WSAGetLastError(); this->handleError("connect"); closesocket(serverSocket); return; } //将socket关联到IOCP if (CreateIoCompletionPort( (HANDLE)serverSocket, this->m_iocpHandle, (ULONG_PTR)serverSocket, 0) == NULL) { this->handleError("CreateIoCompletionPort"); closesocket(serverSocket); CloseHandle(this->m_iocpHandle); return; } this->m_running = true; //加入远程节点 std::string IPPortStr = std::string(prePareServer.m_IP)+":"+std::to_string(prePareServer.m_port); std::shared_ptr<NodeSocketContextData> nodeContextData(new NodeSocketContextData(IPPortStr,serverSocket)); this->m_remoteNodesData.addRemoteNode(nodeContextData); //注册上下文 std::shared_ptr<IOContext> context = this->registerIOContext(); context.get()->m_socket = serverSocket; //投递读操作 this->postRead(context); //调用连接回调函数 if (this->m_onConnected != NULL) this->m_onConnected(serverSocket); } /** * 启动回到调函数 */ void startCallBackFun() override { //初始化WinSocket WSAData wsaData; if (WSAStartup(MAKEWORD(2,2), &wsaData) != 0) { this->handleError("WSAStartup",true); return; } //创建IOCP this->m_iocpHandle = CreateIoCompletionPort(INVALID_HANDLE_VALUE,NULL,0,0); if (this->m_iocpHandle == NULL) { this->handleError("CreateIoCompletionPort",true); return; } //连接服务器 while (!this->m_prePareServers.empty()) { this->connectServer(this->m_prePareServers.front()); std::cout << "client:connectServer......."<< std::endl; this->m_prePareServers.pop(); } // 仅当有有效连接时才创建工作线程 if (this->m_running) for (size_t i = 0; i < this->m_maxConcurrentThreads ; i++) this->m_workerThreads.emplace_back(&IOCP2::Client::workerThread,this); } /** * 关闭回调函数 */ void shutDownCallBackFun() override { } public: /** * * @param IP * @param port */ void addServer(const char* IP, int port) { if (this->m_running) return; this->m_prePareServers.push(std::move(PrePareServerInformation(IP,port))); } /** * * @return */ NODE_TYPE getNodeType() override { return Client::nodeType; } Client() { this->m_running = false; } ~Client() override { this->shutDown(); } }; }; GUID IOCP2::WSAIoctlTool::guidAcceptEx = WSAID_ACCEPTEX; GUID IOCP2::WSAIoctlTool::guidGetAcceptExSockaddrs = WSAID_GETACCEPTEXSOCKADDRS; std::unordered_map<IOCP2::NODE_TYPE,const std::string> IOCP2::NodeInterface::ipPortStrMap = { {IOCP2::NODE_TYPE::SERVER,"server"}, {IOCP2::NODE_TYPE::CLIENT,"client"} }; std::mutex IOCP2::NodeInterface::handleErrorMx;#include <iostream> #include <string> #include <thread> #include "test/IOCP2.hpp" #pragma comment(lib, "ws2_32.lib") #pragma comment(lib, "mswsock.lib") int main() { IOCP2::Server server(1001); IOCP2::Client client; client.addServer("127.0.0.1",1001); std::thread serverThread([&]() { server.setOnDataReceived([&](SOCKET socket, const char* buffer, int len) { std::cout << "server received message: " << buffer << std::endl; const char* message = "hellow client! server received message"; client.send(socket,message); }); server.start(); }); std::this_thread::sleep_for(std::chrono::seconds(3)); std::thread clientThread([&]() { client.setOnConnected([&](SOCKET socket) { std::cout << "client: connect socket = " << socket << std::endl; }); client.setOnDataReceived([&](SOCKET socket, const char* buffer, int len) { std::cout << "client received message: " << buffer << std::endl; client.send(socket,"hellow client! server received message"); }); client.start(); }); std::this_thread::sleep_for(std::chrono::seconds(3)); std::thread clientSendThread([&]() { const char* buffer = "Hello World!"; client.send("127.0.0.1",1001,buffer); }); serverThread.join(); clientThread.join(); clientSendThread.join(); return 0; }运行之后会内存泄露,问题初步定位是handWrite
10-17
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值