Rust 构建高性能 HTTP/HTTPS 代理服务器

851 阅读7分钟

在现代网络应用中,代理服务器扮演着至关重要的角色。它们不仅可以用于负载均衡、缓存和安全过滤,还可以用于调试和分析网络流量。本文将介绍如何使用Rust和Tokio异步运行时构建一个简单的HTTP/HTTPS代理服务器。我们将深入探讨代码的各个部分,并解释其工作原理。

1. 项目概述

我们的代理服务器将支持HTTP和HTTPS协议。对于HTTP请求,代在现代网络应用中,代理服务器扮演着至关重要的角色。它们不仅可以用于负载均衡、缓存和安全过滤,还可以用于调试和分析网络流量。

本文将介绍如何使用Rust和Tokio异步运行时构建一个简单的HTTP/HTTPS代理服务器。我们将深入探讨代码的各个部分,并解释其工作原理。在现代网络应用中,代理服务器扮演着至关重要的角色。

它们不仅可以用于负载均衡、缓存和安全过滤,还可以用于调试和分析网络流量。本文将介绍如何使用Rust和Tokio异步运行时构建一个简单的HTTP/HTTPS代理服务器。我们将深入探讨代码的各个部分,并解释其工作原理。理服务器将直接转发请求和响应。对于HTTPS请求,代理服务器将建立一个隧道,允许客户端与目标服务器进行安全的通信。

2. 依赖库

在开始编写代码之前,我们需要引入一些依赖库:

tokio:一个异步运行时,用于处理网络I/O操作。

url:用于解析和操作URL。

std::io:标准库中的I/O模块。

std::sync::Arc和tokio::sync::Mutex:用于处理并发和共享状态。

[dependencies]
tokio = { version = "1", features = ["full"] }
url = "2.2"

3. 错误处理

在代理服务器中,错误处理是一个重要的方面。我们定义了一个自定义错误类型ProxyError,并实现了From trait,以便从标准库和其他库的错误类型转换为ProxyError。

#[derive(Debug)]
struct ProxyError(String);

impl std::fmt::Display for ProxyError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}"self.0)
    }
}

impl std::error::Error for ProxyError {}

impl From<io::Error> for ProxyError {
    fn from(error: io::Error) -> Self {
        ProxyError(error.to_string())
    }
}

impl From<url::ParseError> for ProxyError {
    fn from(error: url::ParseError) -> Self {
        ProxyError(error.to_string())
    }
}

4. 主函数

主函数负责启动代理服务器并监听来自客户端的连接。每当有新的连接时,它将创建一个异步任务来处理该连接。

#[tokio::main]
async fn main() -> Result<(), ProxyError> {
    let listener = TcpListener::bind("127.0.0.1:8080").await?;
    println!("Proxy server listening on 127.0.0.1:8080");

    loop {
        let (client_stream, addr) = listener.accept().await?;
        println!("New connection from: {}", addr);
        
        task::spawn(async move {
            if let Err(e) = handle_client(client_stream).await {
                eprintln!("Error handling client {}: {:?}", addr, e);
            }
        });
    }
}

5. 处理客户端请求

handle_client函数负责读取客户端的HTTP请求,并根据请求类型(HTTP或HTTPS)调用相应的处理函数。

async fn handle_client(mut client_stream: TcpStream) -> Result<(), ProxyError> {
    // 读取并解析初始请求
    let request = read_http_request(&mut client_stream).await?;
    
    if request.method == "CONNECT" {
        // 处理 HTTPS 请求
        handle_https_tunnel(client_stream, &request).await?;
    } else {
        // 处理 HTTP 请求
        handle_http_proxy(client_stream, request).await?;
    }

    Ok(())
}

6. 读取HTTP请求

read_http_request函数负责从客户端流中读取HTTP请求,并解析请求行、头部和请求体。

async fn read_http_request(stream: &mut TcpStream) -> Result<HttpRequest, ProxyError> {
    let mut headers = Vec::new();
    let mut header_buffer = Vec::new();
    let mut temp_buffer = [0u81];
    let mut header_size = 0;

    // 读取 HTTP 头部
    loop {
        if header_size >= MAX_HEADER_SIZE {
            return Err(ProxyError("HTTP header too large".to_string()));
        }

        stream.read_exact(&mut temp_buffer).await?;
        header_buffer.push(temp_buffer[0]);
        header_size += 1;

        if header_buffer.len() >= 4 &&
           header_buffer[header_size - 4] == b'\r' &&
           header_buffer[header_size - 3] == b'\n' &&
           header_buffer[header_size - 2] == b'\r' &&
           header_buffer[header_size - 1] == b'\n' {
            break;
        }
    }

    let header_str = String::from_utf8_lossy(&header_buffer);
    let mut lines = header_str.lines();

    // 解析请求行
    let request_line = lines.next().ok_or_else(|| ProxyError("Empty request".to_string()))?;

    let partsVec<&str> = request_line.split_whitespace().collect();
    if parts.len() != 3 {
        return Err(ProxyError("Invalid request line".to_string()));
    }

    let method = parts[0].to_string();
    let uri = if method == "CONNECT" {
        Url::parse(&format!("https://{}", parts[1]))?
    } else {
        Url::parse(parts[1])?
    };
    let version = parts[2].to_string();

    // 解析头部
    let mut content_length = 0;
    let mut content_type = String::new();
    for line in lines {
        if line.is_empty() {
            break;
        }
        if let Some((name, value)) = line.split_once(':') {
            let name = name.trim().to_lowercase();
            let value = value.trim().to_string();
            
            if name == "content-length" {
                content_length = value.parse::<usize>().unwrap_or(0);
            } else if name == "content-type" {
                content_type = value.clone();
            }
            
            headers.push((name, value));
        }
    }

    // 读取请求体
    let mut body = Vec::new();
    if content_length > 0 {
        println!("\n[Request Body Info] Content-Type: {}, Length: {} bytes", 
            content_type, content_length);
        
        let mut remaining = content_length;
        let mut buffer = vec![0; BUFFER_SIZE.min(remaining)];
        
        while remaining > 0 {
            let to_read = buffer.len().min(remaining);
            let n = stream.read(&mut buffer[..to_read]).await?;
            if n == 0 {
                break;
            }
            body.extend_from_slice(&buffer[..n]);
            remaining -= n;
        }

        // 打印请求体内容
        if PRINT_RESPONSE_BODY {
            if is_text_request(&content_type) {
                if let Ok(body_str) = String::from_utf8(body.clone()) {
                    println!("\n[Request Body Preview]");
                    let display_len = body_str.len().min(MAX_BODY_PRINT_LENGTH);
                    println!("{}", &body_str[..display_len]);
                    if body_str.len() > MAX_BODY_PRINT_LENGTH {
                        println!("... (truncated {} remaining bytes)", 
                            body_str.len() - MAX_BODY_PRINT_LENGTH);
                    }
                }
            } else if content_type.starts_with("multipart/form-data") {
                println!("\n[Multipart Form Data]");
                println!("Binary data: {} bytes", body.len());
            } else if !content_type.is_empty() {
                println!("\n[Binary Request Body]");
                println!("Size: {} bytes", body.len());
            }
        }
    }

    Ok(HttpRequest {
        method,
        uri,
        version,
        headers,
        body,
    })
}

7. 处理HTTPS隧道

对于HTTPS请求,代理服务器需要建立一个隧道,允许客户端与目标服务器进行安全的通信。handle_https_tunnel函数负责建立隧道并处理双向数据传输。

async fn handle_https_tunnel(mut client_stream: TcpStream, request: &HttpRequest) -> Result<(), ProxyError> {
    let host = request.uri.host_str().ok_or_else(|| ProxyError("Missing host".to_string()))?;
    let port = request.uri.port().unwrap_or(443);

    // 连接到目标服务器
    let server_stream = match TcpStream::connect(format!("{}:{}", host, port)).await {
        Ok(stream) => stream,
        Err(e) => {
            let error_response = format!(
                "HTTP/1.1 502 Bad Gateway\r\n\
                 Connection: close\r\n\
                 Content-Length: 0\r\n\r\n"
            );
            client_stream.write_all(error_response.as_bytes()).await?;
            return Err(e.into());
        }
    };

    // 发送成功响应给客户端
    client_stream.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n").await?;

    // 设置双向数据传输
    let (mut client_reader, mut client_writer) = tokio::io::split(client_stream);
    let (mut server_reader, mut server_writer) = tokio::io::split(server_stream);

    // 创建两个任务来处理双向数据传输
    let client_to_server = tokio::spawn(async move {
        let result = tokio::io::copy(&mut client_reader, &mut server_writer).await;
        if let Ok(bytes) = result {
            println!("Client to server: {} bytes transferred", bytes);
        }
        result
    });

    let server_to_client = tokio::spawn(async move {
        let result = tokio::io::copy(&mut server_reader, &mut client_writer).await;
        if let Ok(bytes) = result {
            println!("Server to client: {} bytes transferred", bytes);
        }
        result
    });

    // 等待任意一个传输完成或出错
    match tokio::try_join!(client_to_server, server_to_client) {
        Ok(_) => Ok(()),
        Err(e) => Err(ProxyError(e.to_string()))
    }
}

8. 处理HTTP代理

对于HTTP请求,代理服务器将直接转发请求和响应。handle_http_proxy函数负责连接到目标服务器,发送请求,并处理响应。

async fn handle_http_proxy(mut client_stream: TcpStream, request: HttpRequest) -> Result<(), ProxyError> {
    let host = request.uri.host_str().ok_or_else(|| ProxyError("Missing host".to_string()))?;
    let port = request.uri.port().unwrap_or(80);

    println!("\n[Request] {} {} {}", request.method, request.uri, request.version);
    for (name, value) in &request.headers {
        println!("[Request Header] {}: {}", name, value);
    }

    // 连接到目标服务器
    let mut server_stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
    println!("[Connected] to {}:{}", host, port);

    // 构建新的请求
    let mut new_request = format!(
        "{} {} {}\r\n",
        request.method,
        request.uri.path(),
        request.version
    );

    // 添加修改后的请求头
    for (name, value) in request.headers {
        if name.to_lowercase() != "proxy-connection" {
            new_request.push_str(&format!("{}: {}\r\n", name, value));
        }
    }
    new_request.push_str("\r\n");

    // 发送请求到目标服务器
    server_stream.write_all(new_request.as_bytes()).await?;
    
    // 如果有请求体,发送请求体
    if !request.body.is_empty() {
        server_stream.write_all(&request.body).await?;
    }

    // 处理响应
    let mut response_buffer = Vec::new();
    let mut buffer = vec![0; BUFFER_SIZE];
    let mut is_header = true;
    let mut headers_str = String::new();
    let mut content_type = String::new();

    loop {
        let n = match server_stream.read(&mut buffer).await {
            Ok(0) => break, // 连接关闭
            Ok(n) => n,
            Err(e) => return Err(ProxyError(e.to_string())),
        };

        response_buffer.extend_from_slice(&buffer[..n]);

        // 如果还在处理头部,尝试解析和打印头部信息
        if is_header {
            if let Ok(current_str) = String::from_utf8(response_buffer.clone()) {
                if let Some(header_end) = current_str.find("\r\n\r\n") {
                    headers_str = current_str[..header_end].to_string();
                    println!("\n[Response Headers]");
                    
                    // 解析并存储Content-Type
                    for line in headers_str.lines() {
                        println!("{}", line);
                        if line.to_lowercase().starts_with("content-type:") {
                            content_type = line.split(':').nth(1).unwrap_or("").trim().to_string();
                        }
                    }
                    
                    is_header = false;
                    
                    // 开始新的响应体部分
                    println!("\n[Response Body]");
                    if !PRINT_RESPONSE_BODY {
                        println!("Body printing disabled");
                    }
                }
            }
        }

        if let Err(e) = client_stream.write_all(&buffer[..n]).await {
            return Err(ProxyError(e.to_string()));
        }
    }

    // 打印响应体(如果启用)
    if PRINT_RESPONSE_BODY {
        if let Some(header_end) = String::from_utf8_lossy(&response_buffer).find("\r\n\r\n") {
            let body_start = header_end + 4;
            let body = &response_buffer[body_start..];
            
            if is_text_response(&content_type) {
                let body_str = String::from_utf8_lossy(body);
                let display_len = body_str.len().min(MAX_BODY_PRINT_LENGTH);
                println!("{}", &body_str[..display_len]);
                if body_str.len() > MAX_BODY_PRINT_LENGTH {
                    println!("\n... (truncated {} remaining bytes)", body_str.len() - MAX_BODY_PRINT_LENGTH);
                }
            } else if PRINT_BINARY_RESPONSE {
                println!("Binary response ({} bytes):", body.len());
                for chunk in body.chunks(16).take(32) {
                    print!("{:02X} ", chunk[0]);
                    for &byte in chunk.iter().skip(1) {
                        print!("{:02X} ", byte);
                    }
                    println!();
                }
                if body.len() > 512 {
                    println!("... (truncated {} remaining bytes)", body.len() - 512);
                }
            } else {
                println!("Binary response ({} bytes, not displayed)", body.len());
            }
        }
    }

    println!("\n[Response Complete] Total bytes received: {}", response_buffer.len());
    println!("Content-Type: {}", content_type);

    Ok(())
}

9. 总结

通过本文,我们详细介绍了如何使用Rust和Tokio构建一个异步的HTTP/HTTPS代理服务器。我们讨论了如何处理不同类型的请求、建立HTTPS隧道、以及如何处理和转发响应。这个代理服务器不仅可以用于调试和分析网络流量,还可以作为学习Rust和异步编程的一个很好的示例。

希望这篇文章能帮助你更好地理解代理服务器的工作原理,并激发你进一步探索Rust和异步编程的兴趣。