【LLM实战】基于WebFlux手撕简易的SSE协议

59 阅读5分钟

一、SSE协议简介

SSE 协议诞生较早,但最近因大模型的兴起,SSE协议重新火热了起来。 对于SSE协议的历史,有兴趣的小伙伴可以自行AI。略...

SSE协议的主要特点:

  1. 基于HTTP协议设计的,所以无需额外的组件支持
  2. 服务器返回的响应头为 Content-Type: text/event-stream
  3. SSE协议返回的数据由多个数据块组成,不同数据块之间以 \n\n 分隔,最后以data: [DONE] 标记结束
  4. 数据块内包含多个约定的字段id: data: 等,不同字段以\n分割

二、SSE协议实战

本文以 SpringBoot 项目为背景,使用WebFlux响应式编程框架中的WebClient组件,介绍了如何应用SSE协议。

2.1 发送SSE协议的数据

2.2 接收SSE协议的数据

本小节主要介绍了,在对接LLM大模型的接口时,如何处理返回的数据。其中包含标准的SSE协议,也包含了自定义的数据。

2.3 标准SSE协议解析

介绍:WebClient组件提供的很多灵活的接口,可以十分方便的解析SSE协议。

@Slf4j
@Builder
public class LLMHttpStreamUtils {

    private volatile String baseUrl;

    private volatile String path;

    private volatile Map<String, String> headers;

    private volatile String body;

    private volatile HttpStreamCallback callback;

    /**
     * 发起请求的线程等待Stream将全部数据返回,等待的最大时间
     */
    private static final int STREAM_API_TIME_OUT = 6 * 60 * 1000;

    /**
     * 按照标准的SSE协议解析
     */
    public void postByStream() {
        // 初始化计数器
        AtomicInteger contentCount = new AtomicInteger(0);

        // 用于等待流式接口返回结束
        CountDownLatch countDownLatch = new CountDownLatch(1);
        boolean awaitResult = false;


        // request
        WebClient.create(baseUrl)
                .post()
                .uri(path)
                .headers(httpHeaders -> {
                    // header
                    httpHeaders.add(HttpHeaders.CONNECTION, "Keep-Alive");
                    httpHeaders.add(HttpHeaders.ACCEPT_CHARSET, "UTF-8");
                    httpHeaders.add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE);
                    for (Map.Entry<String, String> entry : headers.entrySet()) {
                        httpHeaders.add(entry.getKey(), entry.getValue());
                    }
                })
                .body(Mono.just(body), String.class)
                .exchangeToFlux(res -> {
                    if (res.statusCode().is2xxSuccessful()) {
                        // 2xx
                        // 关键:按照期望的格式解析数据
                        return res.bodyToFlux(new ParameterizedTypeReference<ServerSentEvent<String>>() {
                        });
                    }
                    // 4xx 5xx
                    int code = res.statusCode().value();
                    String message = res.statusCode().getReasonPhrase();
                    log.error("[LLM-Stream] request failed. code {}, message:{}", code, message);
                    if (callback != null) {
                        callback.onError(code, message);
                    }
                    return Flux.error(new RuntimeException("[LLM-Stream] Error while fetching. code: " + code + ", message: " + message));
                })
                .doOnError(WebClientResponseException.class, err -> {
                    log.error("[LLM-Stream] [ERROR] status:{}, msg:{}", err.getRawStatusCode(), err.getResponseBodyAsString());
                })
                .doFinally(sig -> {
                    log.info("[LLM-Stream] [TERMINATED] with sig: {}", sig);
                    // 通知 controller 线程停止等待
                    countDownLatch.countDown();

                })
                .subscribe(content -> {
                    contentCount.getAndIncrement();
                    if (callback != null) {
                        callback.onContent(content.id(), content.data());
                    }

                }, error -> {
                    // 此处拿到的是前面返回的 Flux.error 对象
                    log.error("[LLM-Stream] [SUBSCRIBE] on error: {}", error.getMessage());

                }, () -> {
                    log.info("[LLM-Stream] [SUBSCRIBE] on complete.");

                });

        try {
            // 等待:流式接口的数据在子进程中返回
            awaitResult = countDownLatch.await(STREAM_API_TIME_OUT, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        log.info("[LLM-Stream] [POST By STREAM] on complete. count:{}, awaitResult:{}", contentCount.get(), awaitResult);
    }

关键点:

  1. 指定数据格式为.bodyToFlux(new ParameterizedTypeReference<ServerSentEvent<String>>() {}),自动完成解析
  2. 指定数据格式为.bodyToFlux(String.class),可以直接去SSE协议中data部分数据

2.4 手撕简易SSE协议

介绍:如果对接的LLM大模型接口中,有时返回SSE协议的数据,有时又会返回一些自定义的JSON内容。那该如何正确解析呢? 同时,在接口返回的数据中,如果遇到“粘包”和“拆包”问题,又该如何处理?

@Slf4j
@Builder
public class LLMHttpStreamUtils {

    private volatile String baseUrl;

    private volatile String path;

    private volatile Map<String, String> headers;

    private volatile String body;

    private volatile HttpStreamCallback callback;

    /**
     * 发起请求的线程等待Stream将全部数据返回,等待的最大时间
     */
    private static final int STREAM_API_TIME_OUT = 6 * 60 * 1000;

    /**
     * SSE 协议中的分隔符
     */
    private static final String SSE_PREFIX_ID = "id: ";
    private static final String SSE_PREFIX_DATA = "data: ";
    private static final String SSE_SEPARATOR_GROUP = "\n\n";

    /**
     * SSE 协议中的分隔符的长度
     */
    private static final int SSE_PREFIX_ID_LENGTH = SSE_PREFIX_ID.length();
    private static final int SSE_PREFIX_DATA_LENGTH = SSE_PREFIX_DATA.length();
    private static final int SSE_SEPARATOR_GROUP_LENGTH = SSE_SEPARATOR_GROUP.length();

    /**
     * CustomServerSentEvent
     * <p>
     * 参考 https://www.cnblogs.com/Fzeng/p/18735256
     * 应对 算法侧接口返回的两种数据结构:
     * 1. 标准的SSE协议
     * 2. 包含提示和错误信息的JSON,如 {"code":xxx,"message":"Exceeded maximum session limit"}
     */
    public void postByStream() {
        // 初始化计数器
        AtomicInteger contentCount = new AtomicInteger(0);

        // 用于等待流式接口返回结束
        CountDownLatch countDownLatch = new CountDownLatch(1);
        boolean awaitResult = false;

        // 用于存储返回的数据流:解决 拆包和粘包 问题
        StringBuffer contentBuffer = new StringBuffer();

        // request
        WebClient.create(baseUrl)
                .post()
                .uri(path)
                .headers(httpHeaders -> {
                    // header
                    httpHeaders.add(HttpHeaders.CONNECTION, "Keep-Alive");
                    httpHeaders.add(HttpHeaders.ACCEPT_CHARSET, "UTF-8");
                    httpHeaders.add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE);
                    for (Map.Entry<String, String> entry : headers.entrySet()) {
                        httpHeaders.add(entry.getKey(), entry.getValue());
                    }
                })
                .body(Mono.just(body), String.class)
                .exchangeToFlux(res -> {
                    if (res.statusCode().is2xxSuccessful()) {
                        // 2xx
                        // 关键:指定 DataBuffer.class 类型来处理
                        return res.bodyToFlux(DataBuffer.class);
                    }
                    // 4xx 5xx
                    int code = res.statusCode().value();
                    String message = res.statusCode().getReasonPhrase();
                    log.error("[LLM-Stream] request failed. code {}, message:{}", code, message);
                    if (callback != null) {
                        callback.onError(code, message);
                    }
                    return Flux.error(new RuntimeException("[LLM-Stream] Error while fetching. code: " + code + ", message: " + message));
                })
                .doOnError(WebClientResponseException.class, err -> {
                    log.error("[LLM-Stream] [ERROR] status:{}, msg:{}", err.getRawStatusCode(), err.getResponseBodyAsString());
                })
                .doFinally(sig -> {
                    log.info("[LLM-Stream] [TERMINATED] with sig: {}", sig);
                    // 在finally中,清空 contentBuffer 中的数据
                    // 场景:若接收到的数据不符合SSE格式,则会残留在 contentBuffer 中。所以,需要清空
                    if (callback != null && !contentBuffer.isEmpty()) {
                        String remain = contentBuffer.toString();
                        log.info("[LLM- Stream] [TERMINATED] contentBuffer flush: {}", remain);
                        callback.onContent("-1", remain);
                    }
                    // 通知 controller 线程停止等待
                    countDownLatch.countDown();

                })
                .subscribe(content -> {
//                    log.debug("[LLM-Stream] [SUBSCRIBE] on subscribe3. content: {}", content.toString());
                    contentCount.getAndIncrement();
                    if (callback != null) {
                        parseSSE(content, contentBuffer).forEach(c -> {
                            callback.onContent(c.id(), c.data());
                        });
                    }

                }, error -> {
                    // 此处拿到的是前面返回的 Flux.error 对象
                    log.error("[LLM-Stream] [SUBSCRIBE] on error: {}", error.getMessage());

                }, () -> {
                    log.info("[LLM-Stream] [SUBSCRIBE] on complete.");

                });

        try {
            // 等待:流式接口的数据在子进程中返回
            awaitResult = countDownLatch.await(STREAM_API_TIME_OUT, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        log.info("[LLM-Stream] [POST By STREAM] on complete. count:{}, awaitResult:{}", contentCount.get(), awaitResult);
    }

    /**
     * 根据响应的数据,按照约定的 SSE 协议解析
     * 约定:
     * 1、返回的一组数据中,仅包含 id 和 data 两个字段,
     * 2、每个字段之间,通过一个换行符(\n)进行分割
     * 3、每一组数据之间通过两个换行符(\n\n)进行分割
     */
    public static List<ServerSentEvent<String>> parseSSE(DataBuffer dataBuffer, StringBuffer contentBuffer) {
        contentBuffer.append(dataBuffer.toString(StandardCharsets.UTF_8));
        List<ServerSentEvent<String>> contentList = new ArrayList<>();
        int separatorIndex;
        while ((separatorIndex = contentBuffer.indexOf(SSE_SEPARATOR_GROUP)) != -1) {
            // 从 dataBuffer 中读取一段数据
            byte[] contentBytes = new byte[separatorIndex + SSE_SEPARATOR_GROUP_LENGTH];
            String content = contentBuffer.substring(0, contentBytes.length - 1);
            contentBuffer.delete(0, contentBytes.length);

            // 检查 数据的有效性
            if (!content.startsWith(SSE_PREFIX_ID)) {
                throw new RuntimeException("[LLM-Stream] Error: " + content);
            }

            // 解析 id: data: 标志,并校验有效性
            int idIndex = content.indexOf(SSE_PREFIX_ID);
            int dataIndex = content.indexOf(SSE_PREFIX_DATA);
            int idValueIndex = idIndex + SSE_PREFIX_ID_LENGTH;
            int dataValueIndex = dataIndex + SSE_PREFIX_DATA_LENGTH;
            if (idIndex == -1 || dataIndex == -1 || idValueIndex >= dataIndex) {
                throw new RuntimeException("[LLM-Stream] Error: parse {id}/{data} failed: " + content);
            }
            String idValue = content.substring(idValueIndex, dataIndex).trim();
            String dataValue = content.substring(dataValueIndex);
            ServerSentEvent<String> sse = ServerSentEvent.<String>builder().id(idValue).data(dataValue).build();
            //保存结果
            contentList.add(sse);
        }
        return contentList;
    }

关键:

  1. 指定数据格式为.bodyToFlux(DataBuffer.class),然后按照SSE协议格式,在parseSSE()中手动完成解析。
  2. DataBuffer 中的数据读取到 StringBuffer 中,按照分隔符\n\n循环解析,避免出现“粘包”和“拆包”问题。
  3. doFinally() 中清空 StringBuffer ,用于返回自定义的数据,如包含提示或错误信息的JSON。因为这些自定义的数据不符合 SSE协议格式,所以无法在parseSSE()中完成解析,会残留在StringBuffer 中。

2.5 补充

公共的数据结构

public interface HttpStreamCallback {

    void onContent(String id, String data);

    void onError(int code, String message);
}