以下代码的请求处理函数有问题,请重新设计
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/epoll.h>
#include <fcntl.h>
#include <errno.h>
#include "common/debug.h"
#include "common/http_parser.h"
#define PORT 8003
#define MAX_EVENTS 1024
#define BUFFER_SIZE 8192
typedef struct {
int fd;
char buffer[BUFFER_SIZE];
size_t bytes_read;
int state; // 0=等待请求, 1=处理中
} ClientState;
void set_nonblocking(int sockfd) {
int flags = fcntl(sockfd, F_GETFL, 0);
if (flags == -1) {
LOG_ERROR("fcntl F_GETFL failed");
return;
}
if (fcntl(sockfd, F_SETFL, flags | O_NONBLOCK) == -1) {
LOG_ERROR("fcntl F_SETFL failed");
}
}
void handle_request(int client_fd);
int main() {
int server_fd, client_fd;
struct sockaddr_in address;
int addrlen = sizeof(address);
// 创建套接字
if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) {
LOG_ERROR("socket failed");
exit(EXIT_FAILURE);
}
// 设置套接字选项
int opt = 1;
setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
// 绑定地址
address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY;
address.sin_port = htons(PORT);
if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) {
LOG_ERROR("bind failed");
exit(EXIT_FAILURE);
}
// 监听
if (listen(server_fd, MAX_EVENTS) < 0) {
LOG_ERROR("listen failed");
exit(EXIT_FAILURE);
}
// 创建epoll实例
int epoll_fd = epoll_create1(0);
if (epoll_fd == -1) {
LOG_ERROR("epoll_create1");
exit(EXIT_FAILURE);
}
// 添加服务器套接字到epoll
struct epoll_event event;
event.events = EPOLLIN;
event.data.fd = server_fd;
if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, server_fd, &event) == -1) {
LOG_ERROR("epoll_ctl: server_fd");
exit(EXIT_FAILURE);
}
LOG_INFO("Epoll-based server listening on port %d", PORT);
// 客户端状态数组
ClientState *clients = calloc(MAX_EVENTS, sizeof(ClientState));
if (!clients) {
LOG_ERROR("Memory allocation failed");
exit(EXIT_FAILURE);
}
struct epoll_event events[MAX_EVENTS];
while (1) {
// 等待事件
int nfds = epoll_wait(epoll_fd, events, MAX_EVENTS, -1);
if (nfds == -1) {
if (errno == EINTR) continue;
LOG_ERROR("epoll_wait");
exit(EXIT_FAILURE);
}
for (int i = 0; i < nfds; i++) {
int fd = events[i].data.fd;
if (fd == server_fd) {
// 处理新连接
while (1) {
client_fd = accept(server_fd, (struct sockaddr *)&address, (socklen_t*)&addrlen);
if (client_fd < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
break; // 没有更多连接
}
LOG_WARN("accept failed");
break;
}
LOG_INFO("New connection from %s", inet_ntoa(address.sin_addr));
// 设置非阻塞模式
set_nonblocking(client_fd);
// 添加新客户端到epoll
event.events = EPOLLIN | EPOLLET; // 边缘触发模式
event.data.fd = client_fd;
if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &event) == -1) {
LOG_ERROR("epoll_ctl: client_fd");
close(client_fd);
continue;
}
// 初始化客户端状态
if (client_fd < MAX_EVENTS) {
clients[client_fd].fd = client_fd;
clients[client_fd].bytes_read = 0;
clients[client_fd].state = 0;
}
}
} else {
// 处理客户端请求
ClientState *client = &clients[fd];
if (!client || client->fd != fd) continue;
ssize_t count;
while (1) {
// 读取数据
count = read(fd, client->buffer + client->bytes_read,
sizeof(client->buffer) - client->bytes_read - 1);
if (count == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// 数据读取完毕
break;
}
LOG_WARN("read error on fd %d", fd);
close(fd);
client->fd = -1;
break;
} else if (count == 0) {
// 连接关闭
close(fd);
client->fd = -1;
break;
}
client->bytes_read += count;
// 检查是否收到完整请求
if (memmem(client->buffer, client->bytes_read, "\r\n\r\n", 4) != NULL) {
client->buffer[client->bytes_read] = '\0';
LOG_DEBUG("Received complete request on fd %d", fd);
// 处理请求
handle_request(fd);
// 重置状态
client->bytes_read = 0;
break;
}
// 防止缓冲区溢出
if (client->bytes_read >= sizeof(client->buffer) - 1) {
LOG_WARN("Request too large on fd %d", fd);
close(fd);
client->fd = -1;
break;
}
}
}
}
}
free(clients);
close(server_fd);
return 0;
}
void handle_request(int client_fd) {
char buffer[BUFFER_SIZE];
ssize_t bytes_read = read(client_fd, buffer, sizeof(buffer) - 1);
if (bytes_read <= 0) {
LOG_WARN("Failed to read request or connection closed");
close(client_fd);
return;
}
buffer[bytes_read] = '\0';
LOG_DEBUG("Received request:\n%.*s", (int)bytes_read, buffer);
// 解析HTTP请求
HttpRequest request;
if (parse_http_request(buffer, bytes_read, &request) != 0) {
LOG_WARN("Failed to parse HTTP request");
HttpResponse response;
build_400_response(&response);
char response_buffer[BUFFER_SIZE];
size_t response_length = sizeof(response_buffer);
build_http_response(&response, response_buffer, &response_length);
write(client_fd, response_buffer, response_length);
close(client_fd);
return;
}
// 处理请求
HttpResponse response = {0};
if (strcmp(request.path, "/") == 0) {
const char *content =
"<html>"
"<head><title>Thread Server</title></head>"
"<body><h1>Welcome to Thread Server</h1>"
"<p>Hello from thread %ld</p>"
"<form method='POST' action='/submit'>"
"<input type='text' name='data'><input type='submit'></form>"
"</body></html>";
char full_content[512];
snprintf(full_content, sizeof(full_content), content, pthread_self());
build_200_response(&response, full_content, "text/html");
}
else if (strcmp(request.path, "/submit") == 0 && request.method == POST) {
const char *content =
"<html>"
"<head><title>Submission Received</title></head>"
"<body><h1>Thank You!</h1>"
"<p>Your submission: %s</p></body></html>";
char full_content[512];
snprintf(full_content, sizeof(full_content), content, request.body ? request.body : "(empty)");
build_200_response(&response, full_content, "text/html");
}
else {
build_404_response(&response);
}
// 构建并发送响应
char response_buffer[BUFFER_SIZE];
size_t response_length = sizeof(response_buffer);
build_http_response(&response, response_buffer, &response_length);
write(client_fd, response_buffer, response_length);
close(client_fd);
// 清理资源
free_http_request(&request);
}