0.前言
也许你也想过让AI完成数据分析吧。之前给网站接入的是chatgpt,现在5美元的额度用完了,想着换一家。看到豆包大模型送50万tokens的消息,抱着学习的心态,那就马上安排吧!
1.大模型准备工作
后端采用jdk17,springboot3.3.5,前端采用vue3和Arco Design
1.1 访问控制台
访问火山引擎-云上增长新动力,登录账号后进入“控制台”
1.2 创建一个apikey
1.3 创建一个推理接入点
1.4 输入相关信息并选择模型
进入API调用设置,这里给出了示例代码,需要复制apiKey和.model的内容。
2.编写后端代码
2.0 在pom.xml中导入sdk
<dependency>
<groupId>com.volcengine</groupId>
<artifactId>volcengine-java-sdk-ark-runtime</artifactId>
<version>0.1.134</version>
</dependency>
jdk8以上推荐导入
<dependency>
<groupId>javax.annotation</groupId>
<artifactId>javax.annotation-api</artifactId>
<version>1.3.2</version>
</dependency>
2.1封装AIConfig
import com.volcengine.ark.runtime.service.ArkService;
import jakarta.annotation.Resource;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.time.Duration;
@Configuration
@ConfigurationProperties(prefix = "doubaoai")
@Data
public class DoubaoAIConfig {
private String apiKey;
@Bean
public ArkService getDoubaoAIService() {
return ArkService.builder()
.apiKey(apiKey)
//接入点以(华北2北京)为例
.baseUrl("https://ark.cn-beijing.volces.com/api/v3/")
.timeout(Duration.ofSeconds(120))
.connectTimeout(Duration.ofSeconds(20))
.retryTimes(2)
.build();
}
}
记得在application.yml中增加apiKey
doubaoai:
apiKey: xxxxxxxxxxxxxxxxxx
2.2封装AIManager
为了提高AI服务的通用性,我这里定义了AIManager。
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionRequest;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessage;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
import com.volcengine.ark.runtime.service.ArkService;
import jakarta.annotation.Resource;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
@Component
public class DoubaoAIManager {
@Resource
ArkService arkService;
private static final double STABLE_TEMPERATURE = 0.01; //更稳定
private static final double UNSTABLE_TEMPERATURE = 0.99; //更随机
/**
* 通用请求
*
* @param messages
* @param temperature
* @return
*/
public String doDoubaoAIRequest(List<ChatMessage> messages,Double temperature) {
//构造请求
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model("ep-需要修改")
.temperature(temperature)
.messages(messages)
.build();
// 获取第一个结果的内容,如果没有结果则返回空字符串
try {
return arkService.createChatCompletion(chatCompletionRequest).getChoices()
.stream()
.map(choice -> choice.getMessage().getContent())
.findFirst()
.orElse("")
.toString();
} catch (Exception e) {
e.printStackTrace();
throw new BusinessException(ErrorCode.SYSTEM_ERROR, e.getMessage());
}
}
/**
* 通用请求(简化消息传递)
*
* @param systemMessage
* @param userMessage
* @param temperature
* @return
*/
//可自定义message
public String doDoubaoAIRequest(String systemMessage, String userMessage, Double temperature) {
List<ChatMessage> chatMessageList = new ArrayList<>();
ChatMessage systemChatMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(systemMessage).build();
ChatMessage userChatMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(userMessage).build();
chatMessageList.add(systemChatMessage);
chatMessageList.add(userChatMessage);
return doDoubaoAIRequest(chatMessageList, temperature);
}
/**
* 通用请求(简化消息传递,提高随机性)
*
* @param systemMessage
* @param userMessage
* @return
*/
public String doDoubaoUnstableAIRequest(String systemMessage, String userMessage) {
return doDoubaoAIRequest(systemMessage, userMessage, UNSTABLE_TEMPERATURE);
}
/**
* 通用请求(简化消息传递,提高稳定性)
*
* @param systemMessage
* @param userMessage
* @return
*/
public String doDoubaostableAIRequest(String systemMessage, String userMessage) {
return doDoubaoAIRequest(systemMessage, userMessage,STABLE_TEMPERATURE);
}
}
2.3测试类编写
核心代码写好了,我们来测试一下(使用标准方式)
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionRequest;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessage;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
import com.volcengine.ark.runtime.service.ArkService;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import jakarta.annotation.Resource;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
//否则出现报错“Java jakarta.websocket.server.ServerContainer not available”
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class DoubaoAI {
@Resource
private ArkService arkService;
@Test
void test() {
System.out.println("\n----- standard request -----");
final List<ChatMessage> messages = new ArrayList<>();
final ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content("你是豆包,是由字节跳动开发的 AI 人工智能助手").build();
final ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content("常见的十字花科植物有哪些?").build();
messages.add(systemMessage);
messages.add(userMessage);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model("ep-需要修改")
.messages(messages)
.build();
arkService.createChatCompletion(chatCompletionRequest).getChoices().forEach(choice -> System.out.println(choice.getMessage().getContent()));
arkService.shutdownExecutor();
}
}
能够正确得到回复
2.4 编写业务代码
2.4.1 controller
import cn.hutool.core.io.FileUtil;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.google.gson.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletRequest;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadPoolExecutor;
/**
* 图表接口
*
*/
@RestController
@RequestMapping("/chart")
@Slf4j
public class ChartController {
@Resource
private ChartService chartService;
@Resource
private UserService userService;
@Resource
private DoubaoAIService doubaoAIService;
@Resource
private RedisCacheManager redisCacheManager;
@Resource
private RedisLimiterManager redisLimiterManager;
@Resource
private ThreadPoolExecutor threadPoolExecutor;
@Resource
private ChartMessageProducer chartMessageProducer;
@Resource
private TaskManager taskManager;
/**
* 文件AI分析
*
* @param multipartFile
* @param genChartByAiRequest
* @param request
* @return
*/
@PostMapping("/gen")
public BaseResponse<BiResponse> genChartByAi(@RequestPart("file") MultipartFile multipartFile,
GenChartByAiRequest genChartByAiRequest, HttpServletRequest request) {
String chartName = genChartByAiRequest.getChartName();
String goal = genChartByAiRequest.getGoal();
String chartType = genChartByAiRequest.getChartType();
// 校验
ThrowUtils.throwIf(StringUtils.isNotBlank(chartName) && chartName.length() > 100, ErrorCode.PARAMS_ERROR, "名称过长");
User loginUser = userService.getLoginUser(request);
// 校验文件大小及后缀
long size = multipartFile.getSize();
final long TEN_MB = 10 * 1024 * 1024L;
ThrowUtils.throwIf(size > TEN_MB, ErrorCode.PARAMS_ERROR, "文件大小大于1M");
String fileName = multipartFile.getOriginalFilename();
String suffix = FileUtil.getSuffix(fileName);
final List<String> validFileSuffix = Arrays.asList("xlsx", "csv", "xls");
ThrowUtils.throwIf(!validFileSuffix.contains(suffix), ErrorCode.PARAMS_ERROR, "文件后缀不正确");
// 每个用户限流
redisLimiterManager.doRateLimit("genChartByAi" + loginUser.getId());
// 构造用户输入
StringBuilder userInput = new StringBuilder();
userInput.append("分析需求:\n");
// 拼接分析目标
String userGoal = "请帮我合理的分析一下数据";
if (StringUtils.isNotBlank(goal))
userGoal = goal;
// 分析输入加入图表类型
if (StringUtils.isNotBlank(chartType))
userGoal += ",请使用" + chartType;
userInput.append(userGoal).append("\n");
userInput.append("原始数据:\n");
// 压缩数据
String userData = "";
if (Objects.equals(suffix, "csv"))
userData = Csv2String.MultipartFileToString(multipartFile);
else
userData = ExcelUtils.excel2Csv(multipartFile);
BiResponse biResponse = new BiResponse();
Chart chart = new Chart();
// 数据规模校验 gpt3.5分析时长超过30s
if (userData.length() > SYNCHRO_MAX_TOKEN){
biResponse.setGenChart("ERROR");
biResponse.setGenResult("请求数据量过大,请删减excel文档内容");
biResponse.setChartId(chart.getId());
return ResultUtils.success(biResponse);
}
userInput.append(userData).append("\n");
String result = "";
try {
// 执行重试逻辑
result = doubaoAIService.doChatWithRetry(userInput.toString());
} catch (Exception e) {
// 如果重试过程中出现异常,返回错误信息
throw new BusinessException(ErrorCode.SYSTEM_ERROR, e + ",AI生成错误");
}
String[] splits = result.split("@@@@@");
String genChart = splits[1].trim();
String genResult = splits[2].trim();
// Echarts代码过滤 "var option ="
if (genChart.startsWith("var option =")) {
// 去除 "var option ="
genChart = genChart.replaceFirst("var\s+option\s*=\s*", "");
}
JsonObject chartJson = JsonParser.parseString(genChart).getAsJsonObject();
// 自动加入图表名称结尾并设置图表名称
if (StringUtils.isEmpty(chartName)){
String genChartName = String.valueOf(chartJson.getAsJsonObject("title").get("text"));
genChartName = genChartName.replace(""","");
if ( !genChartName.endsWith("图") && !genChartName.endsWith("表") && !genChartName.endsWith("图表"))
genChartName = genChartName + "图";
chart.setChartName(genChartName);
} else
chart.setChartName(chartName);
// 自动添加图表类型
if (StringUtils.isEmpty(chartType)){
JsonArray seriesArray = chartJson.getAsJsonArray("series");
if (seriesArray.size() > 0) {
JsonObject firstSeries = seriesArray.get(0).getAsJsonObject();
String typeChart = firstSeries.getAsJsonObject().get("type").getAsString();
String CnChartType = chartService.getChartTypeToCN(typeChart);
chart.setChartType(CnChartType);
}
}else
chart.setChartType(chartType);
// 加入下载按钮
JsonObject toolbox = new JsonObject();
toolbox.addProperty("show", true);
JsonObject saveAsImage = new JsonObject();
saveAsImage.addProperty("show", true);
saveAsImage.addProperty("excludeComponents", "['toolbox']");
saveAsImage.addProperty("pixelRatio", 2);
JsonObject feature = new JsonObject();
feature.add("saveAsImage", saveAsImage);
toolbox.add("feature", feature);
chartJson.add("toolbox", toolbox);
chartJson.remove("title");
String updatedGenChart = chartJson.toString();
chart.setGoal(userGoal);
chart.setChartData(userData);
chart.setGenChart(updatedGenChart);
chart.setGenResult(genResult);
chart.setUserId(loginUser.getId());
chart.setStatus("succeed");
boolean saveResult = chartService.save(chart);
if (!saveResult)
handleChartUpdateError(chart.getId(),"图表信息保存失败");
biResponse.setGenChart(updatedGenChart);
biResponse.setGenResult(genResult);
biResponse.setChartId(chart.getId());
return ResultUtils.success(biResponse);
}
/**
* 图表错误状态处理
* @param chartId
* @param execMessage
*/
private void handleChartUpdateError(long chartId, String execMessage){
Chart updateChart = new Chart();
updateChart.setId(chartId);
updateChart.setStatus("failed");
updateChart.setExecMessage(execMessage);
boolean b = chartService.updateById(updateChart);
if (!b)
log.error("更新图表失败状态错误" + chartId + ":" + execMessage);
}
2.4.2 serviceImpl
@Service
public class DoubaoAIServiceImpl implements DoubaoAIService {
@Autowired
private DoubaoAIManager doubaoAIManager;
// 理论最大处理数据条数,处理时间约为30s
public static final Integer SYNCHRO_MAX_TOKEN = 340;
// 设置重试,重试次数2次,重试间隔2s
private final Retryer<String> retryer = RetryerBuilder.<String>newBuilder()
.retryIfResult(result -> (!isValidResult(result)))
.withStopStrategy(StopStrategies.stopAfterAttempt(2))
.withWaitStrategy(WaitStrategies.fixedWait(2, java.util.concurrent.TimeUnit.SECONDS))
.build();
public String doChartChat(String userMessage) {
final String systemPrompt = "你是一个数据分析师和前端开发专家,接下来我会按照以下固定格式给你提供内容:\n" +
"分析需求:\n" + "{数据分析的需求或者目标}\n" +
"原始数据:\n" + "{csv格式的原始数据,用,作为分隔符}\n" +
"请根据这两部分内容,按照以下指定格式生成内容,其中包括生成分隔符"@@@@@",同时分析结论请直接给出(此外不要输出任何多余的开头、结尾、注释)\n" +
"@@@@@\n" +
"{前端 Echarts V5 的 option 配置对象js代码(json格式),代码需要包括 title.text(需要该图的名称)部分、图例部分(即 legend 元素,文字部分应为黑色,图例线颜色与图例颜色相同),合理地将数据进行可视化,图表要求:1、若图表有轴线请将轴线画出,如 y 轴线,颜色为黑色 2、坐标字体为黑色,3.图表在鼠标悬停时可以显示数据\n" +
"@@@@@\n" +
"{请直接明确的数据分析结论、越详细越好,不要生成多余的注释、大括号和双引号}";
return doubaoAIManager.doDoubaostableAIRequest(systemPrompt, userMessage);
}
/**
* 分析结果是否存在错误
* @param result
* @return
*/
private boolean isValidResult(String result) {
String[] splits = result.split("@@@@@");
if (splits.length < 3)
return false;
String genChart = splits[1].trim();
try {
JsonObject chartJson = JsonParser.parseString(genChart).getAsJsonObject();
// 检查是否存在 "title" 字段
if (!chartJson.has("title")) {
return false;
}
// 检查 "title" 字段的内容是否为空或不含 "text" 字段
JsonElement titleElement = chartJson.getAsJsonObject("title").get("text");
if (titleElement == null || titleElement.isJsonNull()) {
return false;
}
String titleText = titleElement.getAsString();
if (titleText.isEmpty()) {
return false;
}
} catch (JsonSyntaxException e) {
// Json解析异常,直接返回 false
return false;
}
return true;
}
3.编写前端代码
3.1核对接口信息
/**
* @param genChartByAiRequest
* @param requestBody
* @returns BaseResponseBiResponse OK
* @throws ApiError
*/
public static genChartByAi(
file?: Blob,
chartName?: string,
chartType?: string,
goal?: string,
): CancelablePromise<BaseResponseBiResponse> {
return __request(OpenAPI, {
method: 'POST',
url: '/chart/gen',
query: {
'chartName': chartName,
'chartType': chartType,
'goal': goal,
},
formData: {
'file': file,
},
//这里需要注意!!默认是json,不修改会导致后端报错Current request is not a multipart
mediaType: 'multipart/form-data',
});
}
3.2编写vue代码
<a-button type="primary" @click="save" :disabled="isLoading">
<template #icon><icon-search /></template>
<template #default>提交分析</template>
</a-button>
const isLoading = ref(false);
const chartRef = ref(null);
const data = reactive({
params: {
chartType: "",
goal: "",
chartName: "",
file: null,
},
result: "",
options: null,
});
const filelist = ref([]);
const beforeUpload = (file) => {
const validFormats = [".xls", ".xlsx", ".csv"];
const fileFormat = file.name
.substring(file.name.lastIndexOf("."))
.toLowerCase();
if (!validFormats.includes(fileFormat)) {
// 弹出提示:请检查格式
// 可以使用 Message 或其他 UI 组件库的弹窗组件
message.error("请检查文件格式!");
return false;
}
filelist.value = [file];
data.params.file = file;
return false;
};
const save = () => {
isLoading.value = true;
// 检查是否已经上传了文件
if (!data.params.file || filelist.value.length === 0) {
isLoading.value = false;
message.error("请至少上传一个文件");
return; // 提前返回,不执行后续的代码
}
const file = data.params.file;
const chartType = data.params.chartType;
const goal = data.params.goal;
const chartName = data.params.chartName;
// 注意参数顺序!!
ChartControllerService.genChartByAi(file, chartName,chartType, goal)
.then((res) => {
if (res.data && res.data.genResult) {
data.result = res.data.genResult;
data.options = JSON.parse(res.data.genChart);
isLoading.value = false;
} else {
message.error("AI服务调用出错,请联系管理员");
isLoading.value = false;
}
})
.catch((error) => {
message.error(`Error: ${error}`);
isLoading.value = false;
});
};
4.成果展示
5.存在的问题
生成图表时间过长,接口需要20-30秒才能返回结果
request end, id: b59e9d87-fe0e-4cae-9e80-20fd264176e0, cost: 28396ms
可能代码中存在重复的“retry”,需要进一步排查。