400行代码实现ChatGLM大模型流式对话效果

1,058 阅读3分钟

今年最火的技术莫过于大模型了,而国内的大模型也是百家争鸣,涌现了诸如:文心一言、ChatGLM、MiniMax等多个大模型,同时很多大模型厂商也开放了自己大模型api接口,因此本文将主要以调用ChatGLM的官方api接口来实现流式对话效果。

为什么要选ChatGLM?ChatGLM由清华大学联合智谱公司开发的一款强大的语言生成模型。

技术栈为: SpringBoot + HTML + Server-Sent Events(SSE协议)

注释:Server-Sent Events协议简单来说就是服务端给客户端主动推送消息,在实现流式对话效果的应用场景下比WebSocket、Ajax更适用,更多内容可以参考:www.ruanyifeng.com/blog/2017/0…

实现效果 chatglm_async_stream.gif

后端代码为:

  1. 调用ChatGLM官方api接口,并发送流式数据到客户端
package com.alibaba.controller;  
  
import com.alibaba.utils.TokenAuthUtil;  
import com.google.gson.Gson;  
import com.zhipu.oapi.service.v3.ModelApiRequest;  
import lombok.extern.slf4j.Slf4j;  
import okhttp3.*;  
import org.springframework.http.MediaType;  
import org.springframework.http.ResponseEntity;  
import org.springframework.web.bind.annotation.GetMapping;  
import org.springframework.web.bind.annotation.RequestParam;  
import org.springframework.web.bind.annotation.RestController;  
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;  
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;  
  
import java.io.IOException;  
import java.util.Arrays;  
import java.util.UUID;  
import java.util.concurrent.ExecutorService;  
import java.util.concurrent.Executors;  
  
/**  
* @author quanhangbo  
* @date 2023/12/19 18:40  
*/  
@Slf4j  
@RestController  
public class ChatGLMController {  
  
    @GetMapping(value = "/api/model", produces = MediaType.TEXT_EVENT_STREAM_VALUE)  
    public ResponseEntity<ResponseBodyEmitter> chat(@RequestParam String query) {  
        log.info("query={}", query);  
        SseEmitter emitter = new SseEmitter();  

        ExecutorService threadPool = Executors.newFixedThreadPool(1);  
        threadPool.execute(() -> {  
            try {  
                ModelApiRequest modelApiRequest = new ModelApiRequest();  
                modelApiRequest.setRequestId(UUID.randomUUID().toString().replace("_", ""));  
                modelApiRequest.setTopP(0.7f);  
                modelApiRequest.setIncremental(true);  
                modelApiRequest.setTemperature(0.9f);  
                modelApiRequest.setPrompt(Arrays.asList(new ModelApiRequest.Prompt("user", query)));  

                String jsonPayload = new Gson().toJson(modelApiRequest);  

                // 输入你自己的api_key  
                String apiKey = "YOUR_API_KEY";  
                String[] ans = apiKey.split("\\.");  
                String token = TokenAuthUtil.getToken(ans[0], ans[1]);  

                OkHttpClient client = new OkHttpClient();  
                Request request = new Request.Builder()  
                .url("http://open.bigmodel.cn/api/paas/v3/model-api/chatglm_lite/sse-invoke")  
                .post(RequestBody.create(okhttp3.MediaType.parse("application/json"), jsonPayload))  
                .header("Authorization", "Bearer " + token)  
                .header("Content-Type", "application/json")  
                .header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36")  
                .header("Cache-Control","no-cache")  
                .header("Accept", "text/event-stream")  
                .build();  

                try (Response response = client.newCall(request).execute()) {  
                    ResponseBody body = response.body();  
                    String responseData = body.string();  
                    log.info("Sending SSE event: {}", responseData);  
                    emitter.send(SseEmitter.event().data(responseData).name("message"));  
                } catch (IOException e) {  
                    emitter.completeWithError(e);  
                    e.printStackTrace();  
                } finally {  
                    emitter.complete();  
                }  
            } catch (Exception e) {  
                e.printStackTrace();  
            }  
        });  
        // Thread thread = new Thread();  
        // thread.start();  
        return ResponseEntity.ok(emitter);  
    } 
}
  1. 配置客户端和服务端交互的跨域配置
package com.alibaba.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

/**
 * @author quanhangbo
 * @date 2023/12/19 19:50
 */
@Configuration
public class CorsConfiguration {

    @Bean
    public WebMvcConfigurer corsConfigurer() {
        return new WebMvcConfigurer() {
            @Override
            public void addCorsMappings(CorsRegistry registry) {
                registry.addMapping("/**")
                        .allowCredentials(true)
                        .allowedMethods("POST", "GET", "PUT", "OPTIONS", "DELETE")
                        .allowedOrigins("*")
                        .allowedHeaders("*");

            }
        };
    }
}
  1. ChatGLM鉴权工具类
package com.alibaba.utils;

import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;

import java.nio.charset.StandardCharsets;
import java.util.Calendar;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
 * @author quanhangbo
 * @date 2023/12/19 20:55
 */
public class TokenAuthUtil {
    private static final long expireMillis = 30 * 60 * 1000L;

    // 缓存服务
    public static Cache<String, String> cache = CacheBuilder.newBuilder()
            .expireAfterWrite(expireMillis - (60 * 1000L), TimeUnit.SECONDS)
            .build();

    public static String getToken(String apiKey, String apiSecret) {
        // 缓存Token
        String token = cache.getIfPresent(apiKey);
        if (null != token) {
            return token;
        }
        // 创建Token
        Algorithm algorithm = Algorithm.HMAC256(apiSecret.getBytes(StandardCharsets.UTF_8));
        Map<String, Object> payload = new HashMap<>();
        payload.put("api_key", apiKey);
        payload.put("exp", System.currentTimeMillis() + expireMillis);
        payload.put("timestamp", Calendar.getInstance().getTimeInMillis());
        Map<String, Object> headerClaims = new HashMap<>();
        headerClaims.put("alg", "HS256");
        headerClaims.put("sign_type", "SIGN");
        token = JWT.create().withPayload(payload).withHeader(headerClaims).sign(algorithm);
        cache.put(apiKey, token);
        return token;
    }
}

前端代码为:

<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>ChatBot Interaction</title>
    <style>
        body {
            font-family: 'Arial', sans-serif;
            margin: 0;
            padding: 0;
            background-color: #333;
            display: flex;
            flex-direction: column;
            align-items: center;
            justify-content: center;
            height: 100vh;
        }

        #chat-container {
            width: 95%;
            max-width: 900px;
            height: 75vh;
            margin-top: 1.5vh;
            background-color: #f0f0f0;
            border-radius: 12px;
            box-shadow: 0 0 15px rgba(0, 0, 0, 0.2);
            overflow: auto;
            display: flex;
            flex-direction: column;
            align-items: flex-start;
            padding: 5px 20px;
        }

        .message {
            margin: 10px 0;
            overflow-wrap: break-word;
            font-size: 16px;
            padding: 10px;
            border-radius: 8px;
            max-width: 70%;
        }

        .user-message {
            text-align: right;
            align-self: flex-end;
            background-color: #4CAF50;
            color: #fff;
        }

        .bot-message {
            text-align: left;
            background-color: #2196F3;
            color: #fff;
            white-space: pre-wrap;
        }

        #input-container {
            width: 80%;
            max-width: 600px;
            margin-top: 1vh;
            background-color: #f0f0f0;
            padding: 10px;
            border-top: 1px solid #eee;
            display: flex;
            align-items: center;
            justify-content: center;
            position: fixed;
            bottom: 3vh;
            border-radius: 12px;
            box-shadow: 0 0 15px rgba(0, 0, 0, 0.2);
        }

        #user-input {
            flex: 1;
            padding: 8px;
            margin-right: 8px;
            border: 1px solid #ccc;
            border-radius: 4px;
            font-size: 16px;
        }

        #send-button {
            padding: 8px;
            border: none;
            background-color: #4CAF50;
            color: #fff;
            cursor: pointer;
            border-radius: 4px;
            font-size: 16px;
        }
    </style>
</head>

<body>
    <div id="chat-container">
    </div>

    <div id="input-container">
        <input type="text" id="user-input" placeholder="Type your message...">
        <button id="send-button" onclick="sendMessage()">Send</button>
    </div>

    <script>
        const chatContainer = document.getElementById('chat-container');
        const userInput = document.getElementById('user-input');
        let eventSource = null;
        let messageQueue = [];

        async function appendMessage(sender, content, className) {
            const message = document.createElement('div');
            message.className = `message ${className}`;
            message.innerHTML = `<strong>${sender}:</strong> ${content}`;
            chatContainer.appendChild(message);
            chatContainer.scrollTop = chatContainer.scrollHeight;
        }

        async function typeMessage(message, sender, className) {
            const delay = 50; 
            for (const char of message) {
                await new Promise(resolve => setTimeout(resolve, delay));
                const lastMessage = chatContainer.lastChild;
                if (lastMessage && lastMessage.className.includes(className)) {
                    lastMessage.lastChild.textContent += char;
                } else {
                    await appendMessage(sender, char, className);
                }
            }
        }

        async function handleEventStreamData(event) {
            let data = event.data;
            console.log(data);
            if (event.type === "message") {
                let lines = data.split("\n");
                messageQueue.push({ sender: 'ChatBot', content: lines[1], className: 'bot-message' });
            } else if (event.type === "add") {
                messageQueue.push({ sender: 'ChatBot', content: data, className: 'bot-message' });
            } else if (event.type === "finish") {
                while (messageQueue.length > 0) {
                    const message = messageQueue.shift();
                    await typeMessage(message.content, message.sender, message.className);
                }
            }
        }

        async function sendMessage() {
            const userMessage = userInput.value.trim();
            if (userMessage === '') return;
            appendMessage('You', userMessage, 'user-message');
            if (!eventSource || eventSource.readyState === EventSource.CLOSED) {
                eventSource = new EventSource(`http://localhost:9090/api/model?query=${userMessage}`, { withCredentials: true });
                eventSource.addEventListener('message', function (event) {
                    handleEventStreamData(event);
                });
                eventSource.addEventListener('add', function (event) {
                    handleEventStreamData(event);
                });
                eventSource.addEventListener('finish', function (event) {
                    handleEventStreamData(event);
                });
                eventSource.onclose = function (event) {
                    console.log("SSE connection is closed");
                }
                eventSource.onerror = function (event) {
                    eventSource.close();
                }
            }
            userInput.value = ''; 
        }

        document.getElementById('user-input').addEventListener('keydown', async function (event) {
            if (event.key === 'Enter') {
                await sendMessage();
            }
        });
    </script>
</body>
</html>

更多代码参考github.com/HangboQuan/…, 有问题可以提issue。