java socket实现HTTP代理服务器(HTTP版),支持篡改请求
Context
由于某视频播放软件在某些功能上是收费的,为了巩固SSL协议及HTTP协议,经过分析http请求报文,发现通过篡改HTTP报文来达到此目的。而要实现HTTPS协议篡改,需要解析HTTP协议,并重新组装返回给客户端或者服务端。而且在HTTPS ssl协议握手的过程中,需要伪装服务器来达到握手校验成功,类似于中间人攻击。而我们平时使用的抓包工具就类似于一个中间人,通过让客户端信任自签证书,让自签证书颁发伪造的server证书,然后传递给客户端。在实现该代理时有两大难点:1、模拟https握手过程 2、解析http协议并重新组装
我们先来简单了解下https从握手到发送消息的整个流程,如下图所示:
对https请求进行代理分析 回到本文主题,那么想要对https请求进行代理应该如何实现呢?
在了解了https的通信过程后,那么我们有两种办法可以对https的请求进行代理:
获取到所要代理网站https证书颁发机构的私钥,也就是ca根证书的私钥,然后自己再重新颁发一个新的证书返回给被代理的客户端 自己生成一个ca证书,然后导入到将要被代理的客户端中,让其信任,随后再针对将要代理的请求动态生成https证书 通过分析后我们可以知道,想要获取到ca根证书的私钥是不太可能的,据说ca根证书都是离线存储的,一般人拿不到的(一个https证书一年收费上千块不是开玩笑的),ca的代理机构的证书也是这个道理。
那么通过上面的再次分析后通过方案1来进行请求代理的可行性还高一些,其代理过程可以简单如下图:
在有了上面的分析后,其实想要自己去实现一个https的代理服务器还是有一定难度的,https握手的细节实现就足以让人费事费力了。但在同样大名鼎鼎的netty框架面前这些都是小事儿!netty中的SslContext类帮我们完成了这些细节的实现,我们只管如何调用它遍可完成对https的握手了,框架就是框架,强大哇!但是,为了深入学习http、ssl协议及java NIO,本文先介绍通过java socket来实现https代理服务器,后期有机会再补上,还望各位大佬见谅
如果要详细了解https协议,请参考我的另一篇文章:https原理剖析及实战
java socket解析http协议
代码写的有些糟糕,没有优化,还望见谅!
- 启动类,用来接收并创建socket,这里主线程会循环监听新的socket,如果有新的socket,就会提交到线程池中,让另一个线程去处理该请求
public class ServerSocketThreadPool {
private static Logger logger = LoggerFactory.getLogger(ServerSocketThreadPool.class);
public static void main(String[] args) {
ThreadPoolExecutor threadPoolExecutor = null;
Socket socket = null;
try {
ResponseProcessorHolder.initProcessor();
// 初始化服务端socket并且绑定9999端口
ServerSocket serverSocket = new ServerSocket();
int port = 8888;
serverSocket.bind(new InetSocketAddress(port));
threadPoolExecutor = new ThreadPoolExecutor(Runtime.getRuntime().availableProcessors(), Runtime.getRuntime().availableProcessors() * 2, 3, TimeUnit.SECONDS, new ArrayBlockingQueue<>(30000), new ThreadPoolExecutor.AbortPolicy());
logger.info("============代理服务启动成功==========");
logger.info("============使用端口为{}==========", port);
logger.info("===================================", port);
int count = 1;
while (true) {
//new ServerSocket会监听本地11111端口的连接,当有新的连接时,会加入到操作系统连接队列中,如果超过50个,会直接拒绝报Connection refused
//如果队列为空,则accept会阻塞,等待新的连接到来,然后从队列中拿出来处理。
socket = serverSocket.accept();
//有新的连接,新建线程去执行消息处理任务
logger.info("有新的连接 {}: {} ", count++, socket.toString());
threadPoolExecutor.submit(new ProxyHandler(threadPoolExecutor, socket));
}
} catch (IOException e) {
e.printStackTrace();
} finally {
if (socket != null) {
try {
socket.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (threadPoolExecutor.isShutdown()) {
threadPoolExecutor.shutdown();
}
}
}
}
-
代理处理器,代理处理器中维护了socket对象,通过socket对象我们可以获取到客户端的输入信息,通过客户端的输入流解析请求头,如果请求的Method是CONNECT,说明接下来的是https请求。发送CONNECT请求是为了客户端和代理服务器建立隧道。我们在本节内容中先不转发https请求,否则会带来性能问题。因为https请求密文,我们不知道包体的长度,所以在读取数据时只能通过死循环的方式去读取数据。如下图所示,对于https的数据只能通过死循环的方式读取客户端数据。
在本文中,我们跳过https代理及解析。代理处理器解析完请求头以后,然后会和远程服务器建立socket连接,连接建立完成后会转发客户端的请求给远程服务器,远程服务器再把响应体返回给代理服务器。代理服务器随后解析远程服务器的响应体,并执行替换响应的功能,然后再把响应写回到客户端。
public class ProxyHandler implements Runnable {
private static String CRLF = "\r\n";
private static String D_CRLF = "\r\n\r\n";
private static String CHUNKED_END = "\r\n0\r\n";
private static Logger logger = LoggerFactory.getLogger(ProxyHandler.class);
private ThreadPoolExecutor threadPoolExecutor;
private Socket socket;
private String remoteHost;
public String getRemoteHost() {
return remoteHost;
}
public ProxyHandler setRemoteHost(String remoteHost) {
this.remoteHost = remoteHost;
return this;
}
public ProxyHandler(ThreadPoolExecutor threadPoolExecutor, Socket socket) {
this.threadPoolExecutor = threadPoolExecutor;
this.socket = socket;
}
@Override
public void run() {
MDC.put("traceId", String.valueOf(UUID.randomUUID()).replaceAll("-", ""));
long begin = System.currentTimeMillis();
OutputStream clientOutput = null;
InputStream clientInput = null;
Socket proxySocket = null;
InputStream proxyInput = null;
OutputStream proxyOutput = null;
MyHttpClient myHttpClient = new MyHttpClient(threadPoolExecutor);
try {
//监听客户端请求的socket
//从客户端输入的流
clientInput = socket.getInputStream();
//输出到客户端的流
clientOutput = socket.getOutputStream();
Request request = myHttpClient.getRequest();
Response response = myHttpClient.getResponse();
//解析客户端请求头
Header header = request.parseHead(clientInput);
request.setHeader(header);
request.setAttribute(request);
if (!Filter.permit(request.getHost())) {
return;
}
logger.info(" [请求头信息{}]:\n{}", request.getUrl(), request.getHeader().getOrigin());
//代理服务器与远程服务器建立连接
Socket remoteSocket = myHttpClient.proxyConnectToRemoteServer(request.getHost(), request.getPort());
proxyInput = remoteSocket.getInputStream();
proxyOutput = remoteSocket.getOutputStream();
//根据HTTP method判断是否是https请求,https请求方法为CONNECT
//如果是https请求,首先和客户端先建立连接,只有代理https请求时才会和客户端建立该连接
if (request.isHttps()) {//https先建立隧道
clientOutput.write("HTTP/1.1 200 Connection Established\r\n\r\n".getBytes());
clientOutput.flush();
}
myHttpClient.sendToServer(remoteSocket, clientInput, clientOutput);
ResponseProcessorHolder.invokeProcessor(request, response, clientOutput, proxyInput);
} catch (Exception e) {
e.printStackTrace();
} finally {
logger.info("[{}]执行时间: {}", myHttpClient.getRequest().getUrl(), System.currentTimeMillis() - begin);
MDC.remove("traceId");
if (proxyInput != null) {
try {
proxyOutput.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (proxyOutput != null) {
try {
proxyOutput.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (proxySocket != null) {
try {
proxySocket.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (clientInput != null) {
try {
clientInput.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (clientOutput != null) {
try {
clientOutput.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (socket != null) {
try {
socket.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
解析请求header
public Header parseHead(InputStream inputStream) throws Exception {
Header header = new Header();
String line;
LineBuffer lineBuffer = new LineBuffer(1024);
StringBuilder headBuffer = new StringBuilder();
//readline按照换行符读取,会使用掉换行符,所以要补充换行符。读取HTTP请求头,并拿到HOST请求头和method
int index = 0;
while (null != (line = lineBuffer.readLine(inputStream))) {
headBuffer.append(line + "\r\n");
//读到头末尾结束循环
if (line.length() == 0) {
break;
}
if (index == 0) {//首行
String[] split = line.split(" ");
if (split[0].equals("CONNECT")) {
header.addElement("isHttps", "true");
} else {
header.addElement("isHttps", "false");
header.addElement("requestMethod", split[0]);
}
header.addElement("url", split[1]);
header.addElement("protocol", split[2]);
} else {
if (!StringUtil.isNull(line)) {
String key = line.substring(0, line.indexOf(":")).trim();
String value = line.substring(line.indexOf(":") + 1).trim();
header.addElement(key, value);
}
}
index++;
}
header.setOrigin(headBuffer.toString());
return header;
}
发送请求到远程服务器相关代码
/**
* 发送请求
*/
public void sendToServer(Socket remoteSocket, InputStream clientInput, OutputStream clientOutput) throws Exception {
if (remoteSocket.isConnected()) {
//将数据发送给远程server
writeMsg(request, response, clientInput, remoteSocket.getOutputStream());
}
}
private void writeMsg(Request request,Response response , InputStream clientInput, OutputStream proxyOutput) {
try {
//由于clientInputStream已经被读取了header的一部分,所以只能写出解析后的数据
if (!request.isHttps()) {
proxyOutput.write(request.getHeader().getOrigin().getBytes());
}
if (!request.isHttps() && request.getRequestMethod().equals("GET")) {//get请求无body
return;
}
byte[] bytes = new byte[1024 << 3];
int n = -1;
long contentLength = request.getContentLength();
long tempLength = 0;
StringBuilder builder = new StringBuilder();
while ((n = clientInput.read(bytes, 0, bytes.length)) != -1) {
proxyOutput.write(bytes, 0, n);
//todo charset问题
builder.append(new String(bytes,0,n));
tempLength += n;
if (contentLength != 0 && tempLength == contentLength) {
break;
}
}
log.info("{} POST请求体:{} ", request.getUrl(), builder.toString());
} catch (IOException e) {
e.printStackTrace();
} catch (Exception e) {
e.printStackTrace();
} finally {
MDC.remove("traceId");
}
}
-
处理服务端响应
//执行后置操作 public static void invokeProcessor(Request request, Response response, OutputStream clientOutput, InputStream proxyInput) throws Exception { parseResponse(proxyInput, request, response);// 解析返回的请求数据 // 向客户端响应替换后的响应报文 byte[] replace = replace(request, response); ByteArrayInputStream inputStream = new ByteArrayInputStream(replace); int n=-1; byte[] bytes=new byte[1024<<3]; while ((n = inputStream.read(bytes, 0, bytes.length)) != -1) { clientOutput.write(bytes,0,n); } log.info(" [发送替换后的值] {}]:\n{}", request.getUrl(), new String(replace)); } -
解析响应header
public Header parseHeader(InputStream in) throws Exception {
Header header = new Header();
int n;
ByteArrayOutputStream baos = new ByteArrayOutputStream();
int index = 0;
StringBuilder originHeaderBuffer = new StringBuilder();
while ((n = in.read()) != -1) {
baos.write(n);
//每读一行
if ('\n' == n) {
String line = baos.toString();
baos.reset();
originHeaderBuffer.append(line);
if (StringUtil.isNull(line)) {// 这个时候表示到了空行。
break;
}
if (index == 0) {//首行
int count = 0;
byte[] lineBytes = line.getBytes();
ByteArrayOutputStream lineByteTemp = new ByteArrayOutputStream();
for (int i = 0; i < line.getBytes().length; i++) {
lineByteTemp.write(lineBytes[i]);
if (lineBytes[i] == ' ') {
count++;
}
if (count == 1 && header.getValue("protocol") == null) {
header.addElement("protocol", lineByteTemp.toString().trim());
lineByteTemp.reset();
} else if (count == 2 && header.getValue("status") == null) {
header.addElement("status", lineByteTemp.toString().trim());
lineByteTemp.reset();
}
}
header.addElement("message", lineByteTemp.toString().trim());
lineByteTemp = null;
} else {
if (!StringUtil.isNull(line)) {
String key = line.substring(0, line.indexOf(":")).trim();
String value = line.substring(line.indexOf(":") + 1).trim();
header.addElement(key, value);
}
}
index++;
}
}
header.setOrigin(originHeaderBuffer.toString());
return header;
}
- 解析服务端响应体
/**
* 解析返回的数据
*/
public static void parseResponse(InputStream in, Request request, Response response) throws Exception {
//解析响应头
Header header = response.parseHeader(in);
response.setHeader(header);
//设置其它属性
response.setAttribute(response);
log.debug("[{}] [返回的消息的头]:\n{}", request.getUrl(), header.getOrigin());
ByteArrayOutputStream contentBytes = null;
if (response.isChunked()) {
contentBytes = parseChunked(in, request,response);
} else {
contentBytes = response.parseContentByContentLength(in, response.getContentLength());
}
byte[] contentBuf = contentBytes.toByteArray();
// 当是gzip压缩的时候,进行解压
if ("gzip".equals(response.getHeader().getValue("Content-Encoding"))) {
contentBuf = GzipUtils.uncompress(contentBuf);
}
String context = dataByteToString(contentBuf, response.getCharset()).trim();
response.setContentData(contentBuf);
response.setContext(context);
response.setContentLength(contentBuf.length);
if (ContentType.hasTextContentType(response.getContentType())) {
log.info("[{}] [解析的响应数据为]:\n{}", request.getUrl(), new String(context.getBytes("utf-8"), "utf-8"));
}
}
- 解析chunked类型响应体
public static ByteArrayOutputStream parseChunked(InputStream inputStream, Request request, Response response) throws IOException {
//实际数据
ByteArrayOutputStream contentBytes = new ByteArrayOutputStream();
int n = -1;
byte[] bytes = new byte[1024 << 3];
ByteArrayOutputStream lengthBytes = new ByteArrayOutputStream();
//记录上一次找到\n的游标
long chunkedLen = 0;
int writeCount = 0;
boolean beginWriteData = false;
long contentLength = 0L;
int count = 0;
OUT:
while (((n = inputStream.read(bytes, 0, bytes.length)) != -1)) {
for (int i = 0; i < n; i++) {
byte t = bytes[i];
if (beginWriteData) {//开始写数据
//直到写到length,如果没数据提前退出
if (writeCount < chunkedLen) {
contentBytes.write(t);
writeCount++;
} else {
//准备开始记录长度
writeCount = 0;
beginWriteData = false;
}
}
if (!beginWriteData) { //读长度
lengthBytes.write(t);
String lengthStr = new String(lengthBytes.toByteArray()).trim();
//找到0时,说明数据已经解析完
if (t == '\n' && lengthStr.length() != 0) {
//计算数据长度,此时\r\n也被写进去了,所以需要trim
chunkedLen = Integer.parseInt(lengthStr, 16);
if (chunkedLen == 0) {
break OUT;
}
contentLength += chunkedLen;
lengthBytes.reset();
//准备开始写数据
beginWriteData = true;
}
}
count++;
}
}
response.setContentLength(contentLength);
return contentBytes;
}
- 替换响应体
/*/**
* @Author: jiaww5
* @Description:
* @Date:
*/
public static byte[] replace(Request request, Response response) throws Exception {
for (ResponseProcessor responseProcessor : responseProcessorList) {
responseProcessor.repalce(request,response);
}
//替换完后,将headerStr转为heaer
Header header = response.parseHeader(new ByteArrayInputStream(response.getHeader().getOrigin().getBytes()));
response.setHeader(header);
//修正数据
byte[] amend = amend(response);
return amend;
}
替换数据后修正数据
/**
* @Author: jiaww5
* @Description: 替换后修正数据
* request:替换后的request
* response:替换后的response
* @Date:
*/
private static byte[] amend(Response response) throws Exception {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
Header header = response.getHeader();
String contentEncoding = header.getValue("Content-Encoding");
if (!response.isChunked()) {
//contentLength的
if ("gzip".equals(contentEncoding)) {
response.setContentData(GzipUtils.compress(response.getContentData()));
}
header.setElement("Content-Length", String.valueOf(response.getContentData().length));
} else {//chunked
if("gzip".equals(contentEncoding)){
response.setContentData(GzipUtils.compress(response.getContentData()));
}
//分块
response.setContentData(chunkedData(response.getContentData()));
}
//写头
byteArrayOutputStream.write(headerToByte(header));
//写body
byteArrayOutputStream.write(response.getContentData());
return byteArrayOutputStream.toByteArray();
}
private static byte[] headerToByte(Header header) {
List<HeaderElement> headerElements = header.getHeaderElements();
Iterator<HeaderElement> iterator = headerElements.iterator();
StringBuilder builder = new StringBuilder();
while (iterator.hasNext()) {
HeaderElement next = iterator.next();
String name = next.getName();
String value = next.getValue();
if ("protocol".equalsIgnoreCase(name)) {
builder.append(value + " ");
} else if ("status".equalsIgnoreCase(name)) {
builder.append(value + " ");
} else if ("message".equalsIgnoreCase(name)) {
builder.append(value + "\r\n");
} else {
builder.append(name + ": " + value + "\r\n");
}
}
builder.append("\r\n");
return builder.toString().getBytes();
}
如果是chunked类型的数据,重新分块
/*/**
* @Author: jiaww5
* @Description: 将data分块
* @Date:
*/
private static byte[] chunkedData(byte[] contentData) throws Exception {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
int segmentLen = 1024 << 2;
if (contentData.length <= segmentLen) {
outputStream.write(Integer.toHexString(contentData.length).getBytes());
outputStream.write("\r\n".getBytes());
outputStream.write(contentData);
} else {
int leftLength = contentData.length % segmentLen;
int count = contentData.length / segmentLen;
byte[] leby = Integer.toHexString(segmentLen).getBytes();
for (int i = 0; i < contentData.length; i++) {
if (i % segmentLen == 0) {
byte[] lengthBytes = null;
//写换行符
if (i != 0) {
outputStream.write("\r\n".getBytes());
}
//写长度
if (leftLength != 0 && i == count * segmentLen) {//有余数时
lengthBytes = Integer.toHexString(leftLength).getBytes();
} else {//无余数
lengthBytes = leby;
}
outputStream.write(lengthBytes);
//写\r\n
outputStream.write("\r\n".getBytes());
}
outputStream.write(contentData[i]);
}
}
outputStream.write("\r\n".getBytes());
outputStream.write("0".getBytes());
outputStream.write("\r\n".getBytes());
outputStream.write("\r\n".getBytes());
return outputStream.toByteArray();
}
public static String dataByteToString(byte[] contentData, String charset) throws UnsupportedEncodingException {
if (StringUtils.isNotBlank(charset)) {
return new String(contentData, charset);
}
return new String(contentData);
}
如需要源码,请留言