Spring-AI系列-AI模型-Model

346 阅读2分钟

原文-知识库,欢迎大家评论互动

AI Model API

Portable ModelAPI across AI providers for Chat, Text to Image, Audio Transcription, Text to Speech, and Embedding models. Both synchronous and stream API options are supported. Dropping down to access model specific features is also supported.

With support for AI Models from OpenAI, Microsoft, Amazon, Google, Amazon Bedrock, Hugging Face and more.

AI模型的三层结构

  1. 领域抽象Model、StreamingModel
  2. 能力抽象ChatModel、EmbeddingModel
  3. 提供商具象OpenAiChatModel、VertexAiGeminiChatModel、AnthropicChatModel

AI模型领域抽象

Model

调用AI模型的API

public interface Model<TReq extends ModelRequest<?>, TRes extends ModelResponse<?>> {

	/**
	 * Executes a method call to the AI model.
	 */
	TRes call(TReq request);

}

StreamingModel

调用AI模型的流式响应的API

public interface StreamingModel<TReq extends ModelRequest<?>, TResChunk extends ModelResponse<?>> {

	/**
	 * Executes a method call to the AI model.
	 */
	Flux<TResChunk> stream(TReq request);

}

ModelRequest

AI模型的输入

public interface ModelRequest<T> {

	/**
	 * Retrieves the instructions or input required by the AI model.
	 */
	T getInstructions(); // required input

	ModelOptions getOptions();

}

ModelResponse

AI模型的响应

public interface ModelResponse<T extends ModelResult<?>> {

	/**
	 * Retrieves the result of the AI model.
	 */
	T getResult();

	List<T> getResults();

	ResponseMetadata getMetadata();

}

ModelResult

AI模型的输出

public interface ModelResult<T> {

	/**
	 * Retrieves the output generated by the AI model.
	 */
	T getOutput();

	ResultMetadata getMetadata();

}

AI模型能力抽象

ChatModel

对话模型,文本聊天交互模型。

工作原理是接收 Prompt 或部分对话作为输入,将输入发送给后端大模型,模型根据其训练数据和对自然语言的理解生成对话响应,应用程序可以将响应呈现给用户或用于进一步处理。

public interface ChatModel extends Model<Prompt, ChatResponse>, StreamingChatModel {

	default String call(String message) {
		Prompt prompt = new Prompt(new UserMessage(message));
		Generation generation = call(prompt).getResult();
		return (generation != null) ? generation.getOutput().getText() : "";
	}

	default String call(Message... messages) {
		Prompt prompt = new Prompt(Arrays.asList(messages));
		Generation generation = call(prompt).getResult();
		return (generation != null) ? generation.getOutput().getText() : "";
	}

	@Override
	ChatResponse call(Prompt prompt);

	default ChatOptions getDefaultOptions() {
		return ChatOptions.builder().build();
	}

	@Override
	default Flux<ChatResponse> stream(Prompt prompt) {
		throw new UnsupportedOperationException("streaming is not supported");
	}

}

StreamingChatModel

对话模型的流式响应的API

@FunctionalInterface
public interface StreamingChatModel extends StreamingModel<Prompt, ChatResponse> {

	default Flux<String> stream(String message) {
		Prompt prompt = new Prompt(message);
		return stream(prompt).map(response -> (response.getResult() == null || response.getResult().getOutput() == null
				|| response.getResult().getOutput().getText() == null) ? ""
						: response.getResult().getOutput().getText());
	}

	default Flux<String> stream(Message... messages) {
		Prompt prompt = new Prompt(Arrays.asList(messages));
		return stream(prompt).map(response -> (response.getResult() == null || response.getResult().getOutput() == null
				|| response.getResult().getOutput().getText() == null) ? ""
						: response.getResult().getOutput().getText());
	}

	@Override
	Flux<ChatResponse> stream(Prompt prompt);

}

EmbeddingModel

嵌入模型

嵌入(Embedding)的工作原理是将文本、图像和视频转换为向量(Vector)的浮点数数组。
嵌入模型(EmbeddingModel)是嵌入过程中采用的模型。

public interface EmbeddingModel extends Model<EmbeddingRequest, EmbeddingResponse> {

	@Override
	EmbeddingResponse call(EmbeddingRequest request);

	// 向量嵌入

	default float[] embed(String text) {
		Assert.notNull(text, "Text must not be null");
		List<float[]> response = this.embed(List.of(text));
		return response.iterator().next();
	}

	/**
	 * Embeds the given document's content into a vector.
	 * @param document the document to embed.
	 * @return the embedded vector.
	 */
	float[] embed(Document document);

	default List<float[]> embed(List<String> texts) {
		Assert.notNull(texts, "Texts must not be null");
		return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()))
			.getResults()
			.stream()
			.map(Embedding::getOutput)
			.toList();
	}

	default List<float[]> embed(List<Document> documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) {
		Assert.notNull(documents, "Documents must not be null");
		List<float[]> embeddings = new ArrayList<>(documents.size());
		List<List<Document>> batch = batchingStrategy.batch(documents);
		for (List<Document> subBatch : batch) {
			List<String> texts = subBatch.stream().map(Document::getText).toList();
			EmbeddingRequest request = new EmbeddingRequest(texts, options);
			EmbeddingResponse response = this.call(request);
			for (int i = 0; i < subBatch.size(); i++) {
				embeddings.add(response.getResults().get(i).getOutput());
			}
		}
		Assert.isTrue(embeddings.size() == documents.size(),
				"Embeddings must have the same number as that of the documents");
		return embeddings;
	}

	default EmbeddingResponse embedForResponse(List<String> texts) {
		Assert.notNull(texts, "Texts must not be null");
		return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()));
	}

	default int dimensions() {
		return embed("Test String").length;
	}

}

AI模型提供商具象