#pragma once
#include <windows.h>
#include <winsock2.h>
#include <ws2tcpip.h>
#include <mswsock.h>
#include <vector>
#include <memory>
#include <functional>
#include <string>
#include <iostream>
#include <mutex>
#include <unordered_map>
#include <atomic>
#include <thread>
#include <queue>
#pragma comment(lib, "ws2_32.lib")
#pragma comment(lib, "mswsock.lib")
namespace IOCP2 {
//定义常量
constexpr int BUFFER_SIZE = 4096;
//定义IO操作类型
enum IO_OPERATION{
ACCEPT,
READ,
WRITE
};
//节点类型
enum NODE_TYPE {
NODE,
SERVER,
CLIENT
};
/**
* I/O上下文数据=============================================================================
*/
class IOContext {
public:
//该I/O操作唯一标识符,相同的I/O操作,参数的地址相同,通过 GetQueuedCompletionStatus返回的 LPOVERLAPPED可反向获取 IOContext:
OVERLAPPED m_overlapped;
//描述数据缓冲区的信息,用于 WSARecv和 WSASend。
WSABUF m_dataBuf;
//I/O操作类型
IO_OPERATION m_operation;
//缓冲区数据
char m_buffer[BUFFER_SIZE];
SOCKET m_socket;
public:
IOContext() {
ZeroMemory(&this->m_overlapped,sizeof(OVERLAPPED));
this->m_dataBuf.len = BUFFER_SIZE;
this->m_dataBuf.buf = this->m_buffer;
this->m_operation = IO_OPERATION::READ;
this->m_socket = INVALID_SOCKET;
}
~IOContext(){}
/**
*
* @param other
* @return
*/
bool operator == (const IOContext& other) {
return &this->m_overlapped == &other.m_overlapped;
}
/**
*
* @param other
* @return
*/
bool operator != (const IOContext& other) {
return &this->m_overlapped != &other.m_overlapped;
}
};
/**
* 套接字控制工具,用于调用Windows扩展功能=======================================================
*/
class WSAIoctlTool {
private:
//服务器接收链接函数地址
static GUID guidAcceptEx;
//获取接受到的链接地址的函数的地址
static GUID guidGetAcceptExSockaddrs;
private:
/**
*
* @tparam FUNC_POINTERE
* @param targetSocket
* @param funcGuID
* @param funcPointer
* @return
*/
template<typename FUNC_POINTERE>
static bool WSAIoctlFun(SOCKET& targetSocket,GUID& funcGuID, FUNC_POINTERE& funcPointer) {
DWORD bytesReceived;
return WSAIoctl(
targetSocket,
SIO_GET_EXTENSION_FUNCTION_POINTER,
&funcGuID,
sizeof(funcGuID),
&funcPointer,
sizeof(funcPointer),
&bytesReceived,
nullptr,
nullptr) != SOCKET_ERROR;
}
public:
/**
*异步接收套接字连接
* @param sListenSocket 监听 socket(bind+ listen后的 socket)
* @param sAcceptSocket 预先创建的 socket,用于接受新连接
* @param lpOutputBuffer 接收数据的缓冲区,同时包含本地和远程地址信息
* @param lpdwBytesReceived 返回实际接收的字节数(仅在同步操作时有效)
* @param lpOverlapped 用于异步操作的 OVERLAPPED结构
* @return
*/
static bool acceptExFun(SOCKET sListenSocket,SOCKET sAcceptSocket,PVOID lpOutputBuffer, LPDWORD lpdwBytesReceived,LPOVERLAPPED lpOverlapped) {
LPFN_ACCEPTEX lpAcceptEx = nullptr;
if (!WSAIoctlTool::WSAIoctlFun(sListenSocket,WSAIoctlTool::guidAcceptEx,lpAcceptEx)) return false;
return lpAcceptEx(sListenSocket,sAcceptSocket,lpOutputBuffer,0,sizeof(sockaddr_in)+16,sizeof(sockaddr_in)+16,lpdwBytesReceived,lpOverlapped) != FALSE;
}
/**
*获取连接的地址
* @param sAcceptSocket //AcceptEx的套接字
* @param lpOutputBuffer //AcceptEx 返回的缓冲区
* @return ip:port
*/
static std::string GetAcceptExSockaddrsFun(SOCKET sAcceptSocket, PVOID lpOutputBuffer) {
LPFN_GETACCEPTEXSOCKADDRS lpGetAcceptExSockaddrs = nullptr;
//获取客户端地址信息
sockaddr_in* localAddr = nullptr;
sockaddr_in* remoteAddr = nullptr;
int localAddrLen = 0;
int remoteAddrLen = 0;
constexpr const int addressBufferSize = sizeof(sockaddr_in) + 16;
if (!WSAIoctlTool::WSAIoctlFun(sAcceptSocket,WSAIoctlTool::guidGetAcceptExSockaddrs,lpGetAcceptExSockaddrs)) return NULL;
//使用使用GetAcceptExSockaddrs获取地址信息
lpGetAcceptExSockaddrs(lpOutputBuffer,0,addressBufferSize,
addressBufferSize,
(sockaddr**)&localAddr,
&localAddrLen,
(sockaddr**)&remoteAddr,
&remoteAddrLen);
char clientIP[INET_ADDRSTRLEN];
inet_ntop(AF_INET,&(remoteAddr->sin_addr),clientIP,INET_ADDRSTRLEN);
unsigned short clientPort = ntohs(remoteAddr->sin_port);
return std::string(clientIP,INET_ADDRSTRLEN)+":"+std::to_string(clientPort);
}
};
class RemoteNodesData;
/**
* 节点Socket上下文数据=======================================================================
*/
class NodeSocketContextData {
private:
SOCKET m_socket;
std::string m_IPPortStr;
std::shared_ptr<IOContext> m_readContext;
//void指针存储overlapped地址,overlapped地址一样说明是同一个I/O操作
std::unordered_map<void*,std::shared_ptr<IOContext>> m_writeContexts;
private:
/**
*
* @param context
*/
void deleteReadContext(std::shared_ptr<IOContext>& context) {
//不是同一个cosntext
if (&this->m_readContext.get()->m_overlapped != &context->m_overlapped) return;
this->m_readContext.reset();
}
/**
*
* @param context
*/
void setWriteContext(std::shared_ptr<IOContext>& context) {
this->m_writeContexts[&context->m_overlapped] = context;
}
/**
*
* @param context
*/
void setReadContext(std::shared_ptr<IOContext>& context) {
this->m_readContext = context;
}
/**
*
* @param context
*/
void deleteWriteContext(std::shared_ptr<IOContext>& context) {
this->m_writeContexts.erase(&context->m_overlapped);
}
public:
NodeSocketContextData() {
this->m_socket = INVALID_SOCKET;
this->m_IPPortStr = "";
this->m_readContext = nullptr;
}
/**
*
* @param ipPortStr
* @param socket
*/
NodeSocketContextData(const std::string& ipPortStr,SOCKET& socket) {
this->m_socket = socket;
this->m_IPPortStr = ipPortStr;
}
~NodeSocketContextData() {
this->m_writeContexts.clear();
}
/**
*
* @return
*/
std::string& getIPPortStr() {
return this->m_IPPortStr;
}
/**
*
* @return
*/
SOCKET& getSocket() {
return this->m_socket;
}
/**
*
* @param socket
*/
void setSocket(SOCKET& socket) {
this->m_socket = socket;
}
public:
friend class RemoteNodesData;
};
/**
* 远程节点数据===============================================================================
*/
class RemoteNodesData {
private:
std::unordered_map<std::string, std::shared_ptr<NodeSocketContextData>> m_remoteNodes;
std::unordered_map<SOCKET,std::string> m_remoteNodesSocketIPPortStrMap;
std::mutex m_remoteNodesMx;
public:
RemoteNodesData() {}
~RemoteNodesData() {}
/**
* 是否存在节点
* @param socket
* @return
*/
bool hashNode(SOCKET& socket) {
auto it = this->m_remoteNodesSocketIPPortStrMap.find(socket);
return it != this->m_remoteNodesSocketIPPortStrMap.end();
}
/**
* 是否存在该节点
* @param IPPortStr
* @return
*/
bool hashNode(const std::string& IPPortStr) {
auto it = this->m_remoteNodes.find(IPPortStr);
return it != this->m_remoteNodes.end();
}
/**
*
* @param IPPortStr
* @return
*/
SOCKET getNodeSocketByIPPortStr(std::string& IPPortStr) {
auto it = this->m_remoteNodes.find(IPPortStr);
if (it == this->m_remoteNodes.end()) return INVALID_SOCKET;
return it->second->getSocket();
}
/**
*
* @param remoteNode
*/
void addRemoteNode(std::shared_ptr<NodeSocketContextData>& remoteNode) {
std::lock_guard<std::mutex> lock(this->m_remoteNodesMx);
if (remoteNode.get() == nullptr
|| remoteNode.get()->getIPPortStr() == ""
|| remoteNode.get()->getSocket() == INVALID_SOCKET) return;
this->m_remoteNodesSocketIPPortStrMap[remoteNode.get()->getSocket()] = remoteNode.get()->getIPPortStr();
this->m_remoteNodes[remoteNode.get()->getIPPortStr()] = remoteNode;
}
/**
* 调用此函数会清除Socket
* @param ipPortStr
*/
void deleteRemeoteNodesByIPPortStr(const std::string& ipPortStr) {
std::lock_guard<std::mutex> lock(this->m_remoteNodesMx);
if (ipPortStr == "" || ipPortStr.empty()) return;
auto it = this->m_remoteNodes.find(ipPortStr);
//不存在该节点数据
if (it == this->m_remoteNodes.end()) return;
SOCKET socket = it->second.get()->getSocket();
this->m_remoteNodesSocketIPPortStrMap.erase(socket);
this->m_remoteNodes.erase(it);
//清除Socket
closesocket(socket);
}
/**
* 调用此函数会清除Socket
* @param socket
*/
void deleteRemeoteNodesBySocket(SOCKET& socket) {
std::lock_guard<std::mutex> lock(this->m_remoteNodesMx);
if (socket == INVALID_SOCKET) return;
auto it = this->m_remoteNodesSocketIPPortStrMap.find(socket);
//不存在该节点数据
if (it == this->m_remoteNodesSocketIPPortStrMap.end()) return;
this->m_remoteNodes.erase(it->second);
this->m_remoteNodesSocketIPPortStrMap.erase(it);
//清除Socket
closesocket(socket);
}
/**
*
* @param context
*/
void addRemeoteNodeReadContext(std::shared_ptr<IOContext> & context) {
std::lock_guard<std::mutex> lock(this->m_remoteNodesMx);
if (context.get()->m_socket == INVALID_SOCKET) return;
auto it = this->m_remoteNodesSocketIPPortStrMap.find(context->m_socket);
//该socket对应的节点数据不存在
if (it == this->m_remoteNodesSocketIPPortStrMap.end()) return;
this->m_remoteNodes[it->second]->setReadContext(context);
}
/**
*
* @param context
*/
void deleteRemeoteNodeReadContext(std::shared_ptr<IOContext> & context) {
std::lock_guard<std::mutex> lock(this->m_remoteNodesMx);
if (context.get()->m_socket == INVALID_SOCKET) return;
auto it = this->m_remoteNodesSocketIPPortStrMap.find(context->m_socket);
//该socket对应的节点数据不存在
if (it == this->m_remoteNodesSocketIPPortStrMap.end()) return;
this->m_remoteNodes[it->second]->deleteReadContext(context);
}
/**
*
* @param context
*/
void addRemeoteNodeWriteContext(std::shared_ptr<IOContext> & context) {
std::lock_guard<std::mutex> lock(this->m_remoteNodesMx);
if (context.get()->m_socket == INVALID_SOCKET) return;
auto it = this->m_remoteNodesSocketIPPortStrMap.find(context->m_socket);
//该socket对应的节点数据不存在
if (it == this->m_remoteNodesSocketIPPortStrMap.end()) return;
this->m_remoteNodes[it->second]->setWriteContext(context);
}
/**
*
* @param context
*/
void deleteRemeoteNodeWriteContext(std::shared_ptr<IOContext> & context) {
std::lock_guard<std::mutex> lock(this->m_remoteNodesMx);
if (context.get()->m_socket == INVALID_SOCKET) return;
auto it = this->m_remoteNodesSocketIPPortStrMap.find(context->m_socket);
//该socket对应的节点数据不存在
if (it == this->m_remoteNodesSocketIPPortStrMap.end()) return;
this->m_remoteNodes[it->second]->deleteWriteContext(context);
}
/**
* 遍历处理节点
* @param handleFun 传入的处理函数 返回false退出循环 返回true退出循环
*/
void traverseHandleNodes(std::function<bool(std::shared_ptr<NodeSocketContextData>&)> handleFun) {
std::lock_guard<std::mutex> lock(this->m_remoteNodesMx);
for (auto& pair : this->m_remoteNodes)
if (!handleFun(pair.second)) return;
}
/**
* 清除数据
*/
void clear() {
this->m_remoteNodesSocketIPPortStrMap.clear();
this->m_remoteNodes.clear();
}
};
/**
* 预备服务器信息==============================================================================
*/
class PrePareServerInformation {
public:
int m_port;
const char* m_IP;
public:
/**
*
* @param IP
* @param port
*/
PrePareServerInformation(const char* IP , int port) {
this->m_port = port;
this->m_IP = IP;
}
~PrePareServerInformation(){}
};
/**
* 节点接口================================================================================
*/
class NodeInterface {
protected:
HANDLE m_iocpHandle;
std::vector<std::thread> m_workerThreads;
size_t m_maxConcurrentThreads = 1;//线程数
std::atomic<bool> m_running;
static std::unordered_map<NODE_TYPE,const std::string> ipPortStrMap;
//回调函数
std::function<void(SOCKET)> m_onConnected = NULL;
std::function<void(SOCKET)> m_onDisconnected = NULL;
std::function<void(SOCKET,const char*,int)> m_onDataReceived = NULL;
//远程节点数据
RemoteNodesData m_remoteNodesData;
static std::mutex handleErrorMx;
//注册IOContext上下文Map,需要new新IOcontext时的登记列表
std::unordered_map<void*,std::shared_ptr<IOContext>> m_registerIOContexts;
std::mutex m_registerIOContextsMx;
protected:
/**
*
* @return
*/
virtual NODE_TYPE getNodeType() {}
/**
* 跳转处理函数
* @param context
* @param bytesTransferred
*/
virtual void jumpHandleFun(std::shared_ptr<IOContext>& context ,DWORD bytesTransferred ){}
/**
* 启动回调函数
*/
virtual void startCallBackFun() {}
/**
* 关闭回调函数
*/
virtual void shutDownCallBackFun() {}
/**
*
* @return
*/
const std::string& getNodeTypeStr() {
return NodeInterface::ipPortStrMap[this->getNodeType()];
}
/**
* 注册上下文
* @param socket
* @param IOOperation
*/
void registerIOContext(SOCKET socket, IO_OPERATION IOOperation) {
if (socket == INVALID_SOCKET) return;
IOContext* context = new IOContext();
context->m_operation = IOOperation;
context->m_socket = socket;
this->m_registerIOContexts[&context->m_overlapped] = std::shared_ptr<IOContext>(context);
}
/**
* 注册上下文
* @return
*/
std::shared_ptr<IOContext> registerIOContext() {
std::shared_ptr<IOContext> context(new IOContext());
std::cout << "Register IOContext = " << &context<< std::endl;
this->m_registerIOContexts[&context.get()->m_overlapped] = context;
return context;
}
/**
*
* @param context
*/
void registerIOContext(std::shared_ptr<IOContext>& context) {
this->m_registerIOContexts[&context.get()->m_overlapped] = context;
}
/**
*
* @param overlappedAddress
*/
void deleteRegisterIOContext(void* overlappedAddress) {
auto it = this->m_registerIOContexts.find(overlappedAddress);
if (it == this->m_registerIOContexts.end()) return;
std::lock_guard<std::mutex> lock(this->m_registerIOContextsMx);
std::cout << "Delete RegisterIOContext" << std::endl;
it = this->m_registerIOContexts.find(overlappedAddress);
if (it == this->m_registerIOContexts.end()) return;
this->m_registerIOContexts.erase(it);
}
/**
*
* @param overlappedAddress
* @return
*/
std::shared_ptr<IOContext> getRegisterIOContext(void* overlappedAddress) {
auto it = this->m_registerIOContexts.find(overlappedAddress);
if (it == this->m_registerIOContexts.end()) return nullptr;
return it->second;
}
/**
*
* @param message
* @param fatal
*/
void handleError(const char* message, bool fatal = false) {
std::lock_guard<std::mutex> lock(NodeInterface::handleErrorMx);
DWORD error = WSAGetLastError();
std::string nodeTypeStr = this->getNodeTypeStr()+":";
std::string str= error == 0 ? "Unknown" : strerror( error );
std::cout <<nodeTypeStr<< message << " failed with error" << error
<< ":" << str << std::endl;
if (!fatal) return;
WSACleanup();
ExitProcess(1);
}
/**
*
* @param socket
*/
void closeSocket(SOCKET& socket) {
if (this->m_onDisconnected != NULL) this->m_onDisconnected(socket);
this->m_remoteNodesData.deleteRemeoteNodesBySocket(socket);
}
/**
*
* @param context
*/
void postRead(std::shared_ptr<IOContext>& context) {
DWORD flags = 0;
DWORD bytesRead;
context.get()->m_operation = IO_OPERATION::READ;
if (WSARecv(
context.get()->m_socket,
&context.get()->m_dataBuf,
1,
&bytesRead,
&flags,
&context.get()->m_overlapped,
NULL) == SOCKET_ERROR) {
DWORD error = WSAGetLastError();
if (error != ERROR_IO_PENDING) {
this->handleError("WSARecv");
this->closeSocket(context.get()->m_socket);
}
}
}
/**
*
* @param IP
* @param port
* @param data
* @param length
*/
void postWrite(const char* IP, int port, const char* data, int length) {
std::string IPPortStr = std::string(IP) + ":" + std::to_string(port);
//不存在该节点
if (!this->m_remoteNodesData.hashNode(IPPortStr)) return;
//注册上下文
std::shared_ptr<IOContext> writeContext = this->registerIOContext();
writeContext.get()->m_socket = this->m_remoteNodesData.getNodeSocketByIPPortStr(IPPortStr);
writeContext.get()->m_operation = IO_OPERATION::WRITE;
memcpy(writeContext.get()->m_buffer,data,length);
writeContext.get()->m_dataBuf.len = length;
DWORD bytesSent;
if (WSASend(
writeContext.get()->m_socket,
&writeContext.get()->m_dataBuf,
1,
&bytesSent,
0,
&writeContext.get()->m_overlapped,
NULL) == SOCKET_ERROR) {
DWORD error = WSAGetLastError();
if (error != WSA_IO_PENDING) {
handleError("WSASend");
this->closeSocket(writeContext.get()->m_socket);
return;
}
}
//加入节点读写上下文
this->m_remoteNodesData.addRemeoteNodeWriteContext(writeContext);
}
/**
*
* @param socket
* @param data
* @param length
*/
void postWrite(SOCKET socket, const char* data, int length) {
//Socket为空
if (socket == INVALID_SOCKET) return;
//注册上下文
std::shared_ptr<IOContext> writeContext = this->registerIOContext();
writeContext.get()->m_socket = socket;
writeContext.get()->m_operation = IO_OPERATION::WRITE;
memcpy(writeContext.get()->m_buffer,data,length);
writeContext.get()->m_dataBuf.len = length;
DWORD bytesSent;
if (WSASend(
writeContext.get()->m_socket,
&writeContext.get()->m_dataBuf,
1,
&bytesSent,
0,
&writeContext.get()->m_overlapped,
NULL) == SOCKET_ERROR) {
DWORD error = WSAGetLastError();
if (error != WSA_IO_PENDING) {
handleError("WSASend");
this->closeSocket(writeContext.get()->m_socket);
this->deleteRegisterIOContext(&writeContext.get()->m_overlapped);
return;
}
}
//加入节点读写上下文
this->m_remoteNodesData.addRemeoteNodeWriteContext(writeContext);
}
/**
* 处理读操作
* @param context
* @param bytesTransfered
*/
void handleRead(std::shared_ptr<IOContext>& context, DWORD bytesTransferred) {
//客户端断开连接
if (bytesTransferred == 0) {
this->closeSocket(context.get()->m_socket);
return;
}
if (this->m_onDataReceived != NULL) {
this->m_onDataReceived(context.get()->m_socket,context.get()->m_buffer,bytesTransferred);
}
//继续投递读操作
this->postRead(context);
}
/**
*
* @param context
* @param bytesTransferred
*/
void handleWrite(std::shared_ptr<IOContext>& context, DWORD bytesTransferred) {
if (!context) return;
this->deleteRegisterIOContext(&context.get()->m_overlapped);
//写入完成,释放上下文
this->m_remoteNodesData.deleteRemeoteNodeWriteContext(context);
//context.reset();
}
/**
* 工作线程
*/
void workerThread() {
DWORD bytesTransferred;
ULONG_PTR completionKey;
LPOVERLAPPED overlapped;
while (this->m_running) {
std::cout << this->getNodeTypeStr() << "等待I/O端口返回=========" << std::endl;
BOOL result = GetQueuedCompletionStatus(this->m_iocpHandle,&bytesTransferred,&completionKey,&overlapped,INFINITE);
std::cout << this->getNodeTypeStr() << "I/O端口返回数据=========" << std::endl;
//处理错误或者超时
if (!result) {
DWORD error = GetLastError();
//超时
if (overlapped == NULL && error == WAIT_TIMEOUT) continue;
if (error != WAIT_TIMEOUT && error != ERROR_NETNAME_DELETED) this->handleError("GetQueuedCompletionStatus");
//客户端断开连接
if (overlapped != NULL) {
//获取注册的上下文
std::shared_ptr<IOContext> context = this->getRegisterIOContext(overlapped);
if (context == nullptr) continue;
if (context.get()->m_operation == IO_OPERATION::READ || context.get()->m_operation == IO_OPERATION::WRITE)
this->closeSocket(context.get()->m_socket);
}
continue;
}
//收到退出信号
if (bytesTransferred == 0 && completionKey == NULL && overlapped == NULL) break;
//获取注册的上下文
std::shared_ptr<IOContext> context = this->getRegisterIOContext(overlapped);
if (context == nullptr) continue;
this->jumpHandleFun(context,bytesTransferred);
}
}
public:
void setOnConnected(std::function<void(SOCKET)> callback) {this->m_onConnected = callback;}
void setOnDisconnected(std::function<void(SOCKET)> callback) {this->m_onDisconnected = callback;}
void setOnDataReceived(std::function<void(SOCKET,const char* ,int)> callback){this->m_onDataReceived =callback;}
/**
* 启动节点
*/
void start() {
//已经启动
if (this->m_running) return;
this->startCallBackFun();
//启动失败
if (!this->m_running) return;
std::cout << this->getNodeTypeStr() << ":startCallBackFun is runned!" << std::endl;
//join线程
for (std::thread& worker : this->m_workerThreads)
if (worker.joinable())worker.join();
}
/**
* 关闭节点
*/
void shutDown() {
if (!this->m_running) return;
this->m_running = false;
//关闭所有远程节点的Socket
auto closeSocketFun = [](std::shared_ptr<NodeSocketContextData>& remoteNode)->bool {
if(remoteNode->getSocket() == INVALID_SOCKET) return true;
//关闭socket
closesocket(remoteNode.get()->getSocket());
return true;
}; //封装遍历处理函数
//传入遍历处理远程节点数据函数
this->m_remoteNodesData.traverseHandleNodes(closeSocketFun);
//清除远程节点数据
this->m_remoteNodesData.clear();
//关闭IOCP句柄
if (this->m_iocpHandle != NULL) {
CloseHandle(this->m_iocpHandle);
this->m_iocpHandle = NULL;
}
//调用节点关闭回调函数
this->shutDownCallBackFun();
//唤醒所有工作线程, 发送退出信号
for (size_t i = 0; i < this->m_workerThreads.size();i++)
PostQueuedCompletionStatus(this->m_iocpHandle,0,(ULONG_PTR)NULL,NULL);
//等待所有工作线程
for (std::thread& thred:this->m_workerThreads)
if (thred.joinable()) thred.join();
this->m_workerThreads.clear();
WSACleanup();
}
/**
*
* @param IP
* @param port
* @param message
*/
void send(const char* IP, int port, const char* message) {
if (!this->m_running) return;
this->postWrite(IP,port,message,strlen(message));
}
/**
*
* @param socket
* @param message
*/
void send(SOCKET socket, const char* message) {
if (!this->m_running) return;
this->postWrite(socket,message,strlen(message));
}
NodeInterface() {}
virtual ~NodeInterface() {}
};
/**
* 服务器==================================================================================
*/
class Server: public NodeInterface {
private:
int m_port;
SOCKET m_listenSocket;
//static constexpr不需要类外定义
static constexpr NODE_TYPE nodeType = IOCP2::NODE_TYPE::SERVER;
private:
/**
* 投递新的接收连接操作
*/
void postAccept() {
SOCKET clientSocket = WSASocket(AF_INET,SOCK_STREAM,IPPROTO_TCP,NULL,0,WSA_FLAG_OVERLAPPED);
if (clientSocket == INVALID_SOCKET) {
this->handleError("WSASocket for accept");
return;
}
//注册上下文,防止上下文被析构
std::shared_ptr<IOContext> context = this->registerIOContext();
context.get()->m_operation = IO_OPERATION::ACCEPT;
context.get()->m_socket = clientSocket;
DWORD bytesReceived;
//接收连接
if (!WSAIoctlTool::acceptExFun(
this->m_listenSocket,
clientSocket,
context.get()->m_buffer,
&bytesReceived,
&context.get()->m_overlapped)) {
DWORD error = WSAGetLastError();
if (error == ERROR_IO_PENDING) return;
this->handleError("AcceptEx");
//删除注册上下文
this->deleteRegisterIOContext(&context.get()->m_overlapped);
closesocket(clientSocket);
}
}
/**
* 处理接收连接
* @param context
*/
void handleAccept(std::shared_ptr<IOContext>& context) {
//调用连接到客户端回调函数
if (this->m_onConnected != NULL) this->m_onConnected(context.get()->m_socket);
//将新socket关联到IOCP
if (CreateIoCompletionPort(
(HANDLE)context.get()->m_socket,
this->m_iocpHandle,
(ULONG_PTR)context.get()->m_socket,
0) == NULL) {
this->handleError("CreateIoCompletionPort for client socket");
closesocket(context.get()->m_socket);
return;
}
//获取客户端信息
std::string IPPortStr = WSAIoctlTool::GetAcceptExSockaddrsFun(context.get()->m_socket,context.get()->m_buffer);
//获取到的数据为空
if (IPPortStr.empty() || IPPortStr == "") {
this->handleError("GetAcceptExSockaddrs");
return;
}
//加入远程节点数据
std::shared_ptr<NodeSocketContextData> nodeContextData(new NodeSocketContextData(IPPortStr,context.get()->m_socket));
this->m_remoteNodesData.addRemoteNode(nodeContextData);
/*//注册上下文
this->registerIOContext(context);*/
//为新连接投递读操作
this->postRead(context);
//将读操作的上下文加入节点管理,防止提前释放
//this->m_remoteNodesData.addRemeoteNodeReadContext(context);
//继续投递接收连接操作
this->postAccept();
}
/**
*
* @param context
* @param bytesTransferred
*/
void jumpHandleFun(std::shared_ptr<IOContext> &context, DWORD bytesTransferred) override {
switch (context->m_operation) {
case IO_OPERATION::ACCEPT:
this->handleAccept(context);
break;
case IO_OPERATION::READ:
this->handleRead(context, bytesTransferred);
break;
case IO_OPERATION::WRITE:
this->handleWrite(context, bytesTransferred);
break;
}
}
/**
* 启动回调函数
*/
void startCallBackFun() override {
//初始化WinSocket
WSADATA WSAData;
if (WSAStartup(MAKEWORD(2,2), &WSAData) != 0) {
this->handleError("WSAStartup",true);
return;
}
//创建监听socket
this->m_listenSocket = WSASocket(AF_INET,SOCK_STREAM,IPPROTO_TCP,NULL,0,WSA_FLAG_OVERLAPPED);
if (this->m_listenSocket == INVALID_SOCKET) {
this->handleError("WSASocket",true);
return;
}
//绑定地址
sockaddr_in serverAddr;
serverAddr.sin_family = AF_INET;
serverAddr.sin_addr.s_addr = htonl(INADDR_ANY);
serverAddr.sin_port = htons(this->m_port);
if (bind(this->m_listenSocket,(sockaddr*)&serverAddr,sizeof(serverAddr)) == SOCKET_ERROR) {
this->handleError("bind",true);
closesocket(this->m_listenSocket);
return;
}
//创建IOCP
this->m_iocpHandle =CreateIoCompletionPort(INVALID_HANDLE_VALUE,NULL,0,0);
if (this->m_iocpHandle == NULL) {
handleError("CreateIoCompletionPort",true);
closesocket(this->m_listenSocket);
return;
}
//将监听socket关联到IOCP
if (CreateIoCompletionPort(
(HANDLE)this->m_listenSocket,
this->m_iocpHandle,
(ULONG_PTR)this->m_listenSocket,
0) == NULL) {
this->handleError("CreateIoCompletionPort for listen socket",true);
closesocket(this->m_listenSocket);
CloseHandle(this->m_iocpHandle);
return;
}
//开始监听
if (listen(this->m_listenSocket,SOMAXCONN) == SOCKET_ERROR) {
this->handleError("listen",true);
closesocket(this->m_listenSocket);
CloseHandle(this->m_iocpHandle);
return;
}
//创建工作线程
this->m_running = true;
for (size_t i = 0; i < this->m_maxConcurrentThreads; i++)
this->m_workerThreads.emplace_back(&IOCP2::Server::workerThread,this);
//投递初始化接收连接操作
this->postAccept();
}
/**
* 关闭回调函数
*/
void shutDownCallBackFun() override {
if (this->m_listenSocket == INVALID_SOCKET) return;
closesocket(this->m_listenSocket);
this->m_listenSocket = INVALID_SOCKET;
}
public:
/**
*
* @return
*/
NODE_TYPE getNodeType() override {
return Server::nodeType;
}
/**
*
* @param port
*/
Server(int port) {
this->m_port = port;
this->m_running = false;
}
~Server() override {
this->shutDown();
}
};
/**
* 客户端==================================================================================
*/
class Client: public NodeInterface {
private:
static constexpr NODE_TYPE nodeType = NODE_TYPE::CLIENT;
//预备服务器信息
std::queue<PrePareServerInformation> m_prePareServers;
private:
/**
*
* @param context
* @param bytesTransferred
*/
void jumpHandleFun(std::shared_ptr<IOContext> &context, DWORD bytesTransferred) override {
switch (context.get()->m_operation) {
case IO_OPERATION::READ:
this->handleRead(context,bytesTransferred);
break;
case IO_OPERATION::WRITE:
this->handleWrite(context,bytesTransferred);
break;
default:
break;
}
}
/**
*
* @param prePareServer
*/
void connectServer(PrePareServerInformation& prePareServer) {
//创建Socket
SOCKET serverSocket = WSASocket(AF_INET,SOCK_STREAM,IPPROTO_TCP,NULL,0,WSA_FLAG_OVERLAPPED);
if (serverSocket == INVALID_SOCKET) {
this->handleError("WSASocket",true);
return;
}
//解析服务器地址
sockaddr_in serverAddr;
serverAddr.sin_family = AF_INET;
serverAddr.sin_port = htons(prePareServer.m_port);
inet_pton(AF_INET,prePareServer.m_IP,&serverAddr.sin_addr);
//连接服务器
if (connect(serverSocket,(sockaddr*)&serverAddr,sizeof(serverAddr)) == SOCKET_ERROR) {
DWORD errror = WSAGetLastError();
this->handleError("connect");
closesocket(serverSocket);
return;
}
//将socket关联到IOCP
if (CreateIoCompletionPort(
(HANDLE)serverSocket,
this->m_iocpHandle,
(ULONG_PTR)serverSocket,
0) == NULL) {
this->handleError("CreateIoCompletionPort");
closesocket(serverSocket);
CloseHandle(this->m_iocpHandle);
return;
}
this->m_running = true;
//加入远程节点
std::string IPPortStr = std::string(prePareServer.m_IP)+":"+std::to_string(prePareServer.m_port);
std::shared_ptr<NodeSocketContextData> nodeContextData(new NodeSocketContextData(IPPortStr,serverSocket));
this->m_remoteNodesData.addRemoteNode(nodeContextData);
//注册上下文
std::shared_ptr<IOContext> context = this->registerIOContext();
context.get()->m_socket = serverSocket;
//投递读操作
this->postRead(context);
//调用连接回调函数
if (this->m_onConnected != NULL) this->m_onConnected(serverSocket);
}
/**
* 启动回到调函数
*/
void startCallBackFun() override {
//初始化WinSocket
WSAData wsaData;
if (WSAStartup(MAKEWORD(2,2), &wsaData) != 0) {
this->handleError("WSAStartup",true);
return;
}
//创建IOCP
this->m_iocpHandle = CreateIoCompletionPort(INVALID_HANDLE_VALUE,NULL,0,0);
if (this->m_iocpHandle == NULL) {
this->handleError("CreateIoCompletionPort",true);
return;
}
//连接服务器
while (!this->m_prePareServers.empty()) {
this->connectServer(this->m_prePareServers.front());
std::cout << "client:connectServer......."<< std::endl;
this->m_prePareServers.pop();
}
// 仅当有有效连接时才创建工作线程
if (this->m_running)
for (size_t i = 0; i < this->m_maxConcurrentThreads ; i++)
this->m_workerThreads.emplace_back(&IOCP2::Client::workerThread,this);
}
/**
* 关闭回调函数
*/
void shutDownCallBackFun() override {
}
public:
/**
*
* @param IP
* @param port
*/
void addServer(const char* IP, int port) {
if (this->m_running) return;
this->m_prePareServers.push(std::move(PrePareServerInformation(IP,port)));
}
/**
*
* @return
*/
NODE_TYPE getNodeType() override {
return Client::nodeType;
}
Client() {
this->m_running = false;
}
~Client() override {
this->shutDown();
}
};
};
GUID IOCP2::WSAIoctlTool::guidAcceptEx = WSAID_ACCEPTEX;
GUID IOCP2::WSAIoctlTool::guidGetAcceptExSockaddrs = WSAID_GETACCEPTEXSOCKADDRS;
std::unordered_map<IOCP2::NODE_TYPE,const std::string> IOCP2::NodeInterface::ipPortStrMap = {
{IOCP2::NODE_TYPE::SERVER,"server"},
{IOCP2::NODE_TYPE::CLIENT,"client"}
};
std::mutex IOCP2::NodeInterface::handleErrorMx;#include <iostream>
#include <string>
#include <thread>
#include "test/IOCP2.hpp"
#pragma comment(lib, "ws2_32.lib")
#pragma comment(lib, "mswsock.lib")
int main() {
IOCP2::Server server(1001);
IOCP2::Client client;
client.addServer("127.0.0.1",1001);
std::thread serverThread([&]() {
server.setOnDataReceived([&](SOCKET socket, const char* buffer, int len) {
std::cout << "server received message: " << buffer << std::endl;
const char* message = "hellow client! server received message";
client.send(socket,message);
});
server.start();
});
std::this_thread::sleep_for(std::chrono::seconds(3));
std::thread clientThread([&]() {
client.setOnConnected([&](SOCKET socket) {
std::cout << "client: connect socket = " << socket << std::endl;
});
client.setOnDataReceived([&](SOCKET socket, const char* buffer, int len) {
std::cout << "client received message: " << buffer << std::endl;
client.send(socket,"hellow client! server received message");
});
client.start();
});
std::this_thread::sleep_for(std::chrono::seconds(3));
std::thread clientSendThread([&]() {
const char* buffer = "Hello World!";
client.send("127.0.0.1",1001,buffer);
});
serverThread.join();
clientThread.join();
clientSendThread.join();
return 0;
}运行之后会内存泄露,问题初步定位是handWrite
最新发布