LangChain4j LLM API 详解并接入OpenAI实现聊天

3,801 阅读9分钟

#LangChain4j系列:一文带你入门LangChain4j框架 文章对LangChain4j框架有个整体的认识,并接入本地大模型、OpenAI大模型两种模型实现简单的对话功能。本篇文章将介绍如何使用 high-level API AiServices 。

LangChain4j 支持的LLMs

Provider
「提供商」
Streaming
「流式返回」
Tools
「 函数调用 」
支持模态类型Local
「本地部署」
Native
「支持原生」
Zhipu AI文本、图片
Qianfan文本
ChatGLM文本
Ollama文本、图片
Hugging Face文本
OpenAI文本、图片
Azure OpenAI文本、图片
Amazon Bedrock文本
Anthropic文本、图片
DashScope文本、图片
Google Vertex AI Gemini文本、图片、音频、视频、PDF
Google Vertex AI PaLM 2text
Jlama文本
LocalAI文本
Mistral AI文本
Cloudflare Workers AI文本

与 Spring AI 框架支持的大模型基本相似。下面我们将详细介绍LangChain4j支持的 LLM API。上一节中我们知道 LangChain4j 支持 low-level API 和 high-level API 两种。

low-level LLM API

Model API

  • LanguageModel:API 非常简单。参数类型:String 返回类型:String。 目前废弃!!!!
  • ChatLanguageModel:取代 LanguageModel,参数:ChatMessage,返回:AiMessage。

    对于LanguageModel已不再扩展,可以把这个忘掉了。在实际开发中使用ChatLanguageModel

    public interface ChatLanguageModel {
        // 默认方法,通过用户输入,大模型生成结果
        default String generate(String userMessage) {
            return generate(UserMessage.from(userMessage)).content().text();
        }
        // 默认方法,通过用户输入,大模型生成结果,根据消息的序列返回
        // 消息的序列:System (optional) - User - AI - User - AI - User 
        default Response<AiMessage> generate(ChatMessage... messages) {
            return generate(asList(messages));
        }
        // 同上,只是参数不同
        Response<AiMessage> generate(List<ChatMessage> messages);
    
        // 默认方法,根据一组用户输入和一组用户自定义的工具,大模型生成返回内容
        default Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
            throw new IllegalArgumentException("Tools are currently not supported by this model");
        }
        // 默认方法,根据一组用户输入和一个用户自定义的工具,大模型生成返回内容
        default Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
            throw new IllegalArgumentException("Tools are currently not supported by this model");
        }
    }
    
  • EmbeddingModel:嵌入模型,将文本转换向量。
    public interface EmbeddingModel {
        // 功能:将输入文本转换嵌入(一组向量)
        default Response<Embedding> embed(String text) {
            return embed(TextSegment.from(text));
        }
        // 将文本片段转换为嵌入(一组向量)
        default Response<Embedding> embed(TextSegment textSegment) {
            Response<List<Embedding>> response = embedAll(singletonList(textSegment));
            ValidationUtils.ensureEq(response.content().size(), 1,
                    "Expected a single embedding, but got %d", response.content().size());
            return Response.from(response.content().get(0), response.tokenUsage(), response.finishReason());
        }
        // // 将一组文本片段转换为嵌入(一组向量)
        Response<List<Embedding>> embedAll(List<TextSegment> textSegments);
    }
    
  • ImageModel:图片模型,可以生成和编辑图片。
    public interface ImageModel {
      
        // 根据用户的输入生成一张图片对象,Image 包含url、base64、mimeType等
        Response<Image> generate(String prompt);
        
        // 默认方法,根据用户输入和要生成图片的张数,生成一组图片对象, 
        // 具体使用还的参照图片的模型的特点,设置参数,并不是所有的模型对参数都支持
        default Response<List<Image>> generate(String prompt, int n) {
            throw new IllegalArgumentException("Operation is not supported");
        }
        // 默认方法,根据存在的图片和用户输入的要求,编辑图片
        default Response<Image> edit(Image image, String prompt) {
            throw new IllegalArgumentException("Operation is not supported");
        }
        // 同上,重点是mask参数,后续在研究
        default Response<Image> edit(Image image, Image mask, String prompt) {
            throw new IllegalArgumentException("Operation is not supported");
        }
    }
    
  • ModerationModel:模型可以检查文本是否包含有害内容。
    public interface ModerationModel {
        // 审核给定的文本,判断是否有有害内容
        Response<Moderation> moderate(String text);
    
        // 默认方法,审核给定的文本,判断是否有有害内容
        default Response<Moderation> moderate(Prompt prompt) {
            return moderate(prompt.text());
        }
        // 废弃!!!!
        @SuppressWarnings("deprecation")
        default Response<Moderation> moderate(ChatMessage message) {
            return moderate(message.text());
        }
    
        // 审核一组ChatMessage,判断是否有有害内容
        Response<Moderation> moderate(List<ChatMessage> messages);
        
        // 审核文本片段,判断是否有有害内容
        default Response<Moderation> moderate(TextSegment textSegment) {
            return moderate(textSegment.text());
        }
    }
    
  • ScoringModel:可以针对查询对多段文本进行评分(或排名)。
    public interface ScoringModel {
        // 根据给定的查询为给定的文本打分。即根据query给text打分
        default Response<Double> score(String text, String query) {
            return score(TextSegment.from(text), query);
        }
        // 根据给定的查询为给定的文本打分。即根据query给TextSegment打分
        default Response<Double> score(TextSegment segment, String query) {
            Response<List<Double>> response = scoreAll(singletonList(segment), query);
            ensureEq(response.content().size(), 1,
                    "Expected a single score, but received %d", response.content().size());
            return Response.from(response.content().get(0), response.tokenUsage(), response.finishReason());
        }
        // 根据给定的查询为给定的文本打分。即根据query给一组TextSegment打分
        Response<List<Double>> scoreAll(List<TextSegment> segments, String query);
    }
    

ChatMessage 类型

image.png ChatMessage 有四种类型的聊天消息;

  • UserMessage:用户的消息。用户可以是应用程序的最终用户,也可以是应用程序本身。根据 LLM支持 UserMessage 的模态。可以只包含文本,也可以包含文本和/或图像。
  • AiMessage:由 AI 生成的消息,通常用于响应 UserMessage
  • ToolExecutionResultMessage:是 ToolExecutionRequest 的结果,函数调用时使用,开发人员一般无需关心。
  • SystemMessage:定义扮演什么角色、它应该如何表现、以什么方式回答等的说明,比其他类型的消息优先级更高。

    最好不要让最终用户自由访问定义或注入一些输入。通常,它位于对话的开头。

在使用上还可以使用多ChatMessage,一并传给大模型,实现上下文聊天记忆,后续文章中会有介绍。

high-level LLM API

low-level API 使用上非常灵活,但是享受了自由的同时也迫使你编写大量的样板代码。

在实现 LLM-powered 应用程序通常不仅需要单个组件,需要多个组件协同工作(例如,提示模板、聊天记忆、LLMs输出解析器、RAG组件:嵌入模型和存储),并且通常涉及多个交互,因此编排它们变得更加繁琐。

为让开发者更专注于业务逻辑,而不是低级实现细节。LangChain4j 中目前有两个高级概念可以帮助解决这个问题:AI ServicesChains

高级API注解

  • @AiService:定义一个集成大模型的服务,标注在接口上,无需实现方法。

    参数说明:

    • wiringMode:类注入类型,有两种方式,AUTOMATIC「自动注入」 和 EXPLICIT「手动指定名称」。
    • chatModel:指定使用大模型,如果wiringMode =EXPLICIT,则chatModel指定模型的Bean名称。
    • streamingChatModel:指定支持流式响应的大模型,设置规则与wiringMode设置有关系。
    • chatMemory:指定上下文记忆Bean,设置规则与wiringMode设置有关系。
    • chatMemoryProvider:指定上下文记忆Provider,设置规则与wiringMode设置有关系。
    • contentRetriever:内容获取器,在实现RAG时,需要从向量数据库、文件等获取内容,作为大模型交互的上下文信息使用。
    • retrievalAugmentor:检索聚合器,也是在实现RAG时,将多个组件进行聚合来增强RAG能力,参考:juejin.cn/post/745438… 文件
    • tools:指定需要调用的函数集合,支持函数调用。
  • @Tool:函数定义,可以设置在 @AiService tools的值上。
  • @MemoryId:定义上下文记忆Id,支持上下文记忆能力。
  • @SystemMessage:定义系统消息,定义模版方式两种:字符串/@SystemMessage(fromResource = "my-prompt-template.txt")
  • @UserMessage:定义用户消息,定义模板方式有两种:字符串/@UserMessage(fromResource = "my-prompt-template.txt")
  • @V:定义模板中的参数名称,一般与@SystemMessage或者@UserMessage配合使用,如果方法仅有一个参数可以使用模板中使用{{it}},如果有多个参数需要@V定义名称。
  • @UserName:@UserName注释的方法参数的值将被注入UserMessage的字段“name”
  • @Moderate: 指定方法需要使用审核模型,对输入/输出内容进行检查。

AiService支持其它高级功能

  • 格式化输入/输出
  • Chat memory
  • Tools/Function Calling
  • RAG(Retrieval-augmented Generation)
  • Chains(连接多个组件)

AiServices/DefaultAiServices

可以通过AiServices编程式创建 AI Service 实例,如下为核心的一些源码实现;

  public abstract class AiServices<T> {
      // 根据定义接口,指定语言模型创建 AI Service 实例
      public static <T> T create(Class<T> aiService, ChatLanguageModel chatLanguageModel) {
          return builder(aiService)
                  .chatLanguageModel(chatLanguageModel)
                  .build();
      }
      // 根据定义接口,指定流式语言模型创建 AI Service 实例
      public static <T> T create(Class<T> aiService, StreamingChatLanguageModel streamingChatLanguageModel) {
          return builder(aiService)
                  .streamingChatLanguageModel(streamingChatLanguageModel)
                  .build();
      }
      public static <T> AiServices<T> builder(Class<T> aiService) {
          AiServiceContext context = new AiServiceContext(aiService);
          for (AiServicesFactory factory : loadFactories(AiServicesFactory.class)) {
              return factory.create(context);
          }
          return new DefaultAiServices<>(context);
      }
      // 指定大语言模型
      public AiServices<T> chatLanguageModel(ChatLanguageModel chatLanguageModel) {
          context.chatModel = chatLanguageModel;
          return this;
      }
      // 指定流式大语言模型
      public AiServices<T> streamingChatLanguageModel(StreamingChatLanguageModel streamingChatLanguageModel) {
          context.streamingChatModel = streamingChatLanguageModel;
          return this;
      }

      // 指定SystemMessageProvider
      public AiServices<T> systemMessageProvider(Function<Object, String> systemMessageProvider) {
          context.systemMessageProvider = systemMessageProvider.andThen(Optional::ofNullable);
          return this;
      }

      // 指定聊天记忆实现类
      public AiServices<T> chatMemory(ChatMemory chatMemory) {
          context.chatMemories = new ConcurrentHashMap<>();
          context.chatMemories.put(DEFAULT, chatMemory);
          return this;
      }
      // 指定聊天记忆函数,一般用于实现按照聊天维度进行数据隔离
      public AiServices<T> chatMemoryProvider(ChatMemoryProvider chatMemoryProvider) {
          context.chatMemories = new ConcurrentHashMap<>();
          context.chatMemoryProvider = chatMemoryProvider;
          return this;
      }
      // 指定审核模型
      public AiServices<T> moderationModel(ModerationModel moderationModel) {
          context.moderationModel = moderationModel;
          return this;
      }
      // 指定自定义一组工具类
      public AiServices<T> tools(Object... objectsWithTools) {
          return tools(Arrays.asList(objectsWithTools));
      }
      public AiServices<T> tools(List<Object> objectsWithTools) { // TODO Collection?
          // TODO validate uniqueness of tool names
          context.toolSpecifications = new ArrayList<>();
          context.toolExecutors = new HashMap<>();

          for (Object objectWithTool : objectsWithTools) {
              if (objectWithTool instanceof Class) {
                  throw illegalConfiguration("Tool '%s' must be an object, not a class", objectWithTool);
              }

              for (Method method : objectWithTool.getClass().getDeclaredMethods()) {
                  if (method.isAnnotationPresent(Tool.class)) {
                      ToolSpecification toolSpecification = toolSpecificationFrom(method);
                      context.toolSpecifications.add(toolSpecification);
                      context.toolExecutors.put(toolSpecification.name(), new DefaultToolExecutor(objectWithTool, method));
                  }
              }
          }

          return this;
      }

      // 废弃!!!!!
      @Deprecated
      public AiServices<T> retriever(Retriever<TextSegment> retriever) {
          if (contentRetrieverSet || retrievalAugmentorSet) {
              throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set");
          }
          if (retriever != null) {
              AiServices<T> withContentRetriever = contentRetriever(retriever.toContentRetriever());
              retrieverSet = true;
              return withContentRetriever;
          }
          return this;
      }
      // 
      public AiServices<T> contentRetriever(ContentRetriever contentRetriever) {
          if (retrieverSet || retrievalAugmentorSet) {
              throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set");
          }
          contentRetrieverSet = true;
          context.retrievalAugmentor = DefaultRetrievalAugmentor.builder()
                  .contentRetriever(ensureNotNull(contentRetriever, "contentRetriever"))
                  .build();
          return this;
      }
      // 
      public AiServices<T> retrievalAugmentor(RetrievalAugmentor retrievalAugmentor) {
          if (retrieverSet || contentRetrieverSet) {
              throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set");
          }
          retrievalAugmentorSet = true;
          context.retrievalAugmentor = ensureNotNull(retrievalAugmentor, "retrievalAugmentor");
          return this;
      }
  }

@AiService实例化

langchain4j-spring-boot-starter-0.31.0.jar: AiServicesAutoConfig 实现实例化AI Service 实例。 源码实现非常简单,自行查看源码!!

使用high-level API实战

1. 依赖包

<dependency>
    <groupId>dev.langchain4j</groupId>
    <artifactId>langchain4j-spring-boot-starter</artifactId>
    <version>${langchain4j.version}</version>
</dependency>
<dependency>
    <groupId>dev.langchain4j</groupId>
    <artifactId>langchain4j-open-ai-spring-boot-starter</artifactId>
    <version>${langchain4j.version}</version>
</dependency>

2. yml配置

server:
  port: 8801
spring:
  application:
    name: chat-high-level-service
langchain4j:
  open-ai:
    chat-model:
      base-url: xxx
      api-key: xx

3. AiService

package org.ivy.aiservice.service;

import dev.langchain4j.service.spring.AiService;
import dev.langchain4j.service.spring.AiServiceWiringMode;

/**
 * 通过 @AiService 注解声明一个 AI 助手接口,并指定 chatModel 为 openAiChatModel。
 */
@AiService(
        wiringMode = AiServiceWiringMode.EXPLICIT, // 指定注入方式
        chatModel = "openAiChatModel", // 指定chatModel为OpenAiChatModel
        tools = {"calculator"} // 指定自定义的工具
)
public interface Assistant {
    String chat(String userMessage);
}

4. Tools

package org.ivy.aiservice.func;

import dev.langchain4j.agent.tool.Tool;
import org.springframework.stereotype.Component;

@Component
public class Calculator {

    @Tool("Calculates the length of a string")
    int stringLength(String s) {
        System.out.println("Called stringLength with s='" + s + "'");
        return s.length();
    }

    @Tool("Calculates the sum of two numbers")
    int add(int a, int b) {
        System.out.println("Called add with a=" + a + ", b=" + b);
        return a + b;
    }

    @Tool("Calculates the square root of a number")
    double sqrt(int x) {
        System.out.println("Called sqrt with x=" + x);
        return Math.sqrt(x);
    }
}

5. Controller

package org.ivy.aiservice;

import org.ivy.aiservice.service.Assistant;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RequestMapping("/hl")
@RestController
public class HighLevelChatController {
    private final Assistant assistant;

    public HighLevelChatController(Assistant assistant) {
        this.assistant = assistant;
    }

    @GetMapping("/chat")
    public String chat(
            @RequestParam(value = "prompt",
                    defaultValue = "What is the square root of the sum of the numbers of letters in the words "hello" and "world"?")
            String prompt) {
        return assistant.chat(prompt);
    }
}

6. 测试结果

image.png 借助工具完成复杂的任务处理。

总结

先对LLM API 有个简单的认识,对于格式化输出、RAG以及连接多个服务都没有进行代码示例。先对整个有个感觉,后面会不断进行更新文章对每一部分都进行详细的分析和实战。

示例代码:Github仓库