JAVA原生Servlet支持SSE

150 阅读4分钟

SSE原理

(MDN)SSE文档

SSE非常轻量,当框架有严格的超时时间,但某个业务处理非常耗时,可以用它绕过超时限制,或者需要大批量推流时,都可以使用它来做.

SSE其实类似于文件下载,但是有特定的格式以让EventSource正常解析.

必要响应头: Content-Type: text/event-stream.

常用报文格式:

event: 事件1名称
data: 事件1消息

event: 事件2名称
data: 事件2消息

后端代码

直接从HttpServletResponse中调用startAsync创建异步上下文,并在其他线程中国呢使用这个异步上下文推送消息

package local.my.demo.controller;

import cn.hutool.core.date.DateUtil;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;

import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.Date;
import java.util.concurrent.SynchronousQueue;

@Controller
public class DemoController {

    static class SseMsg {
        String event;
        String data;
        public SseMsg(String event, String data) {
            this.event = event;
            this.data = data;
        }
    }


    @RequestMapping(value = "/sse/test",method = RequestMethod.GET)
    public void sseServlet(HttpServletResponse resp, HttpServletRequest req){
        resp.setContentType(MediaType.TEXT_EVENT_STREAM_VALUE);
        resp.setCharacterEncoding(StandardCharsets.UTF_8.name());
        AsyncContext aCtx = req.startAsync(req,resp);
        //总的超时时间,超时将断开
        aCtx.setTimeout(1000*60*10);
        long[] lastSendTime = {0};
        boolean[] live = {true};
        SynchronousQueue<SseMsg> queue = new SynchronousQueue<>();
        new Thread(()->{
            for(;;){
                if (live[0]){
                    //每隔一会儿发送一次保活消息避免单次间隔超时(外部中间件可能有对应超时设置)
                    long l = System.currentTimeMillis();
                    if (l - lastSendTime[0] > 1500){
                        queue.offer(new SseMsg("live",String.valueOf(l)));
                    }
                }else{
                    break;
                }
            }
        },"保活").start();
        new Thread(()->{
            try {
                for(;;){
                    SseMsg take = queue.take();
                    PrintWriter writer = aCtx.getResponse().getWriter();
                    writer.print("event: ");
                    writer.print(take.event);
                    writer.print("\n");
                    writer.print("data: ");
                    writer.print(take.data);
                    writer.print("\n\n");
                    writer.flush();
                    lastSendTime[0] = System.currentTimeMillis();
                    if ("stop".equals(take.event)){
                        aCtx.complete();
                        live[0] = false;
                        break;
                    }
                }
            }catch (Exception e){
                e.printStackTrace();
            }
        },"发送").start();
        new Thread(()->{
            try{
                String max = aCtx.getRequest().getParameter("max");
                int maxInt = max==null||max.isEmpty()?10:Integer.parseInt(max);
                for (int i = 0; i < Math.max(maxInt, 10); i++) {
                    queue.offer(new SseMsg("msg","模拟SSE"+i+" "+ DateUtil.format(new Date(),"HH:mm:ss")));
                    //模拟处理耗时
                    Thread.sleep(i%8==0?5000:500);
                }
                queue.offer(new SseMsg("stop","stop"));
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        },"业务").start();
    }
}

前端代码

监听不同的消息种类,有不同的处理方式,以分别处理正常业务消息和其他消息

function newSSE(url) {
    const evtSource = new EventSource(url, {withCredentials: true})
    evtSource.addEventListener('live', function (event) {
          var data = event.data;
          //解析保活消息
          console.log("保活: "+data);
        }, false);
    evtSource.addEventListener('msg', function (event) {
          var data = event.data;
          //解析正常业务消息
          console.log("消息: "+data);
        }, false);
    evtSource.onerror = function (event) {
        console.log("close evtSource")
        evtSource.close()
    };
}

var sseHost = 'http://127.0.0.1:8080'
newSSE(sseHost+'/sse/test?max=300')

测试截图

提取成一个工具方法

将异步方法摘出来便于调用

package local.my.demo.controller;

import cn.hutool.json.JSONUtil;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncEvent;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.http.MediaType;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.concurrent.ScheduledFuture;
import java.util.stream.Stream;


@Component
@Slf4j
public class SseUtil implements InitializingBean, DisposableBean, ApplicationContextAware {

    private static final ThreadPoolTaskExecutor sseExecutor = new ThreadPoolTaskExecutor();
    private static final ThreadPoolTaskScheduler keepExecutor = new ThreadPoolTaskScheduler();
    private static ApplicationContext appCtx;

    public static void sendJson(final HttpServletRequest req,final HttpServletResponse resp,final int timeoutSec,final Stream<?> ...streams) {
        new AsyncCtx(req, resp, timeoutSec, streams);
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        sseExecutor.setCorePoolSize(2);
        sseExecutor.setMaxPoolSize(16);
        //不等待,而是立即扩充线程池保障响应
        sseExecutor.setQueueCapacity(0);
        sseExecutor.setThreadNamePrefix("SSEWORK");
        sseExecutor.setKeepAliveSeconds(30);
        sseExecutor.setApplicationContext(appCtx);
        sseExecutor.setAllowCoreThreadTimeOut(true);
        sseExecutor.setWaitForTasksToCompleteOnShutdown(false);
        sseExecutor.initialize();
        keepExecutor.setPoolSize(1);
        keepExecutor.setRemoveOnCancelPolicy(true);
        keepExecutor.setApplicationContext(appCtx);
        keepExecutor.setWaitForTasksToCompleteOnShutdown(false);
        keepExecutor.setThreadNamePrefix("SSEKEEP");
        keepExecutor.initialize();
    }

    @Override
    public void destroy() throws Exception {
        sseExecutor.shutdown();
        keepExecutor.shutdown();
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        appCtx = applicationContext;
    }

    static ScheduledFuture<?>   keep(final AsyncCtx asyncCtx){
        long delaySec = 29;//单次推送最大间隔的一半
        final ScheduledFuture<?> f = keepExecutor.scheduleWithFixedDelay(()-> {
            if (asyncCtx.done || asyncCtx.ctx.getResponse() == null || asyncCtx.ctx.getResponse().isCommitted()){
                asyncCtx.keep.cancel(true);
                return;
            }
            //高频消息不发送心跳
            if (System.currentTimeMillis()-asyncCtx.lastTime < delaySec*1000L){
                return;
            }
            try {
                PrintWriter writer = asyncCtx.ctx.getResponse().getWriter();
                //空消息表示心跳
                writer.print("data: \n\n");
                writer.flush();
            } catch (Exception e) {
                log.error("sse keep writer error", e);
                asyncCtx.keep.cancel(true);
                asyncCtx.ctx.complete();
            }
        },Duration.ofSeconds(delaySec));
        return f;
    }

    private static class AsyncCtx{
        final AsyncContext ctx;
        volatile long lastTime;
        volatile boolean done=false;
        final ScheduledFuture<?> keep;
        AsyncCtx(final HttpServletRequest req,final HttpServletResponse resp,final int timeoutSec,final Stream<?> ...streams) {
            resp.setCharacterEncoding(StandardCharsets.UTF_8.name());
            resp.setContentType(MediaType.TEXT_EVENT_STREAM_VALUE);
            ctx = req.startAsync(req, resp);
            ctx.setTimeout(Math.max(timeoutSec, 10)*1000L);
            lastTime = System.currentTimeMillis();
            final AsyncCtx it = this;
            keep = keep(it);
            ctx.addListener(new AsyncListener() {
                public void onComplete(AsyncEvent event) throws IOException {
                    keep.cancel(true);
                }
                public void onTimeout(AsyncEvent event) throws IOException {
                    keep.cancel(true);
                }
                public void onError(AsyncEvent event) throws IOException {
                    keep.cancel(true);
                }
                public void onStartAsync(AsyncEvent event) throws IOException {}
            });
            sseSend(streams, this);
        }
    }

    private static void sseSend(Stream<?>[] streams, AsyncCtx asyncCtx) {
        sseExecutor.execute(() -> {
            if (asyncCtx.ctx.getResponse().isCommitted()){
                log.error("sse response is committed");
                return;
            }
            try {
                Stream.of(streams).forEach(stream -> stream.forEach(take -> {
                    try {
                        PrintWriter writer = asyncCtx.ctx.getResponse().getWriter();
                        String jsonStr = JSONUtil.toJsonStr(take);
                        String[] split = jsonStr.split("\n");
                        for (String s : split) {
                            writer.print("data: ");
                            writer.print(s);
                            writer.print("\n");
                        }
                        writer.print("\n");
                        writer.flush();
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                    asyncCtx.lastTime=System.currentTimeMillis();
                }));
            }catch (Exception e){
                log.error("sse data writer error",e);
            }finally {
                asyncCtx.ctx.complete();
                asyncCtx.done=true;
            }
        });
    }

}

调用案例

package local.my.demo.controller;

import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.map.MapUtil;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;

import java.util.stream.Stream;

@Controller
public class DemoController2 {

    @GetMapping("/sse/test")
    public void sseTest(HttpServletRequest req, HttpServletResponse resp){
        SseUtil.sendJson(req,resp,10, Stream.of("[1]","[2]", ListUtil.of("3"), MapUtil.of("4","5")));
    }

}

也可以用Spring的SseEmitter实现

JAVA代码

package local.my.sb.demospringboot.controller;

import cn.hutool.core.date.DateUtil;
import cn.hutool.json.JSONUtil;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.util.Date;
import java.util.concurrent.SynchronousQueue;

@Controller
public class DemoController2 {


    public static class SseMsg {
        public String event;
        public String data;
        public SseMsg(String event, String data) {
            this.event = event;
            this.data = data;
        }
    }

    @RequestMapping(value = "/sse/test2",method = RequestMethod.GET)
    public SseEmitter sse(@RequestParam(required = false) String max){
        SseEmitter sse = new SseEmitter(1000*60*10L/*总的超时时间*/);
        long[] lastSendTime = {0};
        boolean[] live = {true};
        SynchronousQueue<SseMsg> queue = new SynchronousQueue<>();
        new Thread(()->{
            for(;;){
                if (live[0]){
                    //每隔一会儿发送一次保活消息避免超时
                    long l = System.currentTimeMillis();
                    if (l - lastSendTime[0] > 1500){
                        queue.offer(new SseMsg("live",""));
                    }
                }else{
                    break;
                }
            }
        },"保活").start();
        new Thread(()->{
            try {
                for(;;){
                    SseMsg take = queue.take();
                    sse.send(JSONUtil.toJsonStr(take));
                    lastSendTime[0] = System.currentTimeMillis();
                    if ("stop".equals(take.event)){
                        sse.complete();
                        live[0] = false;
                        break;
                    }
                }
            }catch (Exception e){
                e.printStackTrace();
            }
        },"发送").start();
        new Thread(()->{
            try{
                int maxInt = max==null||max.isEmpty()?10:Integer.parseInt(max);
                for (int i = 0; i < Math.max(maxInt, 10); i++) {
                    queue.offer(new SseMsg("msg","模拟SSE"+i+" "+ DateUtil.format(new Date(),"HH:mm:ss")));
                    //模拟处理耗时
                    Thread.sleep(i%8==0?5000:500);
                }
                queue.offer(new SseMsg("stop",""));
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        },"业务").start();
        return sse;
    }
}

JS代码

function newSSE(url) {
    const evtSource = new EventSource(url, {withCredentials: true})
    evtSource.onmessage = (event) => {
        console.log(`message: ${event.data}`)
    };
    evtSource.onerror = function (event) {
        console.log("close evtSource")
        evtSource.close()
    };
}
newSSE('/sse/test2')

测试截图