java socket实现HTTP代理服务器(HTTP版),支持篡改请求

330 阅读4分钟

java socket实现HTTP代理服务器(HTTP版),支持篡改请求

Context

由于某视频播放软件在某些功能上是收费的,为了巩固SSL协议及HTTP协议,经过分析http请求报文,发现通过篡改HTTP报文来达到此目的。而要实现HTTPS协议篡改,需要解析HTTP协议,并重新组装返回给客户端或者服务端。而且在HTTPS ssl协议握手的过程中,需要伪装服务器来达到握手校验成功,类似于中间人攻击。而我们平时使用的抓包工具就类似于一个中间人,通过让客户端信任自签证书,让自签证书颁发伪造的server证书,然后传递给客户端。在实现该代理时有两大难点:1、模拟https握手过程 2、解析http协议并重新组装

​ 我们先来简单了解下https从握手到发送消息的整个流程,如下图所示:

img

对https请求进行代理分析 回到本文主题,那么想要对https请求进行代理应该如何实现呢?

在了解了https的通信过程后,那么我们有两种办法可以对https的请求进行代理:

获取到所要代理网站https证书颁发机构的私钥,也就是ca根证书的私钥,然后自己再重新颁发一个新的证书返回给被代理的客户端 自己生成一个ca证书,然后导入到将要被代理的客户端中,让其信任,随后再针对将要代理的请求动态生成https证书 通过分析后我们可以知道,想要获取到ca根证书的私钥是不太可能的,据说ca根证书都是离线存储的,一般人拿不到的(一个https证书一年收费上千块不是开玩笑的),ca的代理机构的证书也是这个道理。

那么通过上面的再次分析后通过方案1来进行请求代理的可行性还高一些,其代理过程可以简单如下图: img 在有了上面的分析后,其实想要自己去实现一个https的代理服务器还是有一定难度的,https握手的细节实现就足以让人费事费力了。但在同样大名鼎鼎的netty框架面前这些都是小事儿!netty中的SslContext类帮我们完成了这些细节的实现,我们只管如何调用它遍可完成对https的握手了,框架就是框架,强大哇!但是,为了深入学习http、ssl协议及java NIO,本文先介绍通过java socket来实现https代理服务器,后期有机会再补上,还望各位大佬见谅

如果要详细了解https协议,请参考我的另一篇文章:https原理剖析及实战

java socket解析http协议

代码写的有些糟糕,没有优化,还望见谅!

  • 启动类,用来接收并创建socket,这里主线程会循环监听新的socket,如果有新的socket,就会提交到线程池中,让另一个线程去处理该请求
public class ServerSocketThreadPool {
    private static Logger logger = LoggerFactory.getLogger(ServerSocketThreadPool.class);

    public static void main(String[] args) {

        ThreadPoolExecutor threadPoolExecutor = null;
        Socket socket = null;
        try {
            ResponseProcessorHolder.initProcessor();
            // 初始化服务端socket并且绑定9999端口
            ServerSocket serverSocket = new ServerSocket();
            int port = 8888;
            serverSocket.bind(new InetSocketAddress(port));
            threadPoolExecutor = new ThreadPoolExecutor(Runtime.getRuntime().availableProcessors(), Runtime.getRuntime().availableProcessors() * 2, 3, TimeUnit.SECONDS, new ArrayBlockingQueue<>(30000), new ThreadPoolExecutor.AbortPolicy());
            logger.info("============代理服务启动成功==========");
            logger.info("============使用端口为{}==========", port);
            logger.info("===================================", port);
            int count = 1;
            while (true) {
                //new ServerSocket会监听本地11111端口的连接,当有新的连接时,会加入到操作系统连接队列中,如果超过50个,会直接拒绝报Connection refused
                //如果队列为空,则accept会阻塞,等待新的连接到来,然后从队列中拿出来处理。
                socket = serverSocket.accept();
                //有新的连接,新建线程去执行消息处理任务
                logger.info("有新的连接 {}: {} ", count++, socket.toString());
                threadPoolExecutor.submit(new ProxyHandler(threadPoolExecutor, socket));
            }
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (socket != null) {
                try {
                    socket.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }

            }
            if (threadPoolExecutor.isShutdown()) {
                threadPoolExecutor.shutdown();
            }
        }

    }


}
  • 代理处理器,代理处理器中维护了socket对象,通过socket对象我们可以获取到客户端的输入信息,通过客户端的输入流解析请求头,如果请求的Method是CONNECT,说明接下来的是https请求。发送CONNECT请求是为了客户端和代理服务器建立隧道。我们在本节内容中先不转发https请求,否则会带来性能问题。因为https请求密文,我们不知道包体的长度,所以在读取数据时只能通过死循环的方式去读取数据。如下图所示,对于https的数据只能通过死循环的方式读取客户端数据。image-20220714145313784

    在本文中,我们跳过https代理及解析。代理处理器解析完请求头以后,然后会和远程服务器建立socket连接,连接建立完成后会转发客户端的请求给远程服务器,远程服务器再把响应体返回给代理服务器。代理服务器随后解析远程服务器的响应体,并执行替换响应的功能,然后再把响应写回到客户端。

public class ProxyHandler implements Runnable {
    private static String CRLF = "\r\n";
    private static String D_CRLF = "\r\n\r\n";
    private static String CHUNKED_END = "\r\n0\r\n";
    private static Logger logger = LoggerFactory.getLogger(ProxyHandler.class);

    private ThreadPoolExecutor threadPoolExecutor;
    private Socket socket;
    private String remoteHost;

    public String getRemoteHost() {
        return remoteHost;
    }

    public ProxyHandler setRemoteHost(String remoteHost) {
        this.remoteHost = remoteHost;
        return this;
    }


    public ProxyHandler(ThreadPoolExecutor threadPoolExecutor, Socket socket) {
        this.threadPoolExecutor = threadPoolExecutor;
        this.socket = socket;
    }

    @Override
    public void run() {
        MDC.put("traceId", String.valueOf(UUID.randomUUID()).replaceAll("-", ""));
        long begin = System.currentTimeMillis();
        OutputStream clientOutput = null;
        InputStream clientInput = null;

        Socket proxySocket = null;
        InputStream proxyInput = null;
        OutputStream proxyOutput = null;
        MyHttpClient myHttpClient = new MyHttpClient(threadPoolExecutor);
        try {
            //监听客户端请求的socket
            //从客户端输入的流
            clientInput = socket.getInputStream();
            //输出到客户端的流
            clientOutput = socket.getOutputStream();
            Request request = myHttpClient.getRequest();
            Response response = myHttpClient.getResponse();
            //解析客户端请求头
            Header header = request.parseHead(clientInput);
            request.setHeader(header);
            request.setAttribute(request);
            if (!Filter.permit(request.getHost())) {
                return;
            }
            logger.info(" [请求头信息{}]:\n{}", request.getUrl(), request.getHeader().getOrigin());
            //代理服务器与远程服务器建立连接
            Socket remoteSocket = myHttpClient.proxyConnectToRemoteServer(request.getHost(), request.getPort());
            proxyInput = remoteSocket.getInputStream();
            proxyOutput = remoteSocket.getOutputStream();
            //根据HTTP method判断是否是https请求,https请求方法为CONNECT
            //如果是https请求,首先和客户端先建立连接,只有代理https请求时才会和客户端建立该连接
            if (request.isHttps()) {//https先建立隧道
                clientOutput.write("HTTP/1.1 200 Connection Established\r\n\r\n".getBytes());
                clientOutput.flush();
            }
            myHttpClient.sendToServer(remoteSocket, clientInput, clientOutput);
            ResponseProcessorHolder.invokeProcessor(request, response, clientOutput, proxyInput);

        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            logger.info("[{}]执行时间: {}", myHttpClient.getRequest().getUrl(), System.currentTimeMillis() - begin);
            MDC.remove("traceId");
            if (proxyInput != null) {
                try {
                    proxyOutput.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (proxyOutput != null) {
                try {
                    proxyOutput.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (proxySocket != null) {
                try {
                    proxySocket.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (clientInput != null) {
                try {
                    clientInput.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (clientOutput != null) {
                try {
                    clientOutput.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (socket != null) {
                try {
                    socket.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }

    }

解析请求header

public Header parseHead(InputStream inputStream) throws Exception {
        Header header = new Header();
        String line;
        LineBuffer lineBuffer = new LineBuffer(1024);
        StringBuilder headBuffer = new StringBuilder();
        //readline按照换行符读取,会使用掉换行符,所以要补充换行符。读取HTTP请求头,并拿到HOST请求头和method
        int index = 0;
        while (null != (line = lineBuffer.readLine(inputStream))) {
            headBuffer.append(line + "\r\n");
            //读到头末尾结束循环
            if (line.length() == 0) {
                break;
            }
            if (index == 0) {//首行
                String[] split = line.split(" ");
                if (split[0].equals("CONNECT")) {
                    header.addElement("isHttps", "true");
                } else {
                    header.addElement("isHttps", "false");
                    header.addElement("requestMethod", split[0]);
                }
                header.addElement("url", split[1]);
                header.addElement("protocol", split[2]);

            } else {
                if (!StringUtil.isNull(line)) {
                    String key = line.substring(0, line.indexOf(":")).trim();
                    String value = line.substring(line.indexOf(":") + 1).trim();
                    header.addElement(key, value);
                }
            }

            index++;
        }
        header.setOrigin(headBuffer.toString());
        return header;
}

发送请求到远程服务器相关代码

/**
 * 发送请求
 */
public void sendToServer(Socket remoteSocket, InputStream clientInput, OutputStream clientOutput) throws Exception {
    if (remoteSocket.isConnected()) {
        //将数据发送给远程server
        writeMsg(request, response, clientInput, remoteSocket.getOutputStream());
    }
}

private void writeMsg(Request request,Response response , InputStream clientInput, OutputStream proxyOutput) {
    try {
        //由于clientInputStream已经被读取了header的一部分,所以只能写出解析后的数据
        if (!request.isHttps()) {
            proxyOutput.write(request.getHeader().getOrigin().getBytes());
        }
        if (!request.isHttps() && request.getRequestMethod().equals("GET")) {//get请求无body
            return;
        }
        byte[] bytes = new byte[1024 << 3];
        int n = -1;
        long contentLength = request.getContentLength();
        long tempLength = 0;
        StringBuilder builder = new StringBuilder();
        while ((n = clientInput.read(bytes, 0, bytes.length)) != -1) {
            proxyOutput.write(bytes, 0, n);
            //todo charset问题
            builder.append(new String(bytes,0,n));
            tempLength += n;
            if (contentLength != 0 && tempLength == contentLength) {
                break;
            }
        }
        log.info("{} POST请求体:{} ", request.getUrl(), builder.toString());
    } catch (IOException e) {
        e.printStackTrace();
    } catch (Exception e) {
        e.printStackTrace();
    } finally {
        MDC.remove("traceId");
    }
}
  • 处理服务端响应

    //执行后置操作
    public static void invokeProcessor(Request request, Response response, OutputStream clientOutput, InputStream proxyInput) throws Exception {
        parseResponse(proxyInput, request, response);// 解析返回的请求数据
        // 向客户端响应替换后的响应报文
        byte[] replace = replace(request, response);
        ByteArrayInputStream inputStream = new ByteArrayInputStream(replace);
        int n=-1;
        byte[] bytes=new byte[1024<<3];
        while ((n = inputStream.read(bytes, 0, bytes.length)) != -1) {
            clientOutput.write(bytes,0,n);
        }
        log.info(" [发送替换后的值] {}]:\n{}", request.getUrl(), new String(replace));
    }
    
  • 解析响应header

 public Header parseHeader(InputStream in) throws Exception {
        Header header = new Header();
        int n;
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        int index = 0;
        StringBuilder originHeaderBuffer = new StringBuilder();
        while ((n = in.read()) != -1) {
            baos.write(n);
            //每读一行
            if ('\n' == n) {
                String line = baos.toString();
                baos.reset();
                originHeaderBuffer.append(line);
                if (StringUtil.isNull(line)) {// 这个时候表示到了空行。
                    break;
                }

                if (index == 0) {//首行
                    int count = 0;
                    byte[] lineBytes = line.getBytes();
                    ByteArrayOutputStream lineByteTemp = new ByteArrayOutputStream();
                    for (int i = 0; i < line.getBytes().length; i++) {
                        lineByteTemp.write(lineBytes[i]);
                        if (lineBytes[i] == ' ') {
                            count++;
                        }
                        if (count == 1 && header.getValue("protocol") == null) {
                            header.addElement("protocol", lineByteTemp.toString().trim());
                            lineByteTemp.reset();
                        } else if (count == 2 && header.getValue("status") == null) {
                            header.addElement("status", lineByteTemp.toString().trim());
                            lineByteTemp.reset();
                        }
                    }
                    header.addElement("message", lineByteTemp.toString().trim());
                    lineByteTemp = null;
                } else {
                    if (!StringUtil.isNull(line)) {
                        String key = line.substring(0, line.indexOf(":")).trim();
                        String value = line.substring(line.indexOf(":") + 1).trim();
                        header.addElement(key, value);
                    }
                }
                index++;
            }
        }
        header.setOrigin(originHeaderBuffer.toString());
        return header;
    }

  • 解析服务端响应体
 /**
     * 解析返回的数据
     */
    public static void parseResponse(InputStream in, Request request, Response response) throws Exception {
        //解析响应头
        Header header = response.parseHeader(in);
        response.setHeader(header);
        //设置其它属性
        response.setAttribute(response);
        log.debug("[{}] [返回的消息的头]:\n{}", request.getUrl(), header.getOrigin());
        ByteArrayOutputStream contentBytes = null;
        if (response.isChunked()) {
            contentBytes = parseChunked(in, request,response);
        } else {
            contentBytes = response.parseContentByContentLength(in, response.getContentLength());
        }

        byte[] contentBuf = contentBytes.toByteArray();
        // 当是gzip压缩的时候,进行解压
        if ("gzip".equals(response.getHeader().getValue("Content-Encoding"))) {
            contentBuf = GzipUtils.uncompress(contentBuf);
        }
        String context = dataByteToString(contentBuf, response.getCharset()).trim();
        response.setContentData(contentBuf);
        response.setContext(context);
        response.setContentLength(contentBuf.length);
        if (ContentType.hasTextContentType(response.getContentType())) {
            log.info("[{}] [解析的响应数据为]:\n{}", request.getUrl(), new String(context.getBytes("utf-8"), "utf-8"));
        }

    }
  • 解析chunked类型响应体
public static ByteArrayOutputStream parseChunked(InputStream inputStream, Request request, Response response) throws IOException {
        //实际数据
        ByteArrayOutputStream contentBytes = new ByteArrayOutputStream();
        int n = -1;
        byte[] bytes = new byte[1024 << 3];
        ByteArrayOutputStream lengthBytes = new ByteArrayOutputStream();

        //记录上一次找到\n的游标
        long chunkedLen = 0;
        int writeCount = 0;
        boolean beginWriteData = false;
        long contentLength = 0L;
        int count = 0;
        OUT:
        while (((n = inputStream.read(bytes, 0, bytes.length)) != -1)) {
            for (int i = 0; i < n; i++) {
                byte t = bytes[i];
                if (beginWriteData) {//开始写数据
                    //直到写到length,如果没数据提前退出
                    if (writeCount < chunkedLen) {
                        contentBytes.write(t);
                        writeCount++;
                    } else {
                        //准备开始记录长度
                        writeCount = 0;
                        beginWriteData = false;
                    }
                }
                if (!beginWriteData) {  //读长度
                    lengthBytes.write(t);
                    String lengthStr = new String(lengthBytes.toByteArray()).trim();
                    //找到0时,说明数据已经解析完

                    if (t == '\n' && lengthStr.length() != 0) {
                        //计算数据长度,此时\r\n也被写进去了,所以需要trim
                        chunkedLen = Integer.parseInt(lengthStr, 16);
                        if (chunkedLen == 0) {
                            break OUT;
                        }
                        contentLength += chunkedLen;
                        lengthBytes.reset();
                        //准备开始写数据
                        beginWriteData = true;
                    }
                }
                count++;
            }
        }
        response.setContentLength(contentLength);
        return contentBytes;

    }
  • 替换响应体
    /*/**
     * @Author: jiaww5
     * @Description:
     * @Date:
     */
    public static byte[] replace(Request request, Response response) throws Exception {
        for (ResponseProcessor responseProcessor : responseProcessorList) {
            responseProcessor.repalce(request,response);
        }
        //替换完后,将headerStr转为heaer
        Header header = response.parseHeader(new ByteArrayInputStream(response.getHeader().getOrigin().getBytes()));
        response.setHeader(header);
        //修正数据
        byte[] amend = amend(response);
        return amend;
    }

替换数据后修正数据

/**
 * @Author: jiaww5
 * @Description: 替换后修正数据
 * request:替换后的request
 * response:替换后的response
 * @Date:
 */
private static byte[] amend(Response response) throws Exception {
    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
    Header header = response.getHeader();
    String contentEncoding = header.getValue("Content-Encoding");
    if (!response.isChunked()) {
        //contentLength的
        if ("gzip".equals(contentEncoding)) {
            response.setContentData(GzipUtils.compress(response.getContentData()));
        }
        header.setElement("Content-Length", String.valueOf(response.getContentData().length));
    } else {//chunked
        if("gzip".equals(contentEncoding)){
            response.setContentData(GzipUtils.compress(response.getContentData()));
        }
        //分块
        response.setContentData(chunkedData(response.getContentData()));
    }
    //写头
    byteArrayOutputStream.write(headerToByte(header));
    //写body
    byteArrayOutputStream.write(response.getContentData());
    return byteArrayOutputStream.toByteArray();
}

private static byte[] headerToByte(Header header) {
    List<HeaderElement> headerElements = header.getHeaderElements();
    Iterator<HeaderElement> iterator = headerElements.iterator();
    StringBuilder builder = new StringBuilder();
    while (iterator.hasNext()) {
        HeaderElement next = iterator.next();
        String name = next.getName();
        String value = next.getValue();
        if ("protocol".equalsIgnoreCase(name)) {
            builder.append(value + " ");
        } else if ("status".equalsIgnoreCase(name)) {
            builder.append(value + " ");
        } else if ("message".equalsIgnoreCase(name)) {
            builder.append(value + "\r\n");
        } else {
            builder.append(name + ": " + value + "\r\n");
        }
    }
    builder.append("\r\n");
    return builder.toString().getBytes();
}


如果是chunked类型的数据,重新分块

/*/**
 * @Author: jiaww5
 * @Description: 将data分块
 * @Date:
 */
private static byte[] chunkedData(byte[] contentData) throws Exception {
    ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
    int segmentLen = 1024 << 2;
    if (contentData.length <= segmentLen) {
        outputStream.write(Integer.toHexString(contentData.length).getBytes());
        outputStream.write("\r\n".getBytes());
        outputStream.write(contentData);
    } else {
        int leftLength = contentData.length % segmentLen;
        int count = contentData.length / segmentLen;
        byte[] leby = Integer.toHexString(segmentLen).getBytes();
        for (int i = 0; i < contentData.length; i++) {
            if (i % segmentLen == 0) {
                byte[] lengthBytes = null;
                //写换行符
                if (i != 0) {
                    outputStream.write("\r\n".getBytes());
                }
                //写长度
                if (leftLength != 0 && i == count * segmentLen) {//有余数时
                    lengthBytes = Integer.toHexString(leftLength).getBytes();
                } else {//无余数
                    lengthBytes = leby;
                }
                outputStream.write(lengthBytes);
                //写\r\n
                outputStream.write("\r\n".getBytes());
            }
            outputStream.write(contentData[i]);
        }
    }

    outputStream.write("\r\n".getBytes());
    outputStream.write("0".getBytes());
    outputStream.write("\r\n".getBytes());
    outputStream.write("\r\n".getBytes());
    return outputStream.toByteArray();
}

public static String dataByteToString(byte[] contentData, String charset) throws UnsupportedEncodingException {
    if (StringUtils.isNotBlank(charset)) {
        return new String(contentData, charset);
    }
    return new String(contentData);
}

如需要源码,请留言