OpenAI API - chat completion by stream

1,383 阅读2分钟

之前实现了OpenAI API JAVA版本(参考文章:OpenAi API JAVA版)。

使用Vue3+java实现了类似官方playground的功能,产品体验地址:h5.felh.xyz

最近实现了调用create chat completion时,参数stream=true,使用# server-sent events进行接收。


客户端接收信息的Listener

package xyz.felh.openai;

import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import xyz.felh.openai.completion.chat.ChatCompletion;

@Slf4j
public abstract class StreamChatCompletionListener {

    private String clientId;

    public String getClientId() {
        return clientId;
    }

    public void setClientId(String clientId) {
        this.clientId = clientId;
    }

    /**
     * Invoked when an event source has been accepted by the remote peer and may begin transmitting
     * events.
     *
     * @param requestId request ID
     * @param response  OK http response
     */
    public void onOpen(String requestId, Response response) {
        log.info("onOpen: {}", requestId);
    }

    /**
     * event line
     *
     * @param requestId      request ID
     * @param chatCompletion return chat completion
     */
    public void onEvent(String requestId, ChatCompletion chatCompletion) {
        log.info("onEvent: {}", requestId);
    }

    /**
     * event message finished
     *
     * @param requestId request ID
     */
    public void onEventDone(String requestId) {
        log.info("onEventDone: {}", requestId);
    }

    /**
     * <p>
     * No further calls to this listener will be made.
     *
     * @param requestId request ID
     */
    public void onClosed(String requestId) {
        log.info("onClosed: {}", requestId);
    }

    /**
     * Invoked when an event source has been closed due to an error reading from or writing to the
     * network. Incoming events may have been lost. No further calls to this listener will be made.
     *
     * @param requestId request ID
     * @param t         throwable
     * @param response  response
     */
    public void onFailure(String requestId, Throwable t, Response response) {
        log.error("onFailure: {}", requestId, t);
    }

}

创建SSE客户端代码

    /**
     * create chat completion by stream
     *
     * @param requestId request ID, every observer is unique
     * @param request   detail of request
     */
    public void createSteamChatCompletion(final String requestId, CreateChatCompletionRequest request) {
        request.setStream(true);
        Request okHttpRequest;
        try {
            okHttpRequest = new Request.Builder().url(BASE_URL + "/v1/chat/completions")
                    .header("content-type", "text/event-stream")
                    .header("Accept", "text/event-stream")
                    .post(RequestBody.create(defaultObjectMapper().writeValueAsString(request),
                            MediaType.parse("application/json")))
                    .build();
        } catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
        EventSource.Factory factory = EventSources.createFactory(client);
        EventSourceListener eventSourceListener = new EventSourceListener() {
            @Override
            public void onOpen(@NonNull EventSource eventSource, @NonNull Response response) {
                streamChatCompletionListeners.forEach(it -> it.onOpen(requestId, response));
            }

            @Override
            public void onEvent(@NonNull EventSource eventSource, @Nullable String id, @Nullable String type, @NonNull String data) {
                if (data.equals("[DONE]")) {
                    streamChatCompletionListeners.forEach(it -> it.onEventDone(requestId));
                } else {
                    try {
                        ChatCompletion chatCompletion = defaultObjectMapper().readValue(data, ChatCompletion.class);
                        streamChatCompletionListeners.forEach(it -> it.onEvent(requestId, chatCompletion));
                    } catch (JsonProcessingException e) {
                        throw new RuntimeException(e);
                    }
                }
            }

            @Override
            public void onClosed(@NonNull EventSource eventSource) {
                streamChatCompletionListeners.forEach(it -> it.onClosed(requestId));
            }

            @Override
            public void onFailure(@NonNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
                streamChatCompletionListeners.forEach(it -> it.onFailure(requestId, t, response));
            }
        };
        factory.newEventSource(okHttpRequest, eventSourceListener);
    }

客户端使用例子

    @Test
    public void createStreamChatCompletion() {
        StreamChatCompletionListener listener = new StreamChatCompletionListener() {
            @Override
            public void onEvent(String requestId, ChatCompletion chatCompletion) {
                log.info("model gpt-3.5-turbo: {}", chatCompletion.getChoices().get(0).getDelta().getContent());
            }

            @Override
            public void onFailure(String requestId, Throwable t, Response response) {
                t.printStackTrace();
            }
        };
        listener.setClientId("client_id");
        getOpenAiService().addStreamChatCompletionListener(listener);
        CreateChatCompletionRequest chatCompletionRequest = CreateChatCompletionRequest.builder()
                .messages(Collections.singletonList(new ChatMessage(ChatMessageRole.USER, "What's 1+1? Answer in one word.")))
                .model("gpt-3.5-turbo")
                .build();
        getOpenAiService().createSteamChatCompletion("request_id_123", chatCompletionRequest);
    }