今年最火的技术莫过于大模型了,而国内的大模型也是百家争鸣,涌现了诸如:文心一言、ChatGLM、MiniMax等多个大模型,同时很多大模型厂商也开放了自己大模型api接口,因此本文将主要以调用ChatGLM的官方api接口来实现流式对话效果。
为什么要选ChatGLM?ChatGLM由清华大学联合智谱公司开发的一款强大的语言生成模型。
技术栈为: SpringBoot + HTML + Server-Sent Events(SSE协议)
注释:Server-Sent Events协议简单来说就是服务端给客户端主动推送消息,在实现流式对话效果的应用场景下比WebSocket、Ajax更适用,更多内容可以参考:www.ruanyifeng.com/blog/2017/0…
实现效果
后端代码为:
- 调用ChatGLM官方api接口,并发送流式数据到客户端
package com.alibaba.controller;
import com.alibaba.utils.TokenAuthUtil;
import com.google.gson.Gson;
import com.zhipu.oapi.service.v3.ModelApiRequest;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.Arrays;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* @author quanhangbo
* @date 2023/12/19 18:40
*/
@Slf4j
@RestController
public class ChatGLMController {
@GetMapping(value = "/api/model", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public ResponseEntity<ResponseBodyEmitter> chat(@RequestParam String query) {
log.info("query={}", query);
SseEmitter emitter = new SseEmitter();
ExecutorService threadPool = Executors.newFixedThreadPool(1);
threadPool.execute(() -> {
try {
ModelApiRequest modelApiRequest = new ModelApiRequest();
modelApiRequest.setRequestId(UUID.randomUUID().toString().replace("_", ""));
modelApiRequest.setTopP(0.7f);
modelApiRequest.setIncremental(true);
modelApiRequest.setTemperature(0.9f);
modelApiRequest.setPrompt(Arrays.asList(new ModelApiRequest.Prompt("user", query)));
String jsonPayload = new Gson().toJson(modelApiRequest);
// 输入你自己的api_key
String apiKey = "YOUR_API_KEY";
String[] ans = apiKey.split("\\.");
String token = TokenAuthUtil.getToken(ans[0], ans[1]);
OkHttpClient client = new OkHttpClient();
Request request = new Request.Builder()
.url("http://open.bigmodel.cn/api/paas/v3/model-api/chatglm_lite/sse-invoke")
.post(RequestBody.create(okhttp3.MediaType.parse("application/json"), jsonPayload))
.header("Authorization", "Bearer " + token)
.header("Content-Type", "application/json")
.header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36")
.header("Cache-Control","no-cache")
.header("Accept", "text/event-stream")
.build();
try (Response response = client.newCall(request).execute()) {
ResponseBody body = response.body();
String responseData = body.string();
log.info("Sending SSE event: {}", responseData);
emitter.send(SseEmitter.event().data(responseData).name("message"));
} catch (IOException e) {
emitter.completeWithError(e);
e.printStackTrace();
} finally {
emitter.complete();
}
} catch (Exception e) {
e.printStackTrace();
}
});
// Thread thread = new Thread();
// thread.start();
return ResponseEntity.ok(emitter);
}
}
- 配置客户端和服务端交互的跨域配置
package com.alibaba.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
/**
* @author quanhangbo
* @date 2023/12/19 19:50
*/
@Configuration
public class CorsConfiguration {
@Bean
public WebMvcConfigurer corsConfigurer() {
return new WebMvcConfigurer() {
@Override
public void addCorsMappings(CorsRegistry registry) {
registry.addMapping("/**")
.allowCredentials(true)
.allowedMethods("POST", "GET", "PUT", "OPTIONS", "DELETE")
.allowedOrigins("*")
.allowedHeaders("*");
}
};
}
}
- ChatGLM鉴权工具类
package com.alibaba.utils;
import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import java.nio.charset.StandardCharsets;
import java.util.Calendar;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* @author quanhangbo
* @date 2023/12/19 20:55
*/
public class TokenAuthUtil {
private static final long expireMillis = 30 * 60 * 1000L;
// 缓存服务
public static Cache<String, String> cache = CacheBuilder.newBuilder()
.expireAfterWrite(expireMillis - (60 * 1000L), TimeUnit.SECONDS)
.build();
public static String getToken(String apiKey, String apiSecret) {
// 缓存Token
String token = cache.getIfPresent(apiKey);
if (null != token) {
return token;
}
// 创建Token
Algorithm algorithm = Algorithm.HMAC256(apiSecret.getBytes(StandardCharsets.UTF_8));
Map<String, Object> payload = new HashMap<>();
payload.put("api_key", apiKey);
payload.put("exp", System.currentTimeMillis() + expireMillis);
payload.put("timestamp", Calendar.getInstance().getTimeInMillis());
Map<String, Object> headerClaims = new HashMap<>();
headerClaims.put("alg", "HS256");
headerClaims.put("sign_type", "SIGN");
token = JWT.create().withPayload(payload).withHeader(headerClaims).sign(algorithm);
cache.put(apiKey, token);
return token;
}
}
前端代码为:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>ChatBot Interaction</title>
<style>
body {
font-family: 'Arial', sans-serif;
margin: 0;
padding: 0;
background-color: #333;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100vh;
}
#chat-container {
width: 95%;
max-width: 900px;
height: 75vh;
margin-top: 1.5vh;
background-color: #f0f0f0;
border-radius: 12px;
box-shadow: 0 0 15px rgba(0, 0, 0, 0.2);
overflow: auto;
display: flex;
flex-direction: column;
align-items: flex-start;
padding: 5px 20px;
}
.message {
margin: 10px 0;
overflow-wrap: break-word;
font-size: 16px;
padding: 10px;
border-radius: 8px;
max-width: 70%;
}
.user-message {
text-align: right;
align-self: flex-end;
background-color: #4CAF50;
color: #fff;
}
.bot-message {
text-align: left;
background-color: #2196F3;
color: #fff;
white-space: pre-wrap;
}
#input-container {
width: 80%;
max-width: 600px;
margin-top: 1vh;
background-color: #f0f0f0;
padding: 10px;
border-top: 1px solid #eee;
display: flex;
align-items: center;
justify-content: center;
position: fixed;
bottom: 3vh;
border-radius: 12px;
box-shadow: 0 0 15px rgba(0, 0, 0, 0.2);
}
#user-input {
flex: 1;
padding: 8px;
margin-right: 8px;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 16px;
}
#send-button {
padding: 8px;
border: none;
background-color: #4CAF50;
color: #fff;
cursor: pointer;
border-radius: 4px;
font-size: 16px;
}
</style>
</head>
<body>
<div id="chat-container">
</div>
<div id="input-container">
<input type="text" id="user-input" placeholder="Type your message...">
<button id="send-button" onclick="sendMessage()">Send</button>
</div>
<script>
const chatContainer = document.getElementById('chat-container');
const userInput = document.getElementById('user-input');
let eventSource = null;
let messageQueue = [];
async function appendMessage(sender, content, className) {
const message = document.createElement('div');
message.className = `message ${className}`;
message.innerHTML = `<strong>${sender}:</strong> ${content}`;
chatContainer.appendChild(message);
chatContainer.scrollTop = chatContainer.scrollHeight;
}
async function typeMessage(message, sender, className) {
const delay = 50;
for (const char of message) {
await new Promise(resolve => setTimeout(resolve, delay));
const lastMessage = chatContainer.lastChild;
if (lastMessage && lastMessage.className.includes(className)) {
lastMessage.lastChild.textContent += char;
} else {
await appendMessage(sender, char, className);
}
}
}
async function handleEventStreamData(event) {
let data = event.data;
console.log(data);
if (event.type === "message") {
let lines = data.split("\n");
messageQueue.push({ sender: 'ChatBot', content: lines[1], className: 'bot-message' });
} else if (event.type === "add") {
messageQueue.push({ sender: 'ChatBot', content: data, className: 'bot-message' });
} else if (event.type === "finish") {
while (messageQueue.length > 0) {
const message = messageQueue.shift();
await typeMessage(message.content, message.sender, message.className);
}
}
}
async function sendMessage() {
const userMessage = userInput.value.trim();
if (userMessage === '') return;
appendMessage('You', userMessage, 'user-message');
if (!eventSource || eventSource.readyState === EventSource.CLOSED) {
eventSource = new EventSource(`http://localhost:9090/api/model?query=${userMessage}`, { withCredentials: true });
eventSource.addEventListener('message', function (event) {
handleEventStreamData(event);
});
eventSource.addEventListener('add', function (event) {
handleEventStreamData(event);
});
eventSource.addEventListener('finish', function (event) {
handleEventStreamData(event);
});
eventSource.onclose = function (event) {
console.log("SSE connection is closed");
}
eventSource.onerror = function (event) {
eventSource.close();
}
}
userInput.value = '';
}
document.getElementById('user-input').addEventListener('keydown', async function (event) {
if (event.key === 'Enter') {
await sendMessage();
}
});
</script>
</body>
</html>
更多代码参考github.com/HangboQuan/…, 有问题可以提issue。