当前有许多ai对话hook 如@ai-sdk/react的useChat @ant-design/x-sdk的useXChat 但在我的使用场景中都有所不足 遂自定义hook useChat
功能如下:
- 收集用户提供的的流式数据 不做任何更改
- 维护用户消息/流式数据的状态 允许停止传输
- 恢复整个session的会话 允许立刻进入流式传输状态
- 更流畅的session创建 用户发送消息立刻切换到会话页面 而不需要在创建页面等待
useChat/useChatId.ts
useChat通过chatId判断是否重置会话 仅仅使用sessionId作为chatId会导致创建session时会话也重置
useChatId可以让用户在创建session时维持chatId不变
/* eslint-disable react-hooks/refs */
import { useRef } from 'react'
import { useMemoizedFn } from 'ahooks'
import { match, P } from 'ts-pattern'
export function useChatId(options: {
sessionId?: string | null
setSessionId?: (sid: string) => void
}) {
const { sessionId, setSessionId } = options
// 上次调用时传入的sessionId
const prevSessionIdRef = useRef(sessionId)
// 上次调用时使用的chatId
const prevChatIdRef = useRef<string>(undefined)
// sessionId由无到有 可能是创建session 也可能进入已有session
// 保存新sessionId 控制chatId不变
const newSessionIdRef = useRef<string>(undefined)
// 用于引发会话刷新
const chatId = match({ psid: prevSessionIdRef.current, sid: sessionId })
.with({ psid: P.string, sid: P.string }, ({ psid, sid }) => {
if (psid !== sid) {
// session->另一session 切换id
return sid
} else {
if (typeof prevChatIdRef.current !== 'string') {
// 首次进入session prevChatIdRef尚未赋值
return sid
}
// session一致 id不变
return prevChatIdRef.current
}
})
.with({ psid: P.string, sid: P.nullish }, () => {
// session->会话创建 新id
return crypto.randomUUID()
})
.with({ psid: P.nullish, sid: P.nullish }, () => {
if (typeof prevChatIdRef.current !== 'string') {
// 首次进入会话创建 prevChatIdRef尚未赋值
return crypto.randomUUID()
}
// 始终处于新会话创建阶段 id不变
return prevChatIdRef.current
})
.with({ psid: P.nullish, sid: P.string }, ({ sid }) => {
const newSessionId = newSessionIdRef.current
newSessionIdRef.current = undefined
if (typeof newSessionId === 'string' && sid === newSessionId) {
// 新sessionId与预设值相同
// 新会话创建 id不变
return prevChatIdRef.current
}
// 进入已有session
return sid
})
.exhaustive()!
prevSessionIdRef.current = sessionId
prevChatIdRef.current = chatId
/** 在新建session后 改为调用此函数设置sessionId 维持chatId不变 */
const setNewSessionId = useMemoizedFn((sid: string) => {
newSessionIdRef.current = sid
setSessionId?.(sid)
})
return { chatId, setNewSessionId }
}
useChat/index.ts
要点如下
- 使用useCallback配合闭包 确保不会让过期的setState执行
- 允许用户同时resume已完成的历史消息和ai正在执行的消息
- 为resuming/streaming/submitting单独设置state 避免单一status表意不清
- 使用AbortController提供了停止能力 无论用户消息还是ai消息
- 通过id唯一标识消息 让流式数据和常规消息正确合并
/* eslint-disable react-hooks/refs */
import type { Dispatch, SetStateAction } from 'react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useLatest, useMemoizedFn } from 'ahooks'
export { useChatId } from './useChatId'
export type AIMessage<AIMessagePart> = {
id: string
role: 'ai'
parts: AIMessagePart[]
status: 'streaming' | 'done' | 'aborted' | 'error'
}
export type UserMessage<UserMessagePart> = {
id: string
role: 'user'
parts: UserMessagePart[]
status: 'submitting' | 'done' | 'aborted' | 'error'
}
export type ChatMessage<AIMessagePart, UserMessagePart> =
| AIMessage<AIMessagePart>
| UserMessage<UserMessagePart>
export type UseChatOptions<AIMessagePart, UserMessagePart> = {
chatId: unknown
resumeMessages: (signal: AbortSignal) =>
| Promise<
| {
messages?: ChatMessage<AIMessagePart, UserMessagePart>[]
stream?: AsyncIterable<AIMessagePart>
}
| void
| undefined
| null
>
| void
| undefined
| null
sendUserMessage: (
parts: UserMessagePart[],
signal: AbortSignal,
) => Promise<AsyncIterable<AIMessagePart> | void | undefined | null> | void | undefined | null
/**
* 异常通知回调\
* resumeMessages、sendUserMessage、AI 流消费中任何非 abort 异常都会触发。\
* 仅作为通知,sendMessage 仍会向调用方 throw,调用方可自行决定是否再做处理。
*/
onError?: (error: unknown) => void
}
/** chatId的占位符 */
const DefaultChatId = Symbol('DefaultChatId')
/**
* 聊天会话 Hook 管理消息列表、发送消息、流式消费与中止。
*
* `chatId` 作为会话作用域 key 变更时内部状态(messages、submitting、streaming)会重置
* 并重新触发 `resumeMessages`,可以同时恢复历史消息和流式传输。
* *流式消息会自动续接在最后一条AI消息上。*
*
* ### 新建会话时保持 chatId 稳定
*
* 如果用sessionId作为chatId,那么 sessionId 从无到有(新会话创建 -> 服务端返回真实 sid)会导致丢失当前正在流式中的消息与状态。
* 使用 {@link useChatId} 可在这一过渡期保持 chatId 不变
* 仅在真正切换到其他已有会话时才更新。
*
* @example
* ```tsx
* const [sessionId, setSessionId] = useState<string | null>(null)
* const { chatId, setNewSessionId } = useChatId({ sessionId, setSessionId })
*
* const { messages, sendMessage } = useChat({
* chatId,
* resumeMessages: async (signal) => {
* if (!sessionId) return { messages: [] }
* // 允许返回stream继续流式传输
* const { messages, stream } = await fetchHistory(sessionId, signal)
* return { messages, stream }
* },
* sendUserMessage: async (parts, signal) => {
* // 没有 sessionId 时先自行新建 session 用 setNewSessionId 写回 chatId 保持不变
* let sid = sessionId
* if (!sid) {
* sid = await createSession(signal)
* // 调用setNewSessionId确保chatId不变
* setNewSessionId(sid)
* }
* const { stream } = await postMessage({ sessionId: sid, parts }, signal)
* return stream
* },
* })
* ```
*/
export function useChat<AIMessagePart, UserMessagePart>(
options: UseChatOptions<AIMessagePart, UserMessagePart>,
) {
const { chatId } = options
const optionsRef = useLatest(options)
const abortCtrlRef = useRef(new AbortController())
const stop = useMemoizedFn(() => {
abortCtrlRef.current.abort()
abortCtrlRef.current = new AbortController()
})
const [messages, setMessages, messageActions] = useMessages<AIMessagePart, UserMessagePart>(
chatId,
)
/** 是否正在恢复会话(不包括读流的过程) */
const [resuming, setResuming] = useSafeState(true, chatId)
/** 用户已提交但尚未收到响应的消息数量 */
const [submittingCount, setSubmittingCount] = useSafeState(0, chatId)
const invokeSendUserMessage = useCallback(
async (id: string, parts: UserMessagePart[], signal: AbortSignal) => {
setSubmittingCount((prev) => prev + 1)
try {
const stream = await optionsRef.current.sendUserMessage(parts, signal)
// 调用方可能未监听 signal
if (signal.aborted) throw new DOMException('Aborted', 'AbortError')
messageActions.updateMessage({ id, role: 'user', status: 'done' })
return stream
} catch (err) {
const aborted = signal.aborted
messageActions.updateMessage({
id,
role: 'user',
status: aborted ? 'aborted' : 'error',
})
throw err
} finally {
setSubmittingCount((prev) => prev - 1)
}
},
[messageActions, optionsRef, setSubmittingCount],
)
/** 正在被读取的 AI 流数量 */
const [streamingCount, setStreamingCount] = useSafeState(0, chatId)
const consumeAIStream = useCallback(
async (id: string, stream: AsyncIterable<AIMessagePart>, signal: AbortSignal) => {
setStreamingCount((prev) => prev + 1)
try {
for await (const part of stream) {
// 调用方可能未监听 signal
if (signal.aborted) throw new DOMException('Aborted', 'AbortError')
messageActions.updateMessage({ id, role: 'ai', parts: [part] })
}
messageActions.updateMessage({ id, role: 'ai', status: 'done' })
} catch (err) {
// 主动 stop 触发的中止标记为 aborted 真实异常才标记 error
const aborted = signal.aborted
messageActions.updateMessage({
id,
role: 'ai',
status: aborted ? 'aborted' : 'error',
})
throw err
} finally {
setStreamingCount((prev) => prev - 1)
}
},
[messageActions, setStreamingCount],
)
// 防止effect意外调用 例如hmr
const latestChatId = useRef<unknown>(DefaultChatId)
const isChatIdChange = latestChatId.current !== chatId
latestChatId.current = chatId
useEffect(() => {
// 组件卸载时无效
const cleanup = () => {
if (latestChatId.current !== chatId) stop()
}
if (!isChatIdChange) return cleanup
const signal = abortCtrlRef.current.signal
;(async () => {
try {
const res = await optionsRef.current.resumeMessages(signal)
// 调用方可能未监听 signal
if (signal.aborted) throw new DOMException('Aborted', 'AbortError')
if (!res) return
const { stream } = res
let { messages: resumed = [] } = res
let lastAiMessageId: string | undefined
const lastMessage = resumed.at(-1)
if (stream) {
if (lastMessage?.role === 'ai') {
lastAiMessageId = lastMessage.id
} else {
lastAiMessageId = crypto.randomUUID()
resumed = [
...resumed,
{ id: lastAiMessageId, role: 'ai', parts: [], status: 'streaming' },
]
}
}
if (resumed.length) {
messageActions.addMessages(resumed, 'pre')
}
if (lastAiMessageId && stream) {
setResuming(false)
await consumeAIStream(lastAiMessageId, stream, signal)
}
} catch (err) {
if (!signal.aborted) optionsRef.current.onError?.(err)
} finally {
setResuming(false)
}
})()
return cleanup
}, [chatId, consumeAIStream, isChatIdChange, messageActions, optionsRef, setResuming, stop])
// 组件卸载时清理
useEffect(() => {
return stop
}, [stop])
const sendMessage = useCallback(
async (parts: UserMessagePart[]) => {
const signal = abortCtrlRef.current.signal
const userMessageId = crypto.randomUUID()
messageActions.addMessages({
id: userMessageId,
role: 'user',
parts,
status: 'submitting',
})
try {
const stream = await invokeSendUserMessage(userMessageId, parts, signal)
if (!stream) return
const aiMessageId = crypto.randomUUID()
messageActions.addMessages({
id: aiMessageId,
role: 'ai',
parts: [],
status: 'streaming',
})
await consumeAIStream(aiMessageId, stream, signal)
} catch (err) {
if (!signal.aborted) optionsRef.current.onError?.(err)
throw err
}
},
[consumeAIStream, invokeSendUserMessage, messageActions, optionsRef],
)
return {
messages,
sendMessage,
/** 是否正在恢复会话(不含读流阶段) */
resuming,
/** 是否有 AI 流正在读取中 */
streaming: streamingCount > 0,
/** 是否有用户消息已发出但尚未收到响应 */
submitting: submittingCount > 0,
stop,
setMessages,
messageActions,
consumeAIStream,
}
}
/**
* 与作用域 key 绑定的 state\
* - value 用 ref 维护 key 变更当帧同步重置 避免 render 中 setValue 触发二次渲染\
* - setter 闭包在 useCallback 中捕获 render 当时的 key 异步回来时若 key 已变更 则跳过更新
*/
function useSafeState<T>(initial: T, key: unknown) {
const [, forceRender] = useState({})
const valueRef = useRef(initial)
const prevKeyRef = useRef(key)
const latestKeyRef = useRef(key)
if (!Object.is(prevKeyRef.current, key)) {
prevKeyRef.current = key
valueRef.current = initial
}
latestKeyRef.current = key
// 不能使用useMemorizedFn 需要闭包里的旧key与最新的key对比
const safeSetValue: Dispatch<SetStateAction<T>> = useCallback(
(next) => {
if (!Object.is(latestKeyRef.current, key)) return
valueRef.current =
typeof next === 'function' ? (next as (prev: T) => T)(valueRef.current) : next
forceRender({})
},
[key],
)
return [valueRef.current, safeSetValue] as const
}
/** 分配式的 PartialExcept 保留联合类型各分支 使 role 能作为 discriminator narrow 到对应分支 */
type PartialExcept<T, K extends PropertyKey> = T extends unknown
? Pick<T, Extract<keyof T, K>> & Partial<Omit<T, K>>
: never
function useMessages<AIMessagePart, UserMessagePart>(chatId: unknown) {
type Message = ChatMessage<AIMessagePart, UserMessagePart>
const [messages, setMessages] = useSafeState<Message[]>([], chatId)
const actions = useMemo(() => {
const addMessages = (messageOrList: Message | Message[], position: 'pre' | 'end' = 'end') => {
const list = Array.isArray(messageOrList) ? messageOrList : [messageOrList]
if (!list.length) return
setMessages((prev) => (position === 'end' ? [...prev, ...list] : [...list, ...prev]))
}
const updateMessage = (patch: PartialExcept<Message, 'id' | 'role'>) => {
if (!patch.parts?.length && patch.status === undefined) return
setMessages((prev) => {
const index = prev.findIndex((item) => item.id === patch.id && item.role === patch.role)
if (index === -1) return prev
const target = prev[index]
const merged = mergeMessage(target, patch)
if (merged === target) return prev
return prev.map((item, i) => (i === index ? merged : item))
})
}
return {
addMessages,
updateMessage,
}
}, [setMessages])
return [messages, setMessages, actions] as const
}
function mergeMessage<AIMessagePart, UserMessagePart>(
target: ChatMessage<AIMessagePart, UserMessagePart>,
patch: PartialExcept<ChatMessage<AIMessagePart, UserMessagePart>, 'id' | 'role'>,
): ChatMessage<AIMessagePart, UserMessagePart> {
// role 已在调用方通过 findIndex 匹配 下面仅通过窄化恢复 TS 联合类型收敛
if (target.role === 'ai' && patch.role === 'ai') {
return {
...target,
...(patch.status !== undefined && { status: patch.status }),
...(patch.parts?.length && { parts: [...target.parts, ...patch.parts] }),
}
}
if (target.role === 'user' && patch.role === 'user') {
return {
...target,
...(patch.status !== undefined && { status: patch.status }),
...(patch.parts?.length && { parts: [...target.parts, ...patch.parts] }),
}
}
return target
}