
文章目录
第22章 Web服务器项目
在本章中,我们将深入探讨如何使用Rust构建一个完整的Web服务器。从单线程基础版本开始,逐步扩展到多线程架构,最终实现一个支持优雅停机的生产级Web服务器。这个项目将综合运用我们之前学到的所有权系统、并发编程、错误处理等概念。
22.1 构建单线程web服务器
让我们从构建一个基础的单线程Web服务器开始,理解HTTP协议的基本原理和网络编程的基础知识。
HTTP协议基础
在开始编码之前,我们需要理解HTTP协议的基本结构。HTTP请求和响应都有特定的格式:
HTTP请求格式:
METHOD PATH HTTP/VERSION
Header: value
Header: value
Body (可选)
HTTP响应格式:
HTTP/VERSION STATUS_CODE STATUS_TEXT
Header: value
Header: value
Body (可选)
基础TCP服务器
首先,我们创建一个能够监听TCP连接并响应基本HTTP请求的服务器。
use std::net::{TcpListener, TcpStream};
use std::io::{Read, Write};
use std::fs;
// HTTP响应状态码
#[derive(Debug)]
enum HttpStatus {
Ok = 200,
NotFound = 404,
InternalServerError = 500,
}
impl HttpStatus {
fn as_str(&self) -> &'static str {
match self {
HttpStatus::Ok => "200 OK",
HttpStatus::NotFound => "404 Not Found",
HttpStatus::InternalServerError => "500 Internal Server Error",
}
}
}
// HTTP响应构建器
struct HttpResponse {
status: HttpStatus,
headers: Vec<String>,
body: Option<Vec<u8>>,
}
impl HttpResponse {
fn new(status: HttpStatus) -> Self {
Self {
status,
headers: Vec::new(),
body: None,
}
}
fn with_header(mut self, name: &str, value: &str) -> Self {
self.headers.push(format!("{}: {}", name, value));
self
}
fn with_body(mut self, body: Vec<u8>) -> Self {
let content_length = body.len();
self.headers.push(format!("Content-Length: {}", content_length));
self.body = Some(body);
self
}
fn with_html_body(self, html: &str) -> Self {
self.with_header("Content-Type", "text/html; charset=utf-8")
.with_body(html.as_bytes().to_vec())
}
fn with_text_body(self, text: &str) -> Self {
self.with_header("Content-Type", "text/plain; charset=utf-8")
.with_body(text.as_bytes().to_vec())
}
fn to_bytes(&self) -> Vec<u8> {
let status_line = format!("HTTP/1.1 {}\r\n", self.status.as_str());
let headers: String = self.headers.join("\r\n");
let header_end = "\r\n\r\n";
let mut response = format!("{}{}{}", status_line, headers, header_end).into_bytes();
if let Some(body) = &self.body {
response.extend_from_slice(body);
}
response
}
}
// 请求处理器
struct RequestHandler;
impl RequestHandler {
fn handle_request(&self, stream: &mut TcpStream) -> std::io::Result<()> {
let mut buffer = [0; 1024];
let bytes_read = stream.read(&mut buffer)?;
if bytes_read == 0 {
return Ok(());
}
let request = String::from_utf8_lossy(&buffer[..bytes_read]);
println!("收到请求:\n{}", request);
let response = self.process_request(&request);
stream.write_all(&response.to_bytes())?;
stream.flush()?;
Ok(())
}
fn process_request(&self, request: &str) -> HttpResponse {
let lines: Vec<&str> = request.lines().collect();
if lines.is_empty() {
return HttpResponse::new(HttpStatus::InternalServerError)
.with_text_body("无效请求");
}
let request_line = lines[0];
let parts: Vec<&str> = request_line.split_whitespace().collect();
if parts.len() < 2 {
return HttpResponse::new(HttpStatus::InternalServerError)
.with_text_body("无效请求行");
}
let method = parts[0];
let path = parts[1];
// 只支持GET方法
if method != "GET" {
return HttpResponse::new(HttpStatus::InternalServerError)
.with_text_body("只支持GET方法");
}
self.serve_file(path)
}
fn serve_file(&self, path: &str) -> HttpResponse {
// 默认页面
let actual_path = if path == "/" {
"index.html"
} else {
&path[1..] // 移除前导斜杠
};
match fs::read(actual_path) {
Ok(content) => {
// 根据文件扩展名设置Content-Type
let content_type = self.get_content_type(actual_path);
HttpResponse::new(HttpStatus::Ok)
.with_header("Content-Type", content_type)
.with_body(content)
}
Err(_) => {
// 文件不存在,返回404页面
let not_found_html = r#"
<!DOCTYPE html>
<html>
<head>
<title>404 Not Found</title>
<style>
body { font-family: Arial, sans-serif; text-align: center; padding: 50px; }
h1 { color: #d32f2f; }
</style>
</head>
<body>
<h1>404 Not Found</h1>
<p>请求的资源不存在。</p>
<a href="/">返回首页</a>
</body>
</html>
"#;
HttpResponse::new(HttpStatus::NotFound)
.with_html_body(not_found_html)
}
}
}
fn get_content_type(&self, filename: &str) -> &str {
if filename.ends_with(".html") {
"text/html; charset=utf-8"
} else if filename.ends_with(".css") {
"text/css"
} else if filename.ends_with(".js") {
"application/javascript"
} else if filename.ends_with(".png") {
"image/png"
} else if filename.ends_with(".jpg") || filename.ends_with(".jpeg") {
"image/jpeg"
} else {
"application/octet-stream"
}
}
}
// 单线程Web服务器
struct SingleThreadedWebServer {
address: String,
handler: RequestHandler,
}
impl SingleThreadedWebServer {
fn new(address: &str) -> Self {
Self {
address: address.to_string(),
handler: RequestHandler,
}
}
fn run(&self) -> std::io::Result<()> {
let listener = TcpListener::bind(&self.address)?;
println!("服务器运行在 http://{}", self.address);
// 创建默认的index.html文件
self.create_default_files()?;
for stream in listener.incoming() {
match stream {
Ok(mut stream) => {
if let Err(e) = self.handler.handle_request(&mut stream) {
eprintln!("处理请求时出错: {}", e);
}
}
Err(e) => {
eprintln!("连接失败: {}", e);
}
}
}
Ok(())
}
fn create_default_files(&self) -> std::io::Result<()> {
let index_html = r#"
<!DOCTYPE html>
<html>
<head>
<title>Rust Web Server</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}
.container {
background: rgba(255, 255, 255, 0.1);
padding: 30px;
border-radius: 10px;
backdrop-filter: blur(10px);
}
h1 {
text-align: center;
margin-bottom: 30px;
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
}
.feature {
background: rgba(255, 255, 255, 0.2);
padding: 15px;
margin: 10px 0;
border-radius: 5px;
}
</style>
</head>
<body>
<div class="container">
<h1>🚀 Rust Web Server</h1>
<p>欢迎使用基于Rust构建的单线程Web服务器!</p>
<div class="feature">
<h3>✨ 特性</h3>
<ul>
<li>单线程架构</li>
<li>静态文件服务</li>
<li>自动Content-Type检测</li>
<li>自定义错误页面</li>
</ul>
</div>
<div class="feature">
<h3>📁 示例文件</h3>
<ul>
<li><a href="/" style="color: #ffd700;">首页</a></li>
<li><a href="/nonexistent.html" style="color: #ffd700;">测试404页面</a></li>
</ul>
</div>
</div>
</body>
</html>
"#;
fs::write("index.html", index_html)?;
println!("已创建默认的 index.html 文件");
Ok(())
}
}
fn main() -> std::io::Result<()> {
println!("启动单线程Web服务器...");
let server = SingleThreadedWebServer::new("127.0.0.1:8080");
server.run()
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::net::TcpStream;
use std::thread;
use std::time::Duration;
#[test]
fn test_http_response_builder() {
let response = HttpResponse::new(HttpStatus::Ok)
.with_header("Content-Type", "text/html")
.with_html_body("<h1>Hello</h1>")
.to_bytes();
let response_str = String::from_utf8_lossy(&response);
assert!(response_str.contains("200 OK"));
assert!(response_str.contains("Content-Type: text/html"));
assert!(response_str.contains("<h1>Hello</h1>"));
}
#[test]
fn test_request_handler() {
let handler = RequestHandler;
// 测试根路径请求
let get_request = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n";
let response = handler.process_request(get_request);
assert!(matches!(response.status, HttpStatus::Ok));
// 测试不存在的路径
let not_found_request = "GET /nonexistent.html HTTP/1.1\r\nHost: localhost\r\n\r\n";
let response = handler.process_request(not_found_request);
assert!(matches!(response.status, HttpStatus::NotFound));
// 测试无效请求
let invalid_request = "INVALID REQUEST";
let response = handler.process_request(invalid_request);
assert!(matches!(response.status, HttpStatus::InternalServerError));
}
}
增强的HTTP解析器
现在让我们实现一个更完整的HTTP请求解析器,能够处理各种HTTP方法和头部信息。
use std::collections::HashMap;
// HTTP方法枚举
#[derive(Debug, PartialEq)]
enum HttpMethod {
Get,
Post,
Put,
Delete,
Head,
Options,
Patch,
}
impl From<&str> for HttpMethod {
fn from(s: &str) -> Self {
match s.to_uppercase().as_str() {
"GET" => HttpMethod::Get,
"POST" => HttpMethod::Post,
"PUT" => HttpMethod::Put,
"DELETE" => HttpMethod::Delete,
"HEAD" => HttpMethod::Head,
"OPTIONS" => HttpMethod::Options,
"PATCH" => HttpMethod::Patch,
_ => HttpMethod::Get, // 默认为GET
}
}
}
// HTTP请求结构
#[derive(Debug)]
struct HttpRequest {
method: HttpMethod,
path: String,
version: String,
headers: HashMap<String, String>,
body: Option<Vec<u8>>,
}
impl HttpRequest {
fn parse(request: &str) -> Result<Self, &'static str> {
let mut lines = request.lines();
// 解析请求行
let request_line = lines.next().ok_or("空请求")?;
let request_parts: Vec<&str> = request_line.split_whitespace().collect();
if request_parts.len() != 3 {
return Err("无效的请求行");
}
let method = HttpMethod::from(request_parts[0]);
let path = request_parts[1].to_string();
let version = request_parts[2].to_string();
// 解析头部
let mut headers = HashMap::new();
for line in lines.by_ref() {
if line.is_empty() {
break; // 头部结束
}
if let Some((name, value)) = line.split_once(':') {
headers.insert(
name.trim().to_string(),
value.trim().to_string(),
);
}
}
// 注意:在实际实现中,我们还需要解析请求体
// 这里简化处理
Ok(HttpRequest {
method,
path,
version,
headers,
body: None,
})
}
fn get_header(&self, name: &str) -> Option<&String> {
self.headers.get(&name.to_lowercase())
}
}
// 增强的请求处理器
struct EnhancedRequestHandler;
impl EnhancedRequestHandler {
fn handle_enhanced_request(&self, stream: &mut TcpStream) -> std::io::Result<()> {
let mut buffer = [0; 4096]; // 更大的缓冲区
let bytes_read = stream.read(&mut buffer)?;
if bytes_read == 0 {
return Ok(());
}
let request_str = String::from_utf8_lossy(&buffer[..bytes_read]);
match HttpRequest::parse(&request_str) {
Ok(request) => {
println!("解析的请求: {:?}", request);
let response = self.process_enhanced_request(request);
stream.write_all(&response.to_bytes())?;
}
Err(e) => {
eprintln!("解析HTTP请求失败: {}", e);
let response = HttpResponse::new(HttpStatus::InternalServerError)
.with_text_body(&format!("解析请求失败: {}", e));
stream.write_all(&response.to_bytes())?;
}
}
stream.flush()?;
Ok(())
}
fn process_enhanced_request(&self, request: HttpRequest) -> HttpResponse {
// 处理不同的HTTP方法
match request.method {
HttpMethod::Get => self.handle_get_request(&request),
HttpMethod::Post => self.handle_post_request(&request),
HttpMethod::Head => self.handle_head_request(&request),
HttpMethod::Options => self.handle_options_request(),
_ => HttpResponse::new(HttpStatus::InternalServerError)
.with_text_body("不支持的HTTP方法"),
}
}
fn handle_get_request(&self, request: &HttpRequest) -> HttpResponse {
self.serve_file(&request.path)
}
fn handle_post_request(&self, _request: &HttpRequest) -> HttpResponse {
// 在实际应用中,这里会处理POST数据
HttpResponse::new(HttpStatus::Ok)
.with_text_body("POST请求已接收")
}
fn handle_head_request(&self, request: &HttpRequest) -> HttpResponse {
// HEAD方法与GET相同,但没有响应体
let mut response = self.serve_file(&request.path);
response.body = None;
// 移除Content-Length头部,因为HEAD响应没有body
response.headers.retain(|header| !header.starts_with("Content-Length:"));
response
}
fn handle_options_request(&self) -> HttpResponse {
HttpResponse::new(HttpStatus::Ok)
.with_header("Allow", "GET, POST, HEAD, OPTIONS")
.with_header("Content-Length", "0")
}
fn serve_file(&self, path: &str) -> HttpResponse {
// 与之前相同的文件服务逻辑
let actual_path = if path == "/" {
"index.html"
} else {
&path[1..]
};
match fs::read(actual_path) {
Ok(content) => {
let content_type = self.get_content_type(actual_path);
HttpResponse::new(HttpStatus::Ok)
.with_header("Content-Type", content_type)
.with_body(content)
}
Err(_) => {
let not_found_html = r#"
<!DOCTYPE html>
<html>
<head><title>404 Not Found</title></head>
<body>
<h1>404 Not Found</h1>
<p>请求的资源不存在。</p>
</body>
</html>
"#;
HttpResponse::new(HttpStatus::NotFound)
.with_html_body(not_found_html)
}
}
}
fn get_content_type(&self, filename: &str) -> &str {
if filename.ends_with(".html") {
"text/html; charset=utf-8"
} else if filename.ends_with(".css") {
"text/css"
} else if filename.ends_with(".js") {
"application/javascript"
} else if filename.ends_with(".png") {
"image/png"
} else if filename.ends_with(".jpg") || filename.ends_with(".jpeg") {
"image/jpeg"
} else {
"application/octet-stream"
}
}
}
// 测试HTTP解析
#[cfg(test)]
mod http_tests {
use super::*;
#[test]
fn test_http_request_parsing() {
let request_text = "GET /index.html HTTP/1.1\r\nHost: localhost\r\nUser-Agent: test\r\n\r\n";
let request = HttpRequest::parse(request_text).unwrap();
assert_eq!(request.method, HttpMethod::Get);
assert_eq!(request.path, "/index.html");
assert_eq!(request.version, "HTTP/1.1");
assert_eq!(request.get_header("Host"), Some(&"localhost".to_string()));
assert_eq!(request.get_header("User-Agent"), Some(&"test".to_string()));
}
#[test]
fn test_http_method_conversion() {
assert_eq!(HttpMethod::from("GET"), HttpMethod::Get);
assert_eq!(HttpMethod::from("POST"), HttpMethod::Post);
assert_eq!(HttpMethod::from("PUT"), HttpMethod::Put);
assert_eq!(HttpMethod::from("UNKNOWN"), HttpMethod::Get); // 默认
}
}
22.2 将服务器变为多线程
单线程服务器在处理并发请求时性能有限。现在我们将服务器改造为多线程架构,使用线程池来处理并发连接。
线程池实现
首先,我们实现一个高效的线程池来管理工作者线程。
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
// 任务类型别名
type Job = Box<dyn FnOnce() + Send + 'static>;
// 线程池状态
struct ThreadPoolState {
workers: Vec<Worker>,
sender: Option<mpsc::Sender<Job>>,
}
// 线程池
pub struct ThreadPool {
state: Arc<Mutex<ThreadPoolState>>,
}
impl ThreadPool {
/// 创建新的线程池
///
/// size 是线程池中的线程数量
///
/// # Panics
///
/// 当 size 为 0 时会 panic
pub fn new(size: usize) -> ThreadPool {
assert!(size > 0);
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let mut workers = Vec::with_capacity(size);
for id in 0..size {
workers.push(Worker::new(id, Arc::clone(&receiver)));
}
ThreadPool {
state: Arc::new(Mutex::new(ThreadPoolState {
workers,
sender: Some(sender),
})),
}
}
/// 执行一个任务
///
/// 任务将被发送到线程池中的一个工作者线程执行
pub fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);
let state = self.state.lock().unwrap();
if let Some(ref sender) = state.sender {
sender.send(job).unwrap();
}
}
/// 获取线程池大小
pub fn size(&self) -> usize {
let state = self.state.lock().unwrap();
state.workers.len()
}
/// 获取活跃线程数量
pub fn active_threads(&self) -> usize {
let state = self.state.lock().unwrap();
state.workers.iter().filter(|w| w.is_active()).count()
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
println!("正在关闭线程池...");
let mut state = self.state.lock().unwrap();
// 首先关闭发送端,这样接收端会收到错误,工作者线程会退出循环
state.sender = None;
// 等待所有工作者线程结束
for worker in &mut state.workers {
println!("关闭工作者 {}", worker.id);
if let Some(thread) = worker.thread.take() {
thread.join().unwrap();
}
}
println!("线程池已关闭");
}
}
// 工作者线程
struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}
impl Worker {
fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
let thread = thread::spawn(move || {
// 线程主循环
loop {
let job = {
let receiver_guard = receiver.lock().unwrap();
match receiver_guard.recv() {
Ok(job) => job,
Err(_) => {
// 接收错误说明发送端已关闭,线程应该退出
break;
}
}
};
println!("工作者 {} 开始执行任务", id);
job();
println!("工作者 {} 完成任务", id);
}
println!("工作者 {} 退出", id);
});
Worker {
id,
thread: Some(thread),
}
}
fn is_active(&self) -> bool {
self.thread.is_some()
}
}
// 连接处理器
struct ConnectionHandler {
thread_pool: ThreadPool,
request_handler: EnhancedRequestHandler,
}
impl ConnectionHandler {
fn new(pool_size: usize) -> Self {
Self {
thread_pool: ThreadPool::new(pool_size),
request_handler: EnhancedRequestHandler,
}
}
fn handle_connection(&self, stream: TcpStream) {
self.thread_pool.execute(move || {
if let Err(e) = EnhancedRequestHandler.handle_enhanced_request(&mut stream.try_clone().unwrap()) {
eprintln!("处理连接时出错: {}", e);
}
});
}
}
// 多线程Web服务器
struct MultiThreadedWebServer {
address: String,
connection_handler: ConnectionHandler,
}
impl MultiThreadedWebServer {
fn new(address: &str, pool_size: usize) -> Self {
Self {
address: address.to_string(),
connection_handler: ConnectionHandler::new(pool_size),
}
}
fn run(&self) -> std::io::Result<()> {
let listener = TcpListener::bind(&self.address)?;
println!("多线程服务器运行在 http://{}", self.address);
println!("线程池大小: {}", self.connection_handler.thread_pool.size());
// 创建默认文件
self.create_default_files()?;
for stream in listener.incoming() {
match stream {
Ok(stream) => {
println!("接受新连接 from: {}", stream.peer_addr().unwrap());
self.connection_handler.handle_connection(stream);
// 打印线程池状态
let active_threads = self.connection_handler.thread_pool.active_threads();
println!("活跃线程: {}/{}", active_threads, self.connection_handler.thread_pool.size());
}
Err(e) => {
eprintln!("连接失败: {}", e);
}
}
}
Ok(())
}
fn create_default_files(&self) -> std::io::Result<()> {
let index_html = r#"
<!DOCTYPE html>
<html>
<head>
<title>Rust Multi-threaded Web Server</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}
.container {
background: rgba(255, 255, 255, 0.1);
padding: 30px;
border-radius: 10px;
backdrop-filter: blur(10px);
}
h1 {
text-align: center;
margin-bottom: 30px;
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
}
.feature {
background: rgba(255, 255, 255, 0.2);
padding: 15px;
margin: 10px 0;
border-radius: 5px;
}
.stats {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 10px;
margin-top: 20px;
}
.stat-item {
background: rgba(255, 255, 255, 0.15);
padding: 10px;
border-radius: 5px;
text-align: center;
}
</style>
</head>
<body>
<div class="container">
<h1>🚀 Rust 多线程 Web 服务器</h1>
<p>欢迎使用基于Rust构建的多线程Web服务器!</p>
<div class="feature">
<h3>✨ 特性</h3>
<ul>
<li>多线程架构</li>
<li>线程池管理</li>
<li>静态文件服务</li>
<li>支持多种HTTP方法</li>
<li>自动Content-Type检测</li>
</ul>
</div>
<div class="stats">
<div class="stat-item">
<strong>线程池大小</strong>
<div>4 个工作者线程</div>
</div>
<div class="stat-item">
<strong>并发处理</strong>
<div>支持多个并发连接</div>
</div>
</div>
<div class="feature">
<h3>🔧 测试端点</h3>
<ul>
<li><a href="/" style="color: #ffd700;">首页 (GET)</a></li>
<li><a href="/nonexistent.html" style="color: #ffd700;">测试404页面</a></li>
<li><button onclick="testHead()" style="color: #333;">测试HEAD请求</button></li>
<li><button onclick="testOptions()" style="color: #333;">测试OPTIONS请求</button></li>
</ul>
</div>
</div>
<script>
async function testHead() {
const response = await fetch('/', { method: 'HEAD' });
alert(`HEAD请求状态: ${response.status}`);
}
async function testOptions() {
const response = await fetch('/', { method: 'OPTIONS' });
alert(`OPTIONS请求状态: ${response.status}\n允许的方法: ${response.headers.get('Allow')}`);
}
</script>
</body>
</html>
"#;
fs::write("index.html", index_html)?;
println!("已创建多线程版本的 index.html 文件");
Ok(())
}
}
// 测试线程池
#[cfg(test)]
mod thread_pool_tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
#[test]
fn test_thread_pool_creation() {
let pool = ThreadPool::new(4);
assert_eq!(pool.size(), 4);
}
#[test]
#[should_panic]
fn test_zero_size_pool() {
let _pool = ThreadPool::new(0);
}
#[test]
fn test_thread_pool_execution() {
let pool = ThreadPool::new(2);
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
pool.execute(move || {
counter_clone.fetch_add(1, Ordering::SeqCst);
});
// 给线程一些时间执行
thread::sleep(Duration::from_millis(100));
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
}
fn main() -> std::io::Result<()> {
println!("启动多线程Web服务器...");
// 使用4个工作者线程的线程池
let server = MultiThreadedWebServer::new("127.0.0.1:8080", 4);
server.run()
}
性能优化:工作窃取线程池
为了进一步提高性能,我们可以实现一个更高级的工作窃取线程池。
use crossbeam::deque::{Injector, Steal, Worker};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
// 工作窃取线程池
struct WorkStealingThreadPool {
workers: Vec<Arc<Worker<std::sync::Mutex<Worker<Job>>>>>,
injector: Arc<Injector<Job>>,
size: usize,
}
impl WorkStealingThreadPool {
fn new(size: usize) -> Self {
let injector = Arc::new(Injector::new());
let mut workers = Vec::with_capacity(size);
for _ in 0..size {
let worker = Arc::new(std::sync::Mutex::new(Worker::new_fifo()));
workers.push(worker);
}
// 启动工作者线程
for (id, worker) in workers.iter().enumerate() {
let injector = Arc::clone(&injector);
let workers = workers.clone();
let worker_clone = Arc::clone(worker);
thread::spawn(move || {
Self::worker_loop(id, worker_clone, injector, workers);
});
}
Self {
workers,
injector,
size,
}
}
fn worker_loop(
id: usize,
worker: Arc<std::sync::Mutex<Worker<Job>>>,
injector: Arc<Injector<Job>>,
workers: Vec<Arc<std::sync::Mutex<Worker<Job>>>>,
) {
loop {
// 首先从自己的队列获取任务
if let Some(job) = worker.lock().unwrap().pop() {
job();
continue;
}
// 然后从全局注入器获取任务
if let Steal::Success(job) = injector.steal() {
job();
continue;
}
// 最后尝试从其他工作者窃取任务
for other_worker in &workers {
if Arc::ptr_eq(other_worker, &worker) {
continue;
}
if let Steal::Success(job) = other_worker.lock().unwrap().steal() {
job();
break;
}
}
// 没有任务,短暂休眠
thread::sleep(Duration::from_micros(100));
}
}
fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);
self.injector.push(job);
}
}
impl Drop for WorkStealingThreadPool {
fn drop(&mut self) {
// 在实际实现中,我们需要更复杂的关闭逻辑
println!("关闭工作窃取线程池");
}
}
22.3 优雅停机与清理
生产级服务器需要能够优雅地关闭,确保所有正在处理的请求都能完成,同时拒绝新的连接。
信号处理和优雅关闭
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
// 服务器状态管理
struct ServerState {
shutdown: AtomicBool,
active_connections: std::sync::Mutex<usize>,
}
impl ServerState {
fn new() -> Self {
Self {
shutdown: AtomicBool::new(false),
active_connections: std::sync::Mutex::new(0),
}
}
fn should_shutdown(&self) -> bool {
self.shutdown.load(Ordering::SeqCst)
}
fn initiate_shutdown(&self) {
self.shutdown.store(true, Ordering::SeqCst);
println!("初始化服务器关闭...");
}
fn increment_connections(&self) -> bool {
if self.should_shutdown() {
return false;
}
let mut count = self.active_connections.lock().unwrap();
*count += 1;
true
}
fn decrement_connections(&self) {
let mut count = self.active_connections.lock().unwrap();
*count -= 1;
}
fn get_active_connections(&self) -> usize {
*self.active_connections.lock().unwrap()
}
fn wait_for_zero_connections(&self, timeout: Duration) -> bool {
let start = std::time::Instant::now();
while self.get_active_connections() > 0 {
if start.elapsed() > timeout {
return false;
}
thread::sleep(Duration::from_millis(100));
}
true
}
}
// 优雅的Web服务器
struct GracefulWebServer {
address: String,
thread_pool: ThreadPool,
state: Arc<ServerState>,
request_handler: EnhancedRequestHandler,
}
impl GracefulWebServer {
fn new(address: &str, pool_size: usize) -> Self {
Self {
address: address.to_string(),
thread_pool: ThreadPool::new(pool_size),
state: Arc::new(ServerState::new()),
request_handler: EnhancedRequestHandler,
}
}
fn run(&self) -> std::io::Result<()> {
// 设置信号处理
self.setup_signal_handlers();
let listener = TcpListener::bind(&self.address)?;
listener.set_nonblocking(true)?; // 设置为非阻塞模式,以便检查关闭信号
println!("优雅服务器运行在 http://{}", self.address);
println!("线程池大小: {}", self.thread_pool.size());
println!("使用 Ctrl+C 或发送 SIGTERM 信号来优雅关闭服务器");
self.create_default_files()?;
// 主服务器循环
while !self.state.should_shutdown() {
match listener.accept() {
Ok((stream, addr)) => {
if !self.state.increment_connections() {
// 服务器正在关闭,拒绝新连接
println!("拒绝新连接 from {} (服务器正在关闭)", addr);
let _ = stream.shutdown(std::net::Shutdown::Both);
continue;
}
println!("接受新连接 from: {}", addr);
let state = Arc::clone(&self.state);
let handler = EnhancedRequestHandler;
self.thread_pool.execute(move || {
defer! {
state.decrement_connections();
println!("连接 from {} 处理完成", addr);
}
if let Err(e) = handler.handle_enhanced_request(&mut stream.try_clone().unwrap()) {
eprintln!("处理连接时出错: {}", e);
}
});
// 打印服务器状态
self.print_status();
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// 没有新连接,短暂休眠后继续
thread::sleep(Duration::from_millis(100));
continue;
}
Err(e) => {
eprintln!("接受连接时出错: {}", e);
}
}
}
// 开始关闭过程
self.initiate_shutdown()
}
fn setup_signal_handlers(&self) {
let state = Arc::clone(&self.state);
// 处理 Ctrl+C
ctrlc::set_handler(move || {
println!("\n收到中断信号,开始优雅关闭...");
state.initiate_shutdown();
}).expect("设置信号处理器失败");
}
fn initiate_shutdown(&self) -> std::io::Result<()> {
println!("\n开始优雅关闭过程...");
println!("等待活动连接完成...");
// 等待所有活动连接完成(最多30秒)
if self.state.wait_for_zero_connections(Duration::from_secs(30)) {
println!("所有活动连接已完成");
} else {
println!("警告: 在超时前并非所有连接都完成");
}
println!("服务器关闭完成");
Ok(())
}
fn print_status(&self) {
let active_connections = self.state.get_active_connections();
let active_threads = self.thread_pool.active_threads();
println!("状态: 连接={}, 线程={}/{}",
active_connections,
active_threads,
self.thread_pool.size());
}
fn create_default_files(&self) -> std::io::Result<()> {
let index_html = r#"
<!DOCTYPE html>
<html>
<head>
<title>Rust Graceful Web Server</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}
.container {
background: rgba(255, 255, 255, 0.1);
padding: 30px;
border-radius: 10px;
backdrop-filter: blur(10px);
}
h1 {
text-align: center;
margin-bottom: 30px;
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
}
.status {
background: rgba(255, 255, 255, 0.2);
padding: 15px;
margin: 10px 0;
border-radius: 5px;
text-align: center;
}
.shutdown {
background: rgba(220, 53, 69, 0.3);
padding: 15px;
margin: 10px 0;
border-radius: 5px;
text-align: center;
}
</style>
</head>
<body>
<div class="container">
<h1>🛡️ Rust 优雅 Web 服务器</h1>
<p>支持优雅关闭的生产级Web服务器</p>
<div class="status">
<h3>📊 服务器状态</h3>
<p>服务器正在运行,可以安全地使用 Ctrl+C 关闭</p>
</div>
<div class="shutdown">
<h3>🔒 优雅关闭</h3>
<p>当收到关闭信号时,服务器会:</p>
<ul>
<li>停止接受新连接</li>
<li>等待现有连接完成</li>
<li>清理资源后退出</li>
</ul>
</div>
<div style="text-align: center; margin-top: 20px;">
<button onclick="simulateLongRequest()" style="padding: 10px 20px; font-size: 16px;">
模拟长时间请求
</button>
</div>
</div>
<script>
async function simulateLongRequest() {
const startTime = Date.now();
const response = await fetch('/simulate-long-task');
const text = await response.text();
const duration = (Date.now() - startTime) / 1000;
alert(`请求完成!耗时: ${duration}秒\n响应: ${text}`);
}
</script>
</body>
</html>
"#;
fs::write("index.html", index_html)?;
println!("已创建优雅关闭版本的 index.html 文件");
Ok(())
}
}
// 宏:在作用域结束时执行代码
macro_rules! defer {
($($body:tt)*) => {
let _guard = {
pub struct Guard<F: FnOnce()>(Option<F>);
impl<F: FnOnce()> Drop for Guard<F> {
fn drop(&mut self) {
if let Some(f) = self.0.take() {
f()
}
}
}
Guard(Some(|| $($body)*))
};
};
}
// 增强的请求处理器,支持模拟长时间任务
struct GracefulRequestHandler;
impl GracefulRequestHandler {
fn handle_graceful_request(&self, stream: &mut TcpStream) -> std::io::Result<()> {
let mut buffer = [0; 4096];
let bytes_read = stream.read(&mut buffer)?;
if bytes_read == 0 {
return Ok(());
}
let request_str = String::from_utf8_lossy(&buffer[..bytes_read]);
match HttpRequest::parse(&request_str) {
Ok(request) => {
let response = self.process_graceful_request(request);
stream.write_all(&response.to_bytes())?;
}
Err(e) => {
let response = HttpResponse::new(HttpStatus::InternalServerError)
.with_text_body(&format!("解析请求失败: {}", e));
stream.write_all(&response.to_bytes())?;
}
}
stream.flush()?;
Ok(())
}
fn process_graceful_request(&self, request: HttpRequest) -> HttpResponse {
match request.path.as_str() {
"/simulate-long-task" => self.simulate_long_task(),
_ => self.serve_file(&request.path),
}
}
fn simulate_long_task(&self) -> HttpResponse {
// 模拟一个长时间运行的任务(5秒)
println!("开始模拟长时间任务...");
thread::sleep(Duration::from_secs(5));
println!("长时间任务完成");
HttpResponse::new(HttpStatus::Ok)
.with_text_body("长时间任务已完成!")
}
fn serve_file(&self, path: &str) -> HttpResponse {
let actual_path = if path == "/" {
"index.html"
} else {
&path[1..]
};
match fs::read(actual_path) {
Ok(content) => {
let content_type = self.get_content_type(actual_path);
HttpResponse::new(HttpStatus::Ok)
.with_header("Content-Type", content_type)
.with_body(content)
}
Err(_) => {
let not_found_html = r#"
<!DOCTYPE html>
<html>
<head><title>404 Not Found</title></head>
<body>
<h1>404 Not Found</h1>
<p>请求的资源不存在。</p>
</body>
</html>
"#;
HttpResponse::new(HttpStatus::NotFound)
.with_html_body(not_found_html)
}
}
}
fn get_content_type(&self, filename: &str) -> &str {
if filename.ends_with(".html") {
"text/html; charset=utf-8"
} else if filename.ends_with(".css") {
"text/css"
} else if filename.ends_with(".js") {
"application/javascript"
} else if filename.ends_with(".png") {
"image/png"
} else if filename.ends_with(".jpg") || filename.ends_with(".jpeg") {
"image/jpeg"
} else {
"application/octet-stream"
}
}
}
fn main() -> std::io::Result<()> {
println!("启动优雅Web服务器...");
let server = GracefulWebServer::new("127.0.0.1:8080", 4);
server.run()
}
配置管理和监控
生产级服务器还需要配置管理和运行监控。
use serde::{Deserialize, Serialize};
use std::time::Duration;
// 服务器配置
#[derive(Debug, Serialize, Deserialize)]
struct ServerConfig {
address: String,
thread_pool_size: usize,
max_connections: usize,
shutdown_timeout_secs: u64,
static_files_dir: String,
log_level: String,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
address: "127.0.0.1:8080".to_string(),
thread_pool_size: 4,
max_connections: 100,
shutdown_timeout_secs: 30,
static_files_dir: ".".to_string(),
log_level: "info".to_string(),
}
}
}
impl ServerConfig {
fn load() -> std::io::Result<Self> {
// 尝试从配置文件加载
if let Ok(content) = fs::read_to_string("config.toml") {
toml::from_str(&content).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, e)
})
} else {
// 使用默认配置
let config = ServerConfig::default();
config.save()?;
Ok(config)
}
}
fn save(&self) -> std::io::Result<()> {
let content = toml::to_string_pretty(self).unwrap();
fs::write("config.toml", content)
}
}
// 服务器监控
struct ServerMonitor {
start_time: std::time::Instant,
total_requests: std::sync::atomic::AtomicUsize,
active_connections: std::sync::atomic::AtomicUsize,
}
impl ServerMonitor {
fn new() -> Self {
Self {
start_time: std::time::Instant::now(),
total_requests: std::sync::atomic::AtomicUsize::new(0),
active_connections: std::sync::atomic::AtomicUsize::new(0),
}
}
fn record_request(&self) {
self.total_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
fn increment_connections(&self) {
self.active_connections.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
fn decrement_connections(&self) {
self.active_connections.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
}
fn get_stats(&self) -> ServerStats {
let uptime = self.start_time.elapsed();
let total_requests = self.total_requests.load(std::sync::atomic::Ordering::Relaxed);
let active_connections = self.active_connections.load(std::sync::atomic::Ordering::Relaxed);
let requests_per_second = if uptime.as_secs() > 0 {
total_requests as f64 / uptime.as_secs() as f64
} else {
0.0
};
ServerStats {
uptime,
total_requests,
active_connections,
requests_per_second,
}
}
}
#[derive(Debug)]
struct ServerStats {
uptime: Duration,
total_requests: usize,
active_connections: usize,
requests_per_second: f64,
}
impl std::fmt::Display for ServerStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"运行时间: {:?}, 总请求: {}, 活动连接: {}, 请求/秒: {:.2}",
self.uptime, self.total_requests, self.active_connections, self.requests_per_second
)
}
}
22.4 性能优化与扩展
连接池和Keep-Alive
use std::collections::VecDeque;
use std::sync::{Arc, Condvar, Mutex};
// 连接池
struct ConnectionPool {
connections: Mutex<VecDeque<TcpStream>>,
condition: Condvar,
max_size: usize,
}
impl ConnectionPool {
fn new(max_size: usize) -> Self {
Self {
connections: Mutex::new(VecDeque::new()),
condition: Condvar::new(),
max_size,
}
}
fn get_connection(&self) -> Option<TcpStream> {
let mut connections = self.connections.lock().unwrap();
while connections.is_empty() {
connections = self.condition.wait(connections).unwrap();
}
connections.pop_front()
}
fn return_connection(&self, stream: TcpStream) {
let mut connections = self.connections.lock().unwrap();
if connections.len() < self.max_size {
connections.push_back(stream);
self.condition.notify_one();
}
// 如果连接池已满,丢弃连接
}
fn create_connections(&self, address: &str, count: usize) -> std::io::Result<()> {
let mut connections = self.connections.lock().unwrap();
for _ in 0..count {
let stream = TcpStream::connect(address)?;
connections.push_back(stream);
}
self.condition.notify_all();
Ok(())
}
}
// 支持Keep-Alive的HTTP处理器
struct KeepAliveHandler {
timeout: Duration,
}
impl KeepAliveHandler {
fn new(timeout: Duration) -> Self {
Self { timeout }
}
fn handle_keep_alive(&self, stream: &mut TcpStream) -> std::io::Result<()> {
stream.set_read_timeout(Some(self.timeout))?;
loop {
let mut buffer = [0; 4096];
let bytes_read = match stream.read(&mut buffer) {
Ok(0) => break, // 连接关闭
Ok(n) => n,
Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => {
// 读取超时,检查是否应该保持连接
continue;
}
Err(e) => return Err(e),
};
let request_str = String::from_utf8_lossy(&buffer[..bytes_read]);
// 检查Connection头部
let keep_alive = request_str.to_lowercase().contains("connection: keep-alive");
let response = self.process_request(&request_str);
stream.write_all(&response.to_bytes())?;
stream.flush()?;
if !keep_alive {
break;
}
}
Ok(())
}
fn process_request(&self, request: &str) -> HttpResponse {
// 简化的请求处理
HttpResponse::new(HttpStatus::Ok)
.with_header("Connection", "keep-alive")
.with_text_body("Hello with Keep-Alive!")
}
}
最终的生产级服务器
// 完整的生产级Web服务器
struct ProductionWebServer {
config: ServerConfig,
monitor: Arc<ServerMonitor>,
state: Arc<ServerState>,
}
impl ProductionWebServer {
fn new() -> std::io::Result<Self> {
let config = ServerConfig::load()?;
let monitor = Arc::new(ServerMonitor::new());
let state = Arc::new(ServerState::new());
Ok(Self {
config,
monitor,
state,
})
}
fn run(&self) -> std::io::Result<()> {
println!("启动生产级Web服务器...");
println!("配置: {:#?}", self.config);
self.setup_signal_handlers();
self.create_default_files()?;
let listener = TcpListener::bind(&self.config.address)?;
listener.set_nonblocking(true)?;
let thread_pool = ThreadPool::new(self.config.thread_pool_size);
let handler = GracefulRequestHandler;
println!("服务器运行在 http://{}", self.config.address);
// 主服务器循环
while !self.state.should_shutdown() {
match listener.accept() {
Ok((stream, addr)) => {
if !self.state.increment_connections() {
println!("拒绝新连接 from {} (服务器正在关闭)", addr);
continue;
}
self.monitor.increment_connections();
let state = Arc::clone(&self.state);
let monitor = Arc::clone(&self.monitor);
let handler = GracefulRequestHandler;
thread_pool.execute(move || {
defer! {
state.decrement_connections();
monitor.decrement_connections();
monitor.record_request();
}
if let Err(e) = handler.handle_graceful_request(&mut stream.try_clone().unwrap()) {
eprintln!("处理连接时出错: {}", e);
}
});
// 定期打印状态
self.print_status();
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(10));
continue;
}
Err(e) => {
eprintln!("接受连接时出错: {}", e);
}
}
}
self.initiate_shutdown()
}
fn setup_signal_handlers(&self) {
let state = Arc::clone(&self.state);
ctrlc::set_handler(move || {
println!("\n收到关闭信号,开始优雅关闭...");
state.initiate_shutdown();
}).expect("设置信号处理器失败");
}
fn initiate_shutdown(&self) -> std::io::Result<()> {
println!("\n开始优雅关闭过程...");
println!("最终统计: {}", self.monitor.get_stats());
if self.state.wait_for_zero_connections(Duration::from_secs(self.config.shutdown_timeout_secs)) {
println!("所有活动连接已完成");
} else {
println!("警告: 在超时前并非所有连接都完成");
}
println!("服务器关闭完成");
Ok(())
}
fn print_status(&self) {
let stats = self.monitor.get_stats();
if stats.total_requests % 100 == 0 {
println!("服务器状态: {}", stats);
}
}
fn create_default_files(&self) -> std::io::Result<()> {
let index_html = format!(
r#"
<!DOCTYPE html>
<html>
<head>
<title>Rust Production Web Server</title>
<style>
body {{
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}}
.container {{
background: rgba(255, 255, 255, 0.1);
padding: 30px;
border-radius: 10px;
backdrop-filter: blur(10px);
}}
.stats {{
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 10px;
margin: 20px 0;
}}
.stat-item {{
background: rgba(255, 255, 255, 0.15);
padding: 15px;
border-radius: 5px;
text-align: center;
}}
</style>
</head>
<body>
<div class="container">
<h1>🏭 Rust 生产级 Web 服务器</h1>
<p>配置完善的生产环境Web服务器</p>
<div class="stats">
<div class="stat-item">
<strong>线程池大小</strong>
<div>{}</div>
</div>
<div class="stat-item">
<strong>最大连接数</strong>
<div>{}</div>
</div>
<div class="stat-item">
<strong>关闭超时</strong>
<div>{}秒</div>
</div>
<div class="stat-item">
<strong>静态文件目录</strong>
<div>{}</div>
</div>
</div>
<div style="text-align: center; margin-top: 20px;">
<button onclick="loadStats()" style="padding: 10px 20px; font-size: 16px;">
加载服务器统计
</button>
</div>
<div id="stats" style="margin-top: 20px;"></div>
</div>
<script>
async function loadStats() {{
const response = await fetch('/stats');
const stats = await response.json();
const statsDiv = document.getElementById('stats');
statsDiv.innerHTML = `
<div class="stat-item">
<strong>运行时间</strong>
<div>${{stats.uptime}}</div>
</div>
<div class="stat-item">
<strong>总请求数</strong>
<div>${{stats.total_requests}}</div>
</div>
<div class="stat-item">
<strong>活动连接</strong>
<div>${{stats.active_connections}}</div>
</div>
<div class="stat-item">
<strong>请求/秒</strong>
<div>${{stats.requests_per_second.toFixed(2)}}</div>
</div>
`;
}}
</script>
</body>
</html>
"#,
self.config.thread_pool_size,
self.config.max_connections,
self.config.shutdown_timeout_secs,
self.config.static_files_dir
);
fs::write("index.html", index_html)?;
println!("已创建生产级版本的 index.html 文件");
Ok(())
}
}
fn main() -> std::io::Result<()> {
let server = ProductionWebServer::new()?;
server.run()
}
总结
本章详细介绍了如何使用Rust构建一个完整的Web服务器,从单线程基础版本逐步发展到生产级的多线程服务器:
- 单线程Web服务器:理解HTTP协议基础,实现基本的请求处理和响应生成
- 多线程架构:使用线程池处理并发连接,提高服务器性能
- 优雅停机:实现信号处理、连接跟踪和资源清理,确保服务器可以安全关闭
- 性能优化与扩展:添加配置管理、监控统计、连接池等生产级特性
通过这个完整的项目,我们学习了:
- TCP网络编程和HTTP协议处理
- 多线程编程和线程池管理
- 资源管理和优雅关闭
- 配置系统和监控统计
- 错误处理和代码组织
这个Web服务器项目展示了Rust在系统编程中的强大能力,包括内存安全、高性能并发和可靠的错误处理。这些技能对于构建生产级的网络服务至关重要。





