借助豆包大模型为AI数据分析可视化赋能

727 阅读8分钟

0.前言

也许你也想过让AI完成数据分析吧。之前给网站接入的是chatgpt,现在5美元的额度用完了,想着换一家。看到豆包大模型送50万tokens的消息,抱着学习的心态,那就马上安排吧!

1.大模型准备工作

后端采用jdk17,springboot3.3.5,前端采用vue3和Arco Design

1.1 访问控制台

访问火山引擎-云上增长新动力,登录账号后进入“控制台”

image.png

1.2 创建一个apikey

image.png

1.3 创建一个推理接入点

image.png

1.4 输入相关信息并选择模型

image.png

进入API调用设置,这里给出了示例代码,需要复制apiKey和.model的内容。

image.png

2.编写后端代码

2.0 在pom.xml中导入sdk

<dependency>
    <groupId>com.volcengine</groupId>
    <artifactId>volcengine-java-sdk-ark-runtime</artifactId>
    <version>0.1.134</version>
</dependency>
jdk8以上推荐导入
<dependency>
  <groupId>javax.annotation</groupId>
  <artifactId>javax.annotation-api</artifactId>
  <version>1.3.2</version>
</dependency>

2.1封装AIConfig

import com.volcengine.ark.runtime.service.ArkService;
import jakarta.annotation.Resource;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.time.Duration;

@Configuration
@ConfigurationProperties(prefix = "doubaoai")
@Data
public class DoubaoAIConfig {
    private String apiKey;

    @Bean
    public ArkService getDoubaoAIService() {
        return ArkService.builder()
                .apiKey(apiKey)
                //接入点以(华北2北京)为例
                .baseUrl("https://ark.cn-beijing.volces.com/api/v3/")
                .timeout(Duration.ofSeconds(120))
                .connectTimeout(Duration.ofSeconds(20))
                .retryTimes(2)
                .build();
    }
}

记得在application.yml中增加apiKey

doubaoai:
  apiKey: xxxxxxxxxxxxxxxxxx

2.2封装AIManager

为了提高AI服务的通用性,我这里定义了AIManager。

import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionRequest;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessage;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
import com.volcengine.ark.runtime.service.ArkService;
import jakarta.annotation.Resource;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;


@Component
public class DoubaoAIManager {
    @Resource
    ArkService arkService;

    private static final double STABLE_TEMPERATURE = 0.01; //更稳定
    private static final double UNSTABLE_TEMPERATURE = 0.99; //更随机

    /**
     * 通用请求
     *
     * @param messages
     * @param temperature
     * @return
     */
    public String doDoubaoAIRequest(List<ChatMessage> messages,Double temperature) {
        //构造请求
        ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
                .model("ep-需要修改")
                .temperature(temperature)
                .messages(messages)
                .build();
        // 获取第一个结果的内容,如果没有结果则返回空字符串
        try {
            return arkService.createChatCompletion(chatCompletionRequest).getChoices()
                    .stream()
                    .map(choice -> choice.getMessage().getContent())
                    .findFirst()
                    .orElse("")
                    .toString();
        } catch (Exception e) {
            e.printStackTrace();
            throw new BusinessException(ErrorCode.SYSTEM_ERROR, e.getMessage());
        }
    }

    /**
     * 通用请求(简化消息传递)
     *
     * @param systemMessage
     * @param userMessage
     * @param temperature
     * @return
     */
    //可自定义message
    public String doDoubaoAIRequest(String systemMessage, String userMessage, Double temperature) {
        List<ChatMessage> chatMessageList = new ArrayList<>();
        ChatMessage systemChatMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(systemMessage).build();
        ChatMessage userChatMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(userMessage).build();
        chatMessageList.add(systemChatMessage);
        chatMessageList.add(userChatMessage);
        return doDoubaoAIRequest(chatMessageList, temperature);
    }

    /**
     * 通用请求(简化消息传递,提高随机性)
     *
     * @param systemMessage
     * @param userMessage
     * @return
     */
    public String doDoubaoUnstableAIRequest(String systemMessage, String userMessage) {
        return doDoubaoAIRequest(systemMessage, userMessage, UNSTABLE_TEMPERATURE);
    }

    /**
     * 通用请求(简化消息传递,提高稳定性)
     *
     * @param systemMessage
     * @param userMessage
     * @return
     */

    public String doDoubaostableAIRequest(String systemMessage, String userMessage) {
        return doDoubaoAIRequest(systemMessage, userMessage,STABLE_TEMPERATURE);
    }
}

2.3测试类编写

核心代码写好了,我们来测试一下(使用标准方式)

import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionRequest;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessage;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
import com.volcengine.ark.runtime.service.ArkService;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import jakarta.annotation.Resource;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;

//否则出现报错“Java jakarta.websocket.server.ServerContainer not available”
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class DoubaoAI {
    @Resource
    private ArkService arkService;
    @Test
    void test() {
        System.out.println("\n----- standard request -----");
        final List<ChatMessage> messages = new ArrayList<>();
        final ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content("你是豆包,是由字节跳动开发的 AI 人工智能助手").build();
        final ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content("常见的十字花科植物有哪些?").build();
        messages.add(systemMessage);
        messages.add(userMessage);

        ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
                .model("ep-需要修改")
                .messages(messages)
                .build();

        arkService.createChatCompletion(chatCompletionRequest).getChoices().forEach(choice -> System.out.println(choice.getMessage().getContent()));
        arkService.shutdownExecutor();
    }
}

能够正确得到回复 image.png

2.4 编写业务代码

2.4.1 controller

import cn.hutool.core.io.FileUtil;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.google.gson.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletRequest;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadPoolExecutor;


/**
 * 图表接口
 *
 */
@RestController
@RequestMapping("/chart")
@Slf4j
public class ChartController {

    @Resource
    private ChartService chartService;

    @Resource
    private UserService userService;

    @Resource
    private DoubaoAIService doubaoAIService;

    @Resource
    private RedisCacheManager redisCacheManager;

    @Resource
    private RedisLimiterManager redisLimiterManager;

    @Resource
    private ThreadPoolExecutor threadPoolExecutor;

    @Resource
    private ChartMessageProducer chartMessageProducer;

    @Resource
    private TaskManager taskManager;

    /**
     * 文件AI分析
     *
     * @param multipartFile
     * @param genChartByAiRequest
     * @param request
     * @return
     */
    @PostMapping("/gen")
    public BaseResponse<BiResponse> genChartByAi(@RequestPart("file") MultipartFile multipartFile,
                                                 GenChartByAiRequest genChartByAiRequest, HttpServletRequest request) {
        String chartName = genChartByAiRequest.getChartName();
        String goal = genChartByAiRequest.getGoal();
        String chartType = genChartByAiRequest.getChartType();
        // 校验
        ThrowUtils.throwIf(StringUtils.isNotBlank(chartName) && chartName.length() > 100, ErrorCode.PARAMS_ERROR, "名称过长");
        User loginUser = userService.getLoginUser(request);
        // 校验文件大小及后缀
        long size = multipartFile.getSize();
        final long TEN_MB = 10 * 1024 * 1024L;
        ThrowUtils.throwIf(size > TEN_MB, ErrorCode.PARAMS_ERROR, "文件大小大于1M");
        String fileName = multipartFile.getOriginalFilename();
        String suffix = FileUtil.getSuffix(fileName);
        final List<String> validFileSuffix = Arrays.asList("xlsx", "csv", "xls");
        ThrowUtils.throwIf(!validFileSuffix.contains(suffix), ErrorCode.PARAMS_ERROR, "文件后缀不正确");
        // 每个用户限流
        redisLimiterManager.doRateLimit("genChartByAi" + loginUser.getId());
        // 构造用户输入
        StringBuilder userInput = new StringBuilder();
        userInput.append("分析需求:\n");
        // 拼接分析目标
        String userGoal = "请帮我合理的分析一下数据";
        if (StringUtils.isNotBlank(goal))
            userGoal = goal;
        // 分析输入加入图表类型
        if (StringUtils.isNotBlank(chartType))
            userGoal += ",请使用" + chartType;
        userInput.append(userGoal).append("\n");
        userInput.append("原始数据:\n");
        // 压缩数据
        String userData = "";
        if (Objects.equals(suffix, "csv"))
            userData = Csv2String.MultipartFileToString(multipartFile);
        else
            userData = ExcelUtils.excel2Csv(multipartFile);
        BiResponse biResponse = new BiResponse();
        Chart chart = new Chart();
        // 数据规模校验 gpt3.5分析时长超过30s
        if (userData.length() > SYNCHRO_MAX_TOKEN){
            biResponse.setGenChart("ERROR");
            biResponse.setGenResult("请求数据量过大,请删减excel文档内容");
            biResponse.setChartId(chart.getId());
            return ResultUtils.success(biResponse);
        }
        userInput.append(userData).append("\n");
        String result = "";
        try {
            // 执行重试逻辑
            result = doubaoAIService.doChatWithRetry(userInput.toString());
        } catch (Exception e) {
            // 如果重试过程中出现异常,返回错误信息
            throw new BusinessException(ErrorCode.SYSTEM_ERROR, e + ",AI生成错误");
        }
        String[] splits = result.split("@@@@@");

        String genChart = splits[1].trim();
        String genResult = splits[2].trim();
        // Echarts代码过滤 "var option ="
        if (genChart.startsWith("var option =")) {
            // 去除 "var option ="
            genChart = genChart.replaceFirst("var\s+option\s*=\s*", "");
        }

        JsonObject chartJson = JsonParser.parseString(genChart).getAsJsonObject();

        // 自动加入图表名称结尾并设置图表名称
        if (StringUtils.isEmpty(chartName)){
            String genChartName = String.valueOf(chartJson.getAsJsonObject("title").get("text"));
            genChartName = genChartName.replace(""","");
            if ( !genChartName.endsWith("图") && !genChartName.endsWith("表") && !genChartName.endsWith("图表"))
                genChartName = genChartName + "图";
            chart.setChartName(genChartName);
        } else
            chart.setChartName(chartName);
        // 自动添加图表类型
        if (StringUtils.isEmpty(chartType)){
            JsonArray seriesArray = chartJson.getAsJsonArray("series");
            if (seriesArray.size() > 0) {
                JsonObject firstSeries = seriesArray.get(0).getAsJsonObject();
                String typeChart = firstSeries.getAsJsonObject().get("type").getAsString();
                String CnChartType = chartService.getChartTypeToCN(typeChart);
                chart.setChartType(CnChartType);
            }
        }else
            chart.setChartType(chartType);
        // 加入下载按钮
        JsonObject toolbox = new JsonObject();
        toolbox.addProperty("show", true);
        JsonObject saveAsImage = new JsonObject();
        saveAsImage.addProperty("show", true);
        saveAsImage.addProperty("excludeComponents", "['toolbox']");
        saveAsImage.addProperty("pixelRatio", 2);
        JsonObject feature = new JsonObject();
        feature.add("saveAsImage", saveAsImage);
        toolbox.add("feature", feature);
        chartJson.add("toolbox", toolbox);
        chartJson.remove("title");
        String updatedGenChart = chartJson.toString();
        chart.setGoal(userGoal);
        chart.setChartData(userData);
        chart.setGenChart(updatedGenChart);
        chart.setGenResult(genResult);
        chart.setUserId(loginUser.getId());
        chart.setStatus("succeed");
        boolean saveResult = chartService.save(chart);
        if (!saveResult)
            handleChartUpdateError(chart.getId(),"图表信息保存失败");
        biResponse.setGenChart(updatedGenChart);
        biResponse.setGenResult(genResult);
        biResponse.setChartId(chart.getId());
        return ResultUtils.success(biResponse);
    }

    /**
     * 图表错误状态处理
     * @param chartId
     * @param execMessage
     */
    private void handleChartUpdateError(long chartId, String execMessage){
        Chart updateChart = new Chart();
        updateChart.setId(chartId);
        updateChart.setStatus("failed");
        updateChart.setExecMessage(execMessage);
        boolean b = chartService.updateById(updateChart);
        if (!b)
            log.error("更新图表失败状态错误" + chartId + ":" + execMessage);
    }

2.4.2 serviceImpl

@Service
public class DoubaoAIServiceImpl implements DoubaoAIService {
    @Autowired
    private DoubaoAIManager doubaoAIManager;

    // 理论最大处理数据条数,处理时间约为30s
    public static final Integer SYNCHRO_MAX_TOKEN = 340;
    // 设置重试,重试次数2次,重试间隔2s
    private final Retryer<String> retryer = RetryerBuilder.<String>newBuilder()
            .retryIfResult(result -> (!isValidResult(result)))
            .withStopStrategy(StopStrategies.stopAfterAttempt(2))
            .withWaitStrategy(WaitStrategies.fixedWait(2, java.util.concurrent.TimeUnit.SECONDS))
            .build();

    public String doChartChat(String userMessage) {
        final String systemPrompt = "你是一个数据分析师和前端开发专家,接下来我会按照以下固定格式给你提供内容:\n" +
                "分析需求:\n" + "{数据分析的需求或者目标}\n" +
                "原始数据:\n" + "{csv格式的原始数据,用,作为分隔符}\n" +
                "请根据这两部分内容,按照以下指定格式生成内容,其中包括生成分隔符"@@@@@",同时分析结论请直接给出(此外不要输出任何多余的开头、结尾、注释)\n" +
                "@@@@@\n" +
                "{前端 Echarts V5 的 option 配置对象js代码(json格式),代码需要包括 title.text(需要该图的名称)部分、图例部分(即 legend 元素,文字部分应为黑色,图例线颜色与图例颜色相同),合理地将数据进行可视化,图表要求:1、若图表有轴线请将轴线画出,如 y 轴线,颜色为黑色 2、坐标字体为黑色,3.图表在鼠标悬停时可以显示数据\n" +
                "@@@@@\n" +
                "{请直接明确的数据分析结论、越详细越好,不要生成多余的注释、大括号和双引号}";
        return doubaoAIManager.doDoubaostableAIRequest(systemPrompt, userMessage);
    }



    /**
     * 分析结果是否存在错误
     * @param result
     * @return
     */
    private boolean isValidResult(String result) {
        String[] splits = result.split("@@@@@");
        if (splits.length < 3)
            return false;
        String genChart = splits[1].trim();
        try {
            JsonObject chartJson = JsonParser.parseString(genChart).getAsJsonObject();
            // 检查是否存在 "title" 字段
            if (!chartJson.has("title")) {
                return false;
            }
            // 检查 "title" 字段的内容是否为空或不含 "text" 字段
            JsonElement titleElement = chartJson.getAsJsonObject("title").get("text");
            if (titleElement == null || titleElement.isJsonNull()) {
                return false;
            }
            String titleText = titleElement.getAsString();
            if (titleText.isEmpty()) {
                return false;
            }
        } catch (JsonSyntaxException e) {
            // Json解析异常,直接返回 false
            return false;
        }
        return true;
    }

3.编写前端代码

3.1核对接口信息

/**
 * @param genChartByAiRequest
 * @param requestBody
 * @returns BaseResponseBiResponse OK
 * @throws ApiError
 */
public static genChartByAi(
    file?: Blob,
    chartName?: string,
    chartType?: string,
    goal?: string,
): CancelablePromise<BaseResponseBiResponse> {
    return __request(OpenAPI, {
        method: 'POST',
        url: '/chart/gen',
        query: {
            'chartName': chartName,
            'chartType': chartType,
            'goal': goal,
        },
        formData: {
            'file': file,
        },
        //这里需要注意!!默认是json,不修改会导致后端报错Current request is not a multipart
        mediaType: 'multipart/form-data',
    });
}

3.2编写vue代码

<a-button type="primary" @click="save" :disabled="isLoading">
  <template #icon><icon-search /></template>
  <template #default>提交分析</template>
</a-button>
const isLoading = ref(false);

const chartRef = ref(null);
const data = reactive({
  params: {
    chartType: "",
    goal: "",
    chartName: "",
    file: null,
  },
  result: "",
  options: null,
});
const filelist = ref([]);
const beforeUpload = (file) => {
  const validFormats = [".xls", ".xlsx", ".csv"];
  const fileFormat = file.name
      .substring(file.name.lastIndexOf("."))
      .toLowerCase();
  if (!validFormats.includes(fileFormat)) {
    // 弹出提示:请检查格式
    // 可以使用 Message 或其他 UI 组件库的弹窗组件
    message.error("请检查文件格式!");
    return false;
  }
  filelist.value = [file];
  data.params.file = file;
  return false;
};

const save = () => {
  isLoading.value = true;
  // 检查是否已经上传了文件
  if (!data.params.file || filelist.value.length === 0) {
    isLoading.value = false;
    message.error("请至少上传一个文件");
    return; // 提前返回,不执行后续的代码
  }

  const file = data.params.file;
  const chartType = data.params.chartType;
  const goal = data.params.goal;
  const chartName = data.params.chartName;
  // 注意参数顺序!!
  ChartControllerService.genChartByAi(file, chartName,chartType, goal)
      .then((res) => {
        if (res.data && res.data.genResult) {
          data.result = res.data.genResult;
          data.options = JSON.parse(res.data.genChart);
          isLoading.value = false;
        } else {
          message.error("AI服务调用出错,请联系管理员");
          isLoading.value = false;
        }
      })
      .catch((error) => {
        message.error(`Error: ${error}`);
        isLoading.value = false;
      });
};

4.成果展示

屏幕截图_3-11-2024_214129_localhost.jpeg

5.存在的问题

生成图表时间过长,接口需要20-30秒才能返回结果

request end, id: b59e9d87-fe0e-4cae-9e80-20fd264176e0, cost: 28396ms

可能代码中存在重复的“retry”,需要进一步排查。

6.参考文档

火山引擎SDK安装与初始化