AI对话-处理SSE数据,实现逐字回复的效果

1,642 阅读2分钟

前言

AI对话逐字回复类似打字机的效果想必大家都已经习以为常了,最近正好实现了一个AI对话生成PPT的需求,这个逐字回复的效果,问了周围同僚发现也有不了解如何实现的。所以这里分享出来,倒也没有什么复杂的知识,作为一种实现方案供大家了解,希望能对需要的人有所帮助。

假设现在我们输入了问题:

image.png

首先请求接口,后端开启事件流,实时地向客户端以流的形式推送数据

接口返回数据块的例子,结构可参考各大AI平台:

data:{
    "event": "message", 
    "conversation_id": "xxx", 
    "message_id": "xxx", 
    "created_at": 1718246242, 
    "task_id": "xxx", 
    "id": "xxx", 
    "answer": "\u53d1\u9732\u3001\u62a4\u53d1\u7d20\uff0c\u4ee5\u53ca"  // 转为中文字符为:"暴露、挥发物,以及"
}

处理SSE数据

下面的代码是一个流式数据处理器,能够从服务器接收连续的数据流,解析并根据数据类型调用不同的处理函数。这种模式常用于实时通信或实时数据处理的场景。

主要的处理方法:

  1. ssePost - 这个函数用于发送一个POST请求,并调用 handleStream 处理返回的SSE流。
  2. handleStream - 这个函数用于处理SSE响应流。
// 发送对话请求
const ssePost = (    {
    onData,
    onCompleted,
    onThought,
    onFile,
    onMessageEnd,
    onMessageReplace,
    onWorkflowStarted,
    onWorkflowFinished,
    onNodeStarted,
    onNodeFinished,
    onTextChunk,
    onTextReplace,
    onError,
}) => {
    // 使用 AbortController 来控制请求的取消,例如用户取消对话
    const abortController = new AbortController()

    // 构造请求头和请求体
    const options = {
        method: 'POST',
        signal: abortController.signal,
        headers: new Headers({
            'Content-Type': 'application/json',
        }),
    }

    options.body = JSON.stringify({
        inputMessage: '输入的问题'
    })
    
    fetch(`接口地址`, options)
        .then((res) => {
            // 响应状态码不是2xx或3xx,显示错误信息并调用 onError 回调
            if (!/^(2|3)\d{2}$/.test(String(res.status))) {
                res.json().then((data) => {
                    message.error(data.message || '服务端错误')
                })
                onError?.('服务端错误')
                return
            }
            // 响应成功,调用 handleStream 函数来处理SSE流
            return handleStream(res, (str, isFirstMessage, moreInfo) => {
                if (moreInfo.errorMessage) {
                    onError?.(moreInfo.errorMessage, moreInfo.errorCode)
                    if (moreInfo.errorMessage !== '用户取消请求')
                        message.error(moreInfo.errorMessage)
                    return
                }
                onData?.(str, isFirstMessage, moreInfo)
            }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onTextChunk, onTextReplace)
        }).catch((e) => {
            if (e.toString() !== '用户取消请求')
                message.error(e)
            onError?.(e)
        })

}
// 处理SSE响应流
const handleStream = (
    response,
    onData,
    onCompleted,
    onThought,
    onMessageEnd,
    onMessageReplace,
    onFile,
    onWorkflowStarted,
    onWorkflowFinished,
    onNodeStarted,
    onNodeFinished,
    onTextChunk,
    onTextReplace,
) => {
    if (!response.ok)
        throw new Error('Network response was not ok')

    // 通过创建的 reader 对象,可以使用 reader.read() 方法来异步地从流中读取数据
    const reader = response.body?.getReader()
    // 创建一个 TextDecoder 对象,用于将流中的二进制数据解码为UTF-8格式的字符串
    const decoder = new TextDecoder('utf-8')
    let buffer = ''     // 累积从流中读取的数据
    let bufferObj       // 临时存储解析后的JSON对象
    let isFirstMessage = true

    // 用于递归读取流中的数据
    function read() {
        let hasError = false

        // 使用 reader.read() 从流中异步读取数据,返回一个包含 done 和 value 属性的对象
        reader?.read().then((result) => {
            if (result.done) { // 表示流已结束,调用 onCompleted 回调函数
                onCompleted && onCompleted(false,isFirstMessage)
                return
            }
            buffer += decoder.decode(result.value, { stream: true })
            // 如果 result.done 为 true,表示流已结束,调用 onCompleted 回调函数
            const lines = buffer.split('\n')
            try {
                lines.forEach((message) => {
                    if (message.startsWith('data:')) { // 接口约定:以data:开始
                        try {
                            bufferObj = JSON.parse(message.substring(5)) // 去掉 "data:" 并解析为 JSON
                        }
                        catch (e) {
                            // 消息被截断
                            onData('', isFirstMessage, {
                                conversationId: bufferObj?.conversation_id,
                                messageId: bufferObj?.message_id,
                            })
                            return
                        }

                        // 检查JSON对象是否有错误状态或缺少 event 字段,并相应地处理
                        if (bufferObj.status === 400 || !bufferObj.event) {
                            onData('', false, {
                                conversationId: undefined,
                                messageId: '',
                                errorMessage: bufferObj?.message,
                                errorCode: bufferObj?.code,
                            })
                            hasError = true
                            onCompleted?.(true, bufferObj?.message)
                            return
                        }

                        // 根据接口返回的 event 字段的值调用相应的回调函数处理不同类型的事件
                        if (bufferObj.event === 'message' || bufferObj.event === 'agent_message') {
                            onData(unicodeToChar(bufferObj.answer), isFirstMessage, {
                                conversationId: bufferObj.conversation_id,
                                taskId: bufferObj.task_id,
                                messageId: bufferObj.id,
                            })
                            isFirstMessage = false
                        }
                        else if (bufferObj.event === 'agent_thought') {
                            onThought?.(bufferObj)
                        }
                        else if (bufferObj.event === 'message_file') {
                            onFile?.(bufferObj)
                        }
                        else if (bufferObj.event === 'message_end') {
                            onMessageEnd?.(bufferObj)
                        }
                        else if (bufferObj.event === 'message_replace') {
                            onMessageReplace?.(bufferObj)
                        }
                        else if (bufferObj.event === 'workflow_started') {
                            onWorkflowStarted?.(bufferObj)
                        }
                        else if (bufferObj.event === 'workflow_finished') {
                            onWorkflowFinished?.(bufferObj)
                        }
                        else if (bufferObj.event === 'node_started') {
                            onNodeStarted?.(bufferObj)
                        }
                        else if (bufferObj.event === 'node_finished') {
                            onNodeFinished?.(bufferObj)
                        }
                        else if (bufferObj.event === 'text_chunk') {
                            onTextChunk?.(bufferObj)
                        }
                        else if (bufferObj.event === 'text_replace') {
                            onTextReplace?.(bufferObj)
                        }
                    }
                })

                // 更新 buffer 为未处理的剩余数据,即最后一行
                // 在处理流数据时,最后读取的数据可能不会立即构成一个完整的消息。
                // 例如,如果流在消息中间断开,那么最后读取的数据可能只是部分消息。通过将这部分数据保留在 buffer 中,可以确保在下一次调用 read 函数时,这部分数据能够被正确地继续处理。
                buffer = lines[lines.length - 1]
            }
            catch (e) {
                onData('', false, {
                    conversationId: undefined,
                    messageId: '',
                    errorMessage: `${e}`,
                })
                hasError = true
                onCompleted?.(true, e)
                return
            }
            if (!hasError)
                read()
        })
    }
    read()
}
// 将Unicode编码转换为字符,例如:"Hello, \u0041\u006C\u006C\u006F!",转为:"Hello, ALLO!"
function unicodeToChar(text) {
    if (!text)
        return ''

    return text.replace(/\u[0-9a-f]{4}/g, (_match, p1) => {
        return String.fromCharCode(parseInt(p1, 16))
    })
}

展示数据

大模型返回的数据可能包含空格、换行、字体加粗、代码公式等内容,如果你的项目使用 React,可以使用 react-markdown 库,配合 remark-math、remark-breaks、rehype-katex、remark-gfm 等与 react-markdown 库配合使用的插件,用于扩展 Markdown 的功能和改进其渲染效果。

<ReactMarkdown
  remarkPlugins={[remarkMath, remarkGfm, remarkBreaks]}
  rehypePlugins={[rehypeKatex]}
  components={components} // 自定义渲染
>
  // 内容
  {markdownText}
</ReactMarkdown>