给别人写的
很奇怪的需求,随便改了一下;
必须是select来接受数据 , 用另一个线程发送;
为方便,直接写了个线程池;
另外由于ClientData在这个对象在每个客户端进入时都会创建一个,每次直接new ClientData将产生大量内存碎片,因此
trans.h
#ifndef CMD_HEADER
#define CMD_HEADER
enum CMD{
CMD_LOGIN,
CMD_LOGOUT,
CMD_LOGIN_RESULT,
CMD_LOGOUT_RESULT,
CMD_USER_JOIN,
CMD_ERROR
};
typedef struct _DataHeader{
short dataLen;
short cmd;
} DataHeader, *LPDataHeader;
typedef struct _Login {
DataHeader header;
char uname[32];
char passwd[32];
char data[932];
} Login, *LPLogin;
typedef struct _LoginResult{
DataHeader header;
short result;
char testdata[992];
}LoginResult, *LPLoginResult;
typedef struct _Logout{
DataHeader header;
char uname[32];
}Logout, *LPLogout;
typedef struct _LogoutResult{
DataHeader header;
short result;
}LogoutResult, *LPLogoutResult;
typedef struct _UserJoin{
DataHeader header;
int sock;
}UserJoin, *LPUserJoin;
#endif
help_func.hpp
#ifndef _GetTimeStamp
#define _GetTimeStamp
//#include <chrono>
#include <Windows.h>
class GetTimeStamp{
static LARGE_INTEGER StartingTime, EndingTime, ElapsedMicroseconds;
static LARGE_INTEGER Frequency;
//static std::chrono::time_point<std::chrono::high_resolution_clock> _starttime;
public:
static void start()
{
QueryPerformanceFrequency(&Frequency);
QueryPerformanceCounter(&StartingTime);
//_starttime = std::chrono::high_resolution_clock::now();
}
static long long elapsed()
{
QueryPerformanceCounter(&EndingTime);
ElapsedMicroseconds.QuadPart = EndingTime.QuadPart - StartingTime.QuadPart;
ElapsedMicroseconds.QuadPart *= 1000000;
return ElapsedMicroseconds.QuadPart / Frequency.QuadPart;
//return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - _starttime).count();
}
};
LARGE_INTEGER GetTimeStamp::EndingTime = { 0 };
LARGE_INTEGER GetTimeStamp::StartingTime = { 0 };
LARGE_INTEGER GetTimeStamp::ElapsedMicroseconds = { 0 };
LARGE_INTEGER GetTimeStamp::Frequency = { 0 };
#endif
task.cpp
#ifndef _TASK
#define _TASK
#include <Windows.h>
#include <process.h>
#include <list>
#define SPINCOUNT 4000
class Task
{
public:
Task()
{}
virtual ~Task()
{}
virtual void doTask()
{}
};
class TaskServer
{
private:
std::list<Task*> taskList;
CRITICAL_SECTION mutex;
CONDITION_VARIABLE cond;
HANDLE * _threadHandlers;
BOOL bEnd;
DWORD threadnum;
public:
TaskServer(DWORD thread_num) :_threadHandlers(NULL), bEnd(FALSE), threadnum(thread_num)
{
InitializeCriticalSectionAndSpinCount(&mutex, SPINCOUNT);
InitializeConditionVariable(&cond);
}
~TaskServer()
{
WaitForMultipleObjects(threadnum, _threadHandlers, TRUE, INFINITE);
for (int i = 0; i < threadnum; ++i)
CloseHandle(_threadHandlers[i]);
delete[] _threadHandlers;
DeleteCriticalSection(&mutex);
}
void addTask(Task * pTask)
{
EnterCriticalSection(&mutex);
taskList.push_back(pTask);
LeaveCriticalSection(&mutex);
WakeConditionVariable(&cond);
}
void start()
{
if (threadnum < 1)
threadnum = 1;
_threadHandlers = new HANDLE[threadnum];
for (int i = 0; i < threadnum; ++i)
_threadHandlers[i] = (HANDLE)_beginthreadex(0, 0, TaskServer::thread_func, this, 0, 0);
}
void stop()
{
EnterCriticalSection(&mutex);
if (bEnd){
LeaveCriticalSection(&mutex);
return;
}
bEnd = TRUE;
LeaveCriticalSection(&mutex);
WakeAllConditionVariable(&cond);
}
static unsigned int __stdcall thread_func(void * param)
{
TaskServer * pServer = (TaskServer *)param;
pServer->onRun();
return 0;
}
virtual void onRun()
{
while (true)
{
EnterCriticalSection(&mutex);
while (!bEnd && taskList.empty())
SleepConditionVariableCS(&cond, &mutex, INFINITE);
if (bEnd){
printf(" %ld end thread!\n" , GetCurrentThreadId());
LeaveCriticalSection(&mutex);
break;
}
Task * pTask = taskList.front();
taskList.pop_front();
LeaveCriticalSection(&mutex);
pTask->doTask();
delete pTask;
}
}
};
#endif
core_serv.hpp
#ifndef CORE_SERV
#define CORE_SERV
#define WIN32_LEAN_AND_MEAN
#include <process.h>
#include <WinSock2.h>
#include <Windows.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include "../trans.h"
#include "../help_func.hpp"
#include "task.cpp"
#pragma comment(lib, "ws2_32.lib")
#define PORT 9988
#define BACKLOG 10
#define RECV_BUFFSIZE (1<<13)
#define MSG_BUFFSIZE (1<<16)
#define SEND_BUFFSIZE (1<<16)
//这个不做封装了,直接使用, 需要封装的自己去搞
struct ClientData
{
unsigned int lastPos;
unsigned int headLen;
SOCKET _sock;
char * msgbuff;
CRITICAL_SECTION send_mutex;
BOOL bEnd;
ClientData(SOCKET sock = INVALID_SOCKET) :
_sock(sock), lastPos(0), headLen(sizeof(DataHeader)), msgbuff(NULL),
bEnd(FALSE)
{
msgbuff = (char*)malloc(MSG_BUFFSIZE); //不做异常处理了
//sendbuff = (char*)malloc(SEND_BUFFSIZE);
InitializeCriticalSectionAndSpinCount(&send_mutex, SPINCOUNT);
}
~ClientData(){
_sock = INVALID_SOCKET;
if (msgbuff)
free(msgbuff);
DeleteCriticalSection(&send_mutex);
}
void Close()
{
puts("close begin");
EnterCriticalSection(&send_mutex);
if (bEnd){
LeaveCriticalSection(&send_mutex);
return;
}
bEnd = TRUE;
LeaveCriticalSection(&send_mutex);
closesocket(_sock);
_sock = INVALID_SOCKET;
puts("close end");
}
BOOL get_status()
{
BOOL end = FALSE;
EnterCriticalSection(&send_mutex);
end = bEnd;
LeaveCriticalSection(&send_mutex);
return end;
}
int write_msg(char * pSendData, int nBytes)
{
int nLeft = nBytes;
int len = 0;
char * buff = pSendData;
EnterCriticalSection(&send_mutex);
while (nLeft > 0 && !bEnd && _sock != INVALID_SOCKET && buff)
{
printf("thread :%ld send !\n", GetCurrentThreadId());
len = send(_sock, buff , nLeft, 0);
if (len == SOCKET_ERROR){
printf("send error : %ld\n", WSAGetLastError());
break;
}
nLeft -= len;
buff += len;
}
LeaveCriticalSection(&send_mutex);
return nBytes - nLeft;
}
int SendData(void *pSendData)
{
int ret = SOCKET_ERROR;
if (_sock == INVALID_SOCKET || !pSendData){
return ret;
}
DataHeader * pHeader = (DataHeader *)pSendData;
switch (pHeader->cmd)
{
case CMD_LOGIN_RESULT:
{
int totallen = sizeof(LoginResult);
LoginResult * lr = (LoginResult *)pSendData;
BOOL isEnd = get_status();
if (!isEnd)
ret = write_msg((char*)lr, sizeof(LoginResult));
delete lr;
return ret;
}
case CMD_LOGOUT_RESULT:
{
int totallen = sizeof(LogoutResult);
LogoutResult * lr = (LogoutResult *)pSendData;
BOOL isEnd = get_status();
if (!isEnd)
ret = write_msg((char*)lr, sizeof(LogoutResult));
delete lr;
return ret;
}
default:{
puts("got error msg");
break;
}
}
}
};
class OneTask : public Task
{
ClientData * _pClientData;
void *pSendData;
public:
OneTask(ClientData *pClientData, void * sendData) : _pClientData(pClientData), pSendData(sendData)
{}
virtual ~OneTask()
{}
virtual void doTask()
{
_pClientData->SendData(pSendData);
}
};
class TcpServ
{
SOCKET hListenSocket;
//2个数组可直接使用一个struct , 用vector 或 dequeue 来替换
//我这里为了方便就直接使用数组了
SOCKET clients[FD_SETSIZE]; //一个socket对应一个clientData,其索引 一 一 对应
ClientData * clients_data[FD_SETSIZE];
SOCKADDR_IN serv_addr;
TaskServer * _task_server;
public:
TcpServ() :hListenSocket(INVALID_SOCKET)
{
_task_server = new TaskServer(4);
}
virtual ~TcpServ()
{
_task_server->stop();
delete _task_server;
Close();
}
int initSocket()
{
if (hListenSocket == INVALID_SOCKET){
hListenSocket = socket(AF_INET, SOCK_STREAM, 0);
if (hListenSocket == INVALID_SOCKET){
printf("socket error : %ld\n" , WSAGetLastError());
return -1;
}
}
return 0;
}
bool isSocketValid() const{
return hListenSocket != INVALID_SOCKET;
}
int Bind(const char *ip = NULL, unsigned short port = PORT)
{
if (!isSocketValid())
{
puts("invalid socket");
return -1;
}
memset(&serv_addr, 0, sizeof(serv_addr));
serv_addr.sin_addr.s_addr = ip ? inet_addr(ip) : INADDR_ANY ;
serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(port);
int serv_sock_len = sizeof(serv_addr);
return bind(hListenSocket, (SOCKADDR*)&serv_addr, serv_sock_len);
}
int Listen(int backlog = BACKLOG)
{
if (!isSocketValid())
{
puts("invalid socket");
return -1;
}
return listen(hListenSocket, backlog);
}
int setNonBlockMode(SOCKET sock, u_long bEnable)
{
return ioctlsocket(sock, FIONBIO, &bEnable);
}
void Close()
{
for (int i = 0; i < FD_SETSIZE; ++i){
if (clients[i] != INVALID_SOCKET){
closesocket(clients[i]);
delete clients_data[i];
}
}
if (hListenSocket != INVALID_SOCKET){
closesocket(hListenSocket);
hListenSocket = INVALID_SOCKET;
}
}
void start()
{
if (!isSocketValid()){
puts("invalid socket");
return;
}
_task_server->start();
FD_SET readset, allreadset;
FD_ZERO(&allreadset);
FD_SET(hListenSocket, &allreadset);
int nready = 0, maxfd = (int)hListenSocket, maxi = 0;
for (int i = 0; i < FD_SETSIZE; ++i)
clients[i] = -1;
struct timeval tval = { 3, 0 };
SOCKADDR_IN client_addr = {};
int client_sock_len = 0;
while (1)
{
client_sock_len = sizeof(client_addr);
readset = allreadset;
//puts("before select");
nready = select(maxfd + 1, &readset, NULL, NULL, &tval);
if (nready == SOCKET_ERROR){
printf("select error :%ld\n", WSAGetLastError());
continue;
}
else if (nready == 0){
//puts("Timeout , u can do something else!!!");
continue;
}
if (FD_ISSET(hListenSocket, &readset))
{
SOCKET hClientSock = accept(hListenSocket, (SOCKADDR*)&client_addr, &client_sock_len);
int i = 0;
for (i = 0; i < FD_SETSIZE && -1 != clients[i]; ++i);
if (FD_SETSIZE == i)
{
puts("full clients");
closesocket(hClientSock);
if (--nready == 0)
continue;
}
else
{
int clientSockNo = (int)hClientSock;//这步在win32下没用,win32中无视maxfd
if (clientSockNo > maxfd)
maxfd = clientSockNo;
if (i > maxi)
maxi = i;
FD_SET(hClientSock, &allreadset);
clients[i] = hClientSock;
ClientData * pClientData = new ClientData(hClientSock);
clients_data[i] = pClientData;
--nready;
int sock_arr[] = { i };
boardcast_msg(CMD_USER_JOIN, sock_arr, maxi);
if (nready == 0)
continue;
}
}
recvMsg(maxi, &readset, &allreadset, &nready);
}
}
void peerClose(int i, FD_SET * allreadset)
{
int sockno = (int)clients[i];
clients_data[i]->Close();
delete clients_data[i];
clients_data[i] = NULL;
FD_CLR(clients[i], allreadset);
clients[i] = -1;
printf("\tsocket:%d peer close !\n", sockno);
}
void recvMsg(int maxi, FD_SET *readset, FD_SET *allreadset, int *nready)
{
if (!isSocketValid()){
puts("invalid socket");
return;
}
//GetTimeStamp::start();
int len = 0, goon = 1;
DWORD totallen = 0;
char * recvbuff = NULL;
for (int i = 0; i <= maxi; ++i)
{
if (clients[i] < 0)
continue;
if (FD_ISSET(clients[i], readset))
{
puts("\tprepare recv from client!");
ClientData* pData = clients_data[i];
len = recv(clients[i], pData->msgbuff + pData->lastPos, MSG_BUFFSIZE - pData->lastPos, 0);
if (len <= 0)
{
peerClose(i,allreadset);
goto RecvMsgDone;
}
pData->lastPos += len;
while (pData->lastPos >= pData->headLen)
{
DataHeader * pHeader = (DataHeader *)pData->msgbuff;
totallen = pHeader->dataLen + pData->headLen;
if (pData->lastPos >= totallen){
onMsgRecved(pData);
memcpy(pData->msgbuff, pData->msgbuff + totallen, pData->lastPos - totallen);
pData->lastPos -= totallen;
}
else{
break;
}
}
RecvMsgDone:
if (--(*nready) == 0)
break;
}
}
//printf("- >>>>> time:%lld elapsed\n", GetTimeStamp::elapsed());
}
virtual void onMsgRecved(ClientData * pData)
{
DataHeader * header =(DataHeader *) pData->msgbuff;
switch (header->cmd)
{
case CMD_LOGIN:
{
Login * pLogin = (Login*)pData->msgbuff;
printf("\tLogin : %s, %s , datalen:%d, cmd:%d\n",
pLogin->uname, pLogin->passwd,
pLogin->header.dataLen, pLogin->header.cmd);
LoginResult * res = new LoginResult();
res->result = 1;
res->header.cmd = CMD_LOGIN_RESULT;
res->header.dataLen = sizeof(LoginResult) - sizeof(DataHeader);
OneTask * oneTask = new OneTask(pData, res);
_task_server->addTask(oneTask);
break;
}
case CMD_LOGOUT:
{
Logout *pLogout = (Logout*)pData->msgbuff;
printf("\tLogout : %s , cmd:%d, datalen:%d\n", pLogout->uname,
pLogout->header.cmd, pLogout->header.dataLen);
LogoutResult * res = new LogoutResult();
res->result = 1;
res->header.cmd = CMD_LOGOUT_RESULT;
res->header.dataLen = sizeof(LogoutResult)-sizeof(DataHeader);
OneTask * oneTask = new OneTask(pData, res);
_task_server->addTask(oneTask);
break;
}
}
}
void boardcast_msg(int cmd, int * sock_arr_index, int maxi)
{
puts(" **** boardcast msg **** ");
unsigned int num = (maxi == -1) ? FD_SETSIZE : maxi;
int arrlen = sizeof(sock_arr_index) / sizeof(int*);
SOCKET tmp_sock[FD_SETSIZE] = {};
switch (cmd){
case CMD_USER_JOIN:
{
UserJoin userjoin;
userjoin.header.cmd = CMD_USER_JOIN;
userjoin.header.dataLen = sizeof(UserJoin)-sizeof(DataHeader);
int sockindex = 0;
if (sock_arr_index) {
for (int i = 0; i < arrlen; ++i)
{
sockindex = sock_arr_index[i];
tmp_sock[i] = clients[sockindex];
clients[sockindex] = -1;
}
for (int i = 0; i <= num; ++i)
{
if (clients[i] != -1){
printf("send:%d\n", clients[i]);
send(clients[i], (const char*)&userjoin, sizeof(userjoin), 0);
}
}
for (int i = 0; i < arrlen; ++i)
{
clients[sock_arr_index[i]] = tmp_sock[i];
}
}
else
{
for (int i = 0; i < num; ++i)
if (clients[i] != -1)
send(clients[i], (const char*)&userjoin, sizeof(userjoin), 0);
}
}
}
}
};
#endif