AI对话数据管理useChat 实现 SSE hook封装

379 阅读8分钟

useChat Hook

概述

useChat 是一个功能完整的 React Hook,用于构建支持流式响应的聊天应用。它提供了消息管理、流式数据处理、错误处理、重试机制、会话管理等核心功能。

核心架构

1. 类型系统

1.1 基础消息结构
interface Message<TMetadata = Record<string, unknown>> {
  id: string;           // 唯一标识符
  role: "user" | "assistant";  // 消息角色
  content: string;      // 消息内容
  metadata?: TMetadata;  // 扩展元数据
}
1.2 配置选项
interface ChatOptions<TMetadata> {
  sendFullHistory?: boolean;      // 是否发送完整历史
  maxHistoryLength?: number;      // 最大历史长度
  headers?: Record<string, string>; // 自定义请求头
  timeoutMs?: number;              // 超时时间
  maxRetries?: number;            // 最大重试次数
  retryDelayMs?: number;          // 重试延迟
  parseChunk?: (data: string) => any; // 数据解析函数
  onStatusChange?: (status: ChatStatus) => void; // 状态变更回调
  onChunkError?: (error: Error) => void; // 块错误处理
  sessionId?: string;             // 会话ID
  metadata?: TMetadata;           // 全局元数据
}
1.3 状态管理
type ChatStatus = "ready" | "submitted" | "streaming" | "error";

interface ChatError {
  message: string;
  code: "TIMEOUT" | "ABORT" | "NETWORK" | "HTTP" | "PARSE" | "UNKNOWN";
  details?: unknown;
}

核心实现原理

2.1 流式数据处理架构

2.1.1 TransformStream 管道设计

useChat 采用现代 Web Streams API 构建了一个高效的数据处理管道:

ReadableStream (HTTP响应) 
    ↓
TextDecoderStream (二进制→文本)
    ↓
LineSplitterStream (按行分割)
    ↓
SSEParserStream (SSE解析)
    ↓
React State (消息更新)
2.1.2 LineSplitterStream 实现
function createLineSplitterStream(): TransformStream<string, string> {
  let buffer = "";
  return new TransformStream<string, string>({
    transform(chunk, controller) {
      buffer += chunk;
      const lines = buffer.split(/\r?\n/);
      buffer = lines.pop() || "";
      for (const line of lines) {
        if (line) controller.enqueue(line);
      }
    },
    flush(controller) {
      if (buffer) controller.enqueue(buffer);
    },
  });
}

工作原理

  • 维护一个缓冲区累积接收到的文本块
  • 使用正则表达式 \r?\n 分割行
  • 保留不完整的最后一行到缓冲区
  • 流结束时推送剩余数据
2.1.3 SSEParserStream 实现
function createSSEParserStream<TMetadata>(
  parseChunk: (data: string) => any,
  onChunkError?: (error: Error) => void
): TransformStream<string, string> {
  return new TransformStream<string, string>({
    async transform(line, controller) {
      const trimmedLine = line.trim();
      if (trimmedLine.startsWith("data:")) {
        const data = trimmedLine.slice(5).trim();
        if (data && data !== "[DONE]") {
          try {
            const content = parseChunk(data);
            if (content !== null) controller.enqueue(content);
          } catch (err) {
            onChunkError?.(err as Error);
          }
        }
      }
    },
  });
}

SSE格式解析

  • 识别 data: 前缀的行
  • 过滤空数据和 [DONE] 标记
  • 支持自定义解析函数
  • 错误隔离处理

2.2 状态管理机制

2.2.1 核心状态
const [messages, setMessages] = useState<Message<TMetadata>[]>([]);
const [status, setStatus] = useState<ChatStatus>("ready");
const [error, setError] = useState<ChatError | null>(null);
const [currentSessionId, setCurrentSessionId] = useState<string | undefined>();
2.2.2 引用管理
const abortControllerRef = useRef<AbortController | null>(null);
const lastRequestRef = useRef<{
  message: string;
  options: Partial<ChatOptions<TMetadata>>;
} | null>(null);

2.3 消息发送流程

2.3.1 完整生命周期
  1. 验证阶段:检查状态和消息有效性
  2. 状态更新:设置 "submitted" 状态
  3. 消息构建:创建用户消息和空助手消息
  4. 配置合并:合并全局配置和局部配置
  5. 历史处理:根据配置决定发送的历史消息
  6. 重试循环:支持指数退避重试
  7. 流式处理:通过管道处理响应数据
  8. 状态同步:实时更新消息内容和状态
2.3.2 流式数据处理流程
// 1. 创建桥接流
const readableStream = new ReadableStream<Uint8Array>({
  start(controller) {
    const pump = async () => {
      try {
        const { done, value } = await reader.read();
        if (done) controller.close();
        else {
          controller.enqueue(value);
          pump();
        }
      } catch (err) {
        controller.error(err);
      }
    };
    pump();
  },
});

// 2. 构建处理管道
const pipeline = readableStream
  .pipeThrough(new TextDecoderStream())
  .pipeThrough(createLineSplitterStream())
  .pipeThrough(createSSEParserStream(...));

// 3. 消费处理结果
const processedReader = pipeline.getReader();
while (true) {
  const { done, value } = await processedReader.read();
  if (done) break;
  if (value) {
    setMessages(prev => prev.map(msg => 
      msg.id === assistantMessageId 
        ? { ...msg, content: msg.content + value }
        : msg
    ));
  }
}

2.4 错误处理机制

2.4.1 错误分类
  • TIMEOUT: 请求超时
  • ABORT: 用户主动取消
  • NETWORK: 网络错误
  • HTTP: HTTP状态错误
  • PARSE: 数据解析错误
  • UNKNOWN: 未知错误
2.4.2 重试策略
let attempt = 0;
while (attempt < mergedOptions.maxRetries) {
  try {
    // 执行请求
    break; // 成功则退出循环
  } catch (err) {
    if (isAbortOrTimeout(err)) {
      // 不可重试错误,直接返回
      return;
    }
    attempt++;
    if (attempt < maxRetries) {
      await delay(retryDelayMs);
    }
  }
}

2.5 会话管理

2.5.1 会话ID处理
  • 支持通过 options.sessionId 初始化会话
  • 动态切换会话时自动清空消息
  • 支持通过 setMessages 更新会话ID
2.5.2 会话持久化
useEffect(() => {
  if (options.sessionId !== currentSessionId) {
    clearMessages();
    setCurrentSessionId(options.sessionId);
  }
}, [options.sessionId]);

高级特性

3.1 元数据支持

通过 TypeScript 泛型支持任意类型的元数据:

interface MyMetadata {
  confidence?: number;
  sources?: string[];
  timestamp?: Date;
}

const { sendMessage } = useChat<MyMetadata>('/api/chat', {
  metadata: { confidence: 0.95 }
});

3.2 自定义数据解析

支持自定义 parseChunk 函数处理不同格式的响应:

const parseChunk = (data: string) => {
  const parsed = JSON.parse(data);
  if (parsed.type === 'delta') {
    return parsed.content;
  }
  return null;
};

3.3 生命周期钩子

  • onStatusChange: 状态变更通知
  • onChunkError: 块级错误处理

使用示例

4.1 基础用法

import useChat from '@/hooks/use-chat';

function ChatComponent() {
  const { messages, status, error, sendMessage } = useChat('/api/chat');

  const handleSend = async (message: string) => {
    await sendMessage(message);
  };

  return (
    <div>
      {messages.map(msg => (
        <div key={msg.id} className={msg.role}>
          {msg.content}
        </div>
      ))}
    </div>
  );
}

4.2 高级配置

const { messages, status, sendMessage, abortRequest } = useChat('/api/chat', {
  sendFullHistory: true,
  maxHistoryLength: 50,
  timeoutMs: 10000,
  maxRetries: 3,
  retryDelayMs: 1000,
  headers: { 'Authorization': 'Bearer token' },
  onStatusChange: (status) => console.log('Status:', status),
  parseChunk: (data) => {
    try {
      const parsed = JSON.parse(data);
      return parsed.choices[0]?.delta?.content || null;
    } catch {
      return null;
    }
  }
});

性能优化

5.1 内存管理

  • 自动清理超时的 AbortController
  • 组件卸载时自动中止请求
  • 使用 useCallback 优化重渲染

5.2 流式处理优势

  • 内存占用恒定,不受响应大小影响
  • 实时更新,无需等待完整响应
  • 支持任意大小的响应数据

错误调试

6.1 调试技巧

const { error, status } = useChat('/api/chat', {
  onStatusChange: console.log,
  onChunkError: console.error,
});

// 错误处理
if (error) {
  console.error('Chat error:', error.code, error.message, error.details);
}

6.2 常见问题

  1. 网络超时:调整 timeoutMs 配置
  2. CORS错误:检查服务器配置
  3. 解析错误:验证 parseChunk 函数实现
  4. 状态异常:检查 status 状态机

扩展指南

7.1 集成其他UI库

支持无缝集成各种UI组件库,只需处理返回的 messagesstatus

7.2 自定义存储

通过 setMessagesgetMessages 实现自定义存储:

// 保存到localStorage
useEffect(() => {
  localStorage.setItem('chat', JSON.stringify(messages));
}, [messages]);

// 从localStorage恢复
const loadMessages = () => {
  const saved = localStorage.getItem('chat');
  if (saved) setMessages(JSON.parse(saved));
};

API 参考

useChat 参数

参数名类型默认值描述
apiUrlstring-聊天API端点
optionsChatOptions{}配置选项

返回值

属性名类型描述
messagesMessage[]当前消息列表
statusChatStatus当前状态
errorChatError | null错误信息
sendMessagefunction发送消息函数
abortRequestfunction中止当前请求
clearMessagesfunction清空消息列表
retryfunction重试上次请求
setMessagesfunction设置消息列表
resetConversationfunction重置会话
import { generateId } from "@/lib/generateId";
import { useState, useCallback, useRef, useEffect } from "react";
// === 类型定义 ===
interface Message<TMetadata = Record<string, unknown>> {
  id: string;
  role: "user" | "assistant";
  content: string;
  metadata?: TMetadata;
}

interface ChatOptions<TMetadata = Record<string, unknown>> {
  sendFullHistory?: boolean;
  maxHistoryLength?: number;
  headers?: Record<string, string>;
  timeoutMs?: number;
  maxRetries?: number;
  retryDelayMs?: number;
  parseChunk?: (data: string) => string | null;
  onStatusChange?: (status: ChatStatus) => void;
  onChunkError?: (error: Error) => void;
  sessionId?: string;
  metadata?: TMetadata;
}

type ChatStatus = "ready" | "submitted" | "streaming" | "error";

interface ChatError {
  message: string;
  code: "TIMEOUT" | "ABORT" | "NETWORK" | "HTTP" | "PARSE" | "UNKNOWN";
  details?: unknown;
}

interface RequestPayload<TMetadata = Record<string, unknown>> {
  messages: Message<TMetadata>[];
  stream: boolean;
  sessionId?: string;
}

interface UseChatReturn<TMetadata = Record<string, unknown>> {
  messages: Message<TMetadata>[];
  status: ChatStatus;
  error: ChatError | null;
  sendMessage: (
    message: string,
    options?: Partial<ChatOptions<TMetadata>>
  ) => Promise<void>;
  abortRequest: () => void;
  clearMessages: () => void;
  retry: () => Promise<void>;
  setMessages: (messages: Message<TMetadata>[], newSessionId?: string) => void;
  resetConversation: (newSessionId?: string) => void;
}
function createLineSplitterStream(): TransformStream<string, string> {
  let buffer = "";
  return new TransformStream<string, string>({
    transform(chunk, controller) {
      buffer += chunk;
      const lines = buffer.split(/\r?\n/);
      buffer = lines.pop() || ""; // 保留不完整的最后一行
      for (const line of lines) {
        if (line) {
          // 只推送非空行
          controller.enqueue(line);
        }
      }
    },
    flush(controller) {
      // 流结束时,推送缓冲区中剩余的最后一行(如果存在)
      if (buffer) {
        controller.enqueue(buffer);
        buffer = "";
      }
      controller.terminate();
    },
  });
}

/**
 * 创建一个 TransformStream,用于解析 SSE 的 data: 行并提取 content。
 */
function createSSEParserStream<TMetadata>(
  parseChunk: (data: string) => any,
  onChunkError?: (error: Error) => void
): TransformStream<string, string> {
  return new TransformStream<string, string>({
    async transform(line, controller) {
      const trimmedLine = line.trim();
      if (trimmedLine.startsWith("data:")) {
        const data = trimmedLine.slice(5).trim();
        if (data && data !== "[DONE]") {
          // 忽略空 data 和 [DONE]
          try {
            const content = parseChunk(data);
            if (content !== null) {
              controller.enqueue(content );
            }
          } catch (err) {
            console.error("Error parsing SSE chunk:", data, err);
            onChunkError?.(err as Error);
          }
        }
      }
      // 忽略 event:, id:, retry: 等其他行
    },
  });
}

const useChat = <TMetadata = Record<string, unknown>>(
  apiUrl: string,
  options: ChatOptions<TMetadata> = {}
): UseChatReturn<TMetadata> => {
  const [messages, setMessages] = useState<Message<TMetadata>[]>([]);
  const [status, setStatus] = useState<ChatStatus>("ready");
  const [error, setError] = useState<ChatError | null>(null);
  const [currentSessionId, setCurrentSessionId] = useState<string | undefined>(
    options.sessionId
  );

  const abortControllerRef = useRef<AbortController | null>(null);
  const lastRequestRef = useRef<{
    message: string;
    options: Partial<ChatOptions<TMetadata>>;
  } | null>(null);


  const {
    sendFullHistory = false,
    maxHistoryLength = Infinity,
    headers = {},
    timeoutMs = 30000,
    maxRetries = 3,
    retryDelayMs = 1000,
    parseChunk = (data: string): string | null => {
      try {
        const trimmedData = data.trim();
        if (!trimmedData || trimmedData === "[DONE]") return null;
        const parsed = JSON.parse(trimmedData);

        return parsed ? parsed : null;
      } catch (e) {
        console.warn("Failed to parse SSE chunk:", data, e);
        // 不再直接调用 onChunkError,因为 TransformStream 会处理
        return null;
      }
    },
    onStatusChange,
    onChunkError, // 传递给 TransformStream
  } = options;



  const setStatusWithCallback = useCallback(
    (newStatus: ChatStatus) => {
      setStatus(newStatus);
      onStatusChange?.(newStatus);
    },
    [onStatusChange]
  );

  // 移除 flushChunks 和 flushChunksCore

  const sendMessage = useCallback(
    async (
      message: string,
      sendOptions: Partial<ChatOptions<TMetadata>> = {}
    ) => {
      if (status !== "ready") {
        throw new Error("无法在非 ready 状态下发送消息");
      }
      if (!message.trim()) {
        throw new Error("消息不能为空");
      }
      
      setStatusWithCallback("submitted");
      setError(null);
      lastRequestRef.current = { message, options: sendOptions };
     const assistantMessageId=generateId()
      const userMessage: Message<TMetadata> = {
        id:assistantMessageId,
        role: "user",
        content: message,
        metadata: sendOptions.metadata,
      };

      setMessages((prev) => [...prev, userMessage,{
          id:  `assistant-${assistantMessageId}`,
          role: "assistant",
          content: "",
      }]);

      const mergedOptions = {
        sendFullHistory,
        maxHistoryLength,
        timeoutMs,
        maxRetries,
        retryDelayMs,
        parseChunk, // 使用传入的或默认的 parseChunk
        ...sendOptions,
        headers: { ...headers, ...sendOptions.headers },
      };

      const messagesToSend = mergedOptions.sendFullHistory
        ? [...messages, userMessage].slice(-mergedOptions.maxHistoryLength)
        : [userMessage];

      let attempt = 0;
      let lastError: Error | null = null;

      while (attempt < mergedOptions.maxRetries) {
        abortControllerRef.current = new AbortController();
        let timeoutId: NodeJS.Timeout | null = null;
        try {
          const timeoutPromise = new Promise<never>((_, reject) => {
            timeoutId = setTimeout(() => {
              abortControllerRef.current?.abort(new Error("Request Timeout"));
              reject(new Error("Request Timeout"));
            }, mergedOptions.timeoutMs);
          });

          const fetchPromise = fetch(apiUrl, {
            method: "POST",
            headers: {
              "Content-Type": "application/json",
              Accept: "text/event-stream",
              ...mergedOptions.headers,
            },
            body: JSON.stringify({
              messages: messagesToSend,
              stream: true,
              sessionId: sendOptions.sessionId || currentSessionId,
            } as RequestPayload<TMetadata>),
            signal: abortControllerRef.current.signal,
          });

          const response = await Promise.race([fetchPromise, timeoutPromise]);

          if (timeoutId) clearTimeout(timeoutId);

          if (!response.ok) {
            throw new Error(`HTTP 错误!状态码:${response.status}`, {
              cause: { status: response.status },
            });
          }

          setStatusWithCallback("streaming");

          // 创建助手消息
         // const assistantMessageId = generateId();
          // const assistantMessage: Message<TMetadata> = {
          //   id: assistantMessageId,
          //   role: "assistant",
          //   content: "",
          //   metadata: sendOptions.metadata || ({} as TMetadata),
          // };
          // setMessages((prev) => [...prev, assistantMessage]);

          const reader = response.body?.getReader();
          if (!reader) {
            throw new Error("响应体不可读");
          }

          // --- 核心修改:使用 TransformStream 处理流 ---
          try {
            // 1. 创建可读流
            const readableStream = new ReadableStream<Uint8Array>({
              start(controller) {
                // 将 fetch 的 reader 逻辑桥接到新的 ReadableStream
                const pump = async () => {
                  try {
                    const { done, value } = await reader.read();
                    if (done) {
                      controller.close();
                    } else {
                      controller.enqueue(value);
                      pump(); // 继续读取
                    }
                  } catch (err) {
                    controller.error(err);
                  }
                };
                pump();
              },
              cancel(reason) {
                // 如果新的流被取消,也取消原始 reader
                reader.cancel(reason).catch(console.error);
              },
            });

            // 2. 构建管道: ReadableStream -> TextDecoderStream -> LineSplitter -> SSEParser
            const pipeline = readableStream
              .pipeThrough(new TextDecoderStream()) // 解码为文本
              .pipeThrough(createLineSplitterStream()) // 按行分割
              .pipeThrough(
                createSSEParserStream(mergedOptions.parseChunk, onChunkError)
              ); // 解析 SSE data 并提取 content

            // 3. 从管道中读取处理后的数据
            const processedReader = pipeline.getReader();

            // 4. 迭代处理解析后的内容块
            while (true) {
              const { done, value: contentChunk } =
                await processedReader.read();

              if (done) {
                break; // 流处理完成
              }

              if (contentChunk) {
                // 将每个解析出的内容块追加到对应的助手消息中
                setMessages((prev) =>
                  prev.map((msg) => {
                    if (msg.id ===  `assistant-${assistantMessageId}`) {
                      if (contentChunk.type === "text") {
                        return {
                          ...msg,
                          content: msg.content + contentChunk.content,
                          type: "text",
                        };
                      }
                      if (contentChunk.type === "annotation") {
                        return {
                          ...msg,
                          metadata: {
                            ...msg?.metadata,
                            ...contentChunk.data,
                          },
                          type: "annotation",
                        };
                      }
                    }
                    return msg;
                  })
                );
              }
            }
          } finally {
            reader.releaseLock();
            if (timeoutId) clearTimeout(timeoutId);
          }
       

          setStatusWithCallback("ready");
          return;
        } catch (err: unknown) {
          if (timeoutId) clearTimeout(timeoutId);
          lastError = err instanceof Error ? err : new Error(String(err));

          let errorCode: ChatError["code"] = "UNKNOWN";
          if (lastError.name === "AbortError") {
            errorCode =
              lastError.message === "Request Timeout" ? "TIMEOUT" : "ABORT";
          } else if (lastError.message.includes("HTTP")) {
            errorCode = "HTTP";
          } else {
            errorCode = "NETWORK";
          }

          if (errorCode === "ABORT" || errorCode === "TIMEOUT") {
            console.log(`请求被${errorCode === "ABORT" ? "用户中止" : "超时"}`);
            setError({
              message: errorCode === "ABORT" ? "请求已被用户中止" : "请求超时",
              code: errorCode,
              details: lastError,
            });
            setStatusWithCallback("ready");
            return;
          }

          attempt++;
          if (attempt < mergedOptions.maxRetries) {
            console.log(
              `请求失败 (${lastError.message}), ${
                mergedOptions.maxRetries - attempt
              } 次重试机会`
            );
            await new Promise((resolve) =>
              setTimeout(resolve, mergedOptions.retryDelayMs)
            );
          }
        } finally {
          abortControllerRef.current = null;
        }
      }

      setError({
        message: lastError?.message || "重试后请求失败",
        code:
          lastError?.name === "AbortError"
            ? "ABORT"
            : lastError?.message.includes("HTTP")
            ? "HTTP"
            : "UNKNOWN",
        details: lastError,
      });
      setStatusWithCallback("error");
    },
    [
      apiUrl,
      status,
      messages,
      sendFullHistory,
      maxHistoryLength,
      headers,
      timeoutMs,
      maxRetries,
      retryDelayMs,
      parseChunk, // 依赖 parseChunk
      currentSessionId,
      setStatusWithCallback,
      onChunkError, // 依赖 onChunkError
      generateId,
    ]
  );

  const abortRequest = useCallback(() => {
    if (abortControllerRef.current) {
      abortControllerRef.current.abort(new Error("User Aborted"));
      abortControllerRef.current = null;
      setStatusWithCallback("ready");
    }
  }, [setStatusWithCallback]);

  const clearMessages = useCallback(() => {
    setMessages([]);
    setError(null);
    setStatusWithCallback("ready");
    lastRequestRef.current = null;
    // 不再需要手动清空 lineBufferRef 和 contentBufferRef
  }, [setStatusWithCallback]);

  // --- 修改 setMessagesWithReset ---
  // 由于不再使用缓冲区,这个函数可以大大简化
  const setMessagesWithReset = useCallback(
    (newMessages: Message<TMetadata>[], newSessionId?: string) => {
      // 1. 如果有正在进行的流式请求,立即中止它。
      if (abortControllerRef.current) {
        abortControllerRef.current.abort(
          new Error("SetMessages called, aborting current stream")
        );
        abortControllerRef.current = null;
      }

      // 2. 生成带 ID 的新消息列表 (如果需要)
      const messagesWithIds = newMessages.map((msg) => ({
        ...msg,
        id: msg.id || generateId(),
      }));

      // 3. 直接设置新的消息列表状态
      setMessages(messagesWithIds);

      // 4. 重置其他相关状态
      setError(null);
      setStatusWithCallback("ready");
      lastRequestRef.current = null;

      // 5. 处理会话 ID 的变更
      if (newSessionId !== undefined && newSessionId !== currentSessionId) {
        setCurrentSessionId(newSessionId);
      }
    },
    [setStatusWithCallback, currentSessionId, generateId]
  );
  // --- 修改结束 ---

  const resetConversation = useCallback(() => {
    clearMessages();
    if (options.sessionId !== undefined) {
      setCurrentSessionId(options.sessionId);
    }
  }, [clearMessages, options.sessionId]);

  useEffect(() => {
    if (options.sessionId !== currentSessionId) {
      clearMessages();
      setCurrentSessionId(options.sessionId);
    }
  }, [options.sessionId, clearMessages, currentSessionId]);

  useEffect(() => {
    return () => {
      abortRequest();
    };
  }, [abortRequest]);

  return {
    messages,
    status,
    error,
    sendMessage,
    abortRequest,
    clearMessages,
    setMessages: setMessagesWithReset,
    retry: useCallback(async () => {
      if (status !== "error" || !lastRequestRef.current) {
        return;
      }
      await sendMessage(
        lastRequestRef.current.message,
        lastRequestRef.current.options
      );
    }, [status, sendMessage]),
    resetConversation,
  };
};

export default useChat;