diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiChatAutoConfiguration.java index 82b3b5d9577..c7914556029 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiChatAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiChatAutoConfiguration.java @@ -19,6 +19,7 @@ import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.model.SpringAIModelProperties; import org.springframework.ai.model.SpringAIModels; import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; @@ -42,6 +43,7 @@ import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; /** * Chat {@link AutoConfiguration Auto-configuration} for ZhiPuAI. @@ -63,14 +65,15 @@ public class ZhiPuAiChatAutoConfiguration { @ConditionalOnMissingBean public ZhiPuAiChatModel zhiPuAiChatModel(ZhiPuAiConnectionProperties commonProperties, ZhiPuAiChatProperties chatProperties, ObjectProvider restClientBuilderProvider, - RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, - ObjectProvider observationRegistry, + ObjectProvider webClientBuilderProvider, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler, ObjectProvider observationRegistry, ObjectProvider observationConvention, ToolCallingManager toolCallingManager, ObjectProvider toolExecutionEligibilityPredicate) { var zhiPuAiApi = zhiPuAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), chatProperties.getApiKey(), commonProperties.getApiKey(), - restClientBuilderProvider.getIfAvailable(RestClient::builder), responseErrorHandler); + restClientBuilderProvider.getIfAvailable(RestClient::builder), + webClientBuilderProvider.getIfAvailable(WebClient::builder), responseErrorHandler); var chatModel = new ZhiPuAiChatModel(zhiPuAiApi, chatProperties.getOptions(), toolCallingManager, retryTemplate, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), @@ -82,7 +85,8 @@ public ZhiPuAiChatModel zhiPuAiChatModel(ZhiPuAiConnectionProperties commonPrope } private ZhiPuAiApi zhiPuAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey, - RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + ResponseErrorHandler responseErrorHandler) { String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; Assert.hasText(resolvedBaseUrl, "ZhiPuAI base URL must be set"); @@ -90,7 +94,14 @@ private ZhiPuAiApi zhiPuAiApi(String baseUrl, String commonBaseUrl, String apiKe String resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey; Assert.hasText(resolvedApiKey, "ZhiPuAI API key must be set"); - return new ZhiPuAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler); + return ZhiPuAiApi.builder() + .baseUrl(resolvedBaseUrl) + .apiKey(new SimpleApiKey(resolvedApiKey)) + .restClientBuilder(restClientBuilder) + .webClientBuilder(webClientBuilder) + .responseErrorHandler(responseErrorHandler) + .build(); + } } diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiEmbeddingAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiEmbeddingAutoConfiguration.java index 52fd055e48b..a80913cdd3d 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiEmbeddingAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiEmbeddingAutoConfiguration.java @@ -19,6 +19,7 @@ import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; +import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.model.SpringAIModelProperties; import org.springframework.ai.model.SpringAIModels; import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; @@ -37,6 +38,7 @@ import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; /** * Embedding {@link AutoConfiguration Auto-configuration} for ZhiPuAI. @@ -54,13 +56,16 @@ public class ZhiPuAiEmbeddingAutoConfiguration { @Bean @ConditionalOnMissingBean public ZhiPuAiEmbeddingModel zhiPuAiEmbeddingModel(ZhiPuAiConnectionProperties commonProperties, - ZhiPuAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder, - RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, - ObjectProvider observationRegistry, + ZhiPuAiEmbeddingProperties embeddingProperties, + ObjectProvider restClientBuilderProvider, + ObjectProvider webClientBuilderProvider, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler, ObjectProvider observationRegistry, ObjectProvider observationConvention) { var zhiPuAiApi = zhiPuAiApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), - embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler); + embeddingProperties.getApiKey(), commonProperties.getApiKey(), + restClientBuilderProvider.getIfAvailable(RestClient::builder), + webClientBuilderProvider.getIfAvailable(WebClient::builder), responseErrorHandler); var embeddingModel = new ZhiPuAiEmbeddingModel(zhiPuAiApi, embeddingProperties.getMetadataMode(), embeddingProperties.getOptions(), retryTemplate, @@ -72,7 +77,8 @@ public ZhiPuAiEmbeddingModel zhiPuAiEmbeddingModel(ZhiPuAiConnectionProperties c } private ZhiPuAiApi zhiPuAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey, - RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + ResponseErrorHandler responseErrorHandler) { String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; Assert.hasText(resolvedBaseUrl, "ZhiPuAI base URL must be set"); @@ -80,7 +86,13 @@ private ZhiPuAiApi zhiPuAiApi(String baseUrl, String commonBaseUrl, String apiKe String resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey; Assert.hasText(resolvedApiKey, "ZhiPuAI API key must be set"); - return new ZhiPuAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler); + return ZhiPuAiApi.builder() + .baseUrl(resolvedBaseUrl) + .apiKey(new SimpleApiKey(resolvedApiKey)) + .restClientBuilder(restClientBuilder) + .webClientBuilder(webClientBuilder) + .responseErrorHandler(responseErrorHandler) + .build(); } } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java index cac51cbe232..10213b9734b 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,8 +32,11 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.NoopApiKey; +import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.retry.RetryUtils; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; @@ -41,25 +44,57 @@ import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -// @formatter:off /** - * Single class implementation of the ZhiPuAI Chat Completion API and + * Single class implementation of the + * ZhiPuAI Chat Completion API and * ZhiPuAI Embedding API. * * @author Geng Rong * @author Thomas Vitale + * @author YunKui Lu * @since 1.0.0 */ public class ZhiPuAiApi { + /** + * Returns a builder pre-populated with the current configuration for mutation. + */ + public Builder mutate() { + return new Builder(this); + } + + public static Builder builder() { + return new Builder(); + } + public static final String DEFAULT_CHAT_MODEL = ChatModel.GLM_4_Air.getValue(); + public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.Embedding_2.getValue(); + + public static final String DEFAULT_EMBEDDINGS_PATH = "/v4/embeddings"; + + public static final String DEFAULT_COMPLETIONS_PATH = "/v4/chat/completions"; + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; + private final String baseUrl; + + private final ApiKey apiKey; + + private final MultiValueMap headers; + + private final String completionsPath; + + private final String embeddingsPath; + + private final ResponseErrorHandler responseErrorHandler; + private final RestClient restClient; private final WebClient webClient; @@ -68,137 +103,203 @@ public class ZhiPuAiApi { /** * Create a new chat completion api with default base URL. - * * @param zhiPuAiToken ZhiPuAI apiKey. + * @deprecated Use {@link #builder()} instead. */ + @Deprecated public ZhiPuAiApi(String zhiPuAiToken) { this(ZhiPuApiConstants.DEFAULT_BASE_URL, zhiPuAiToken); } /** * Create a new chat completion api. - * * @param baseUrl api base URL. * @param zhiPuAiToken ZhiPuAI apiKey. + * @deprecated Use {@link #builder()} instead. */ + @Deprecated public ZhiPuAiApi(String baseUrl, String zhiPuAiToken) { this(baseUrl, zhiPuAiToken, RestClient.builder()); } /** * Create a new chat completion api. - * * @param baseUrl api base URL. * @param zhiPuAiToken ZhiPuAI apiKey. * @param restClientBuilder RestClient builder. + * @deprecated Use {@link #builder()} instead. */ + @Deprecated public ZhiPuAiApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restClientBuilder) { this(baseUrl, zhiPuAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); } /** * Create a new chat completion api. - * * @param baseUrl api base URL. * @param zhiPuAiToken ZhiPuAI apiKey. * @param restClientBuilder RestClient builder. * @param responseErrorHandler Response error handler. + * @deprecated Use {@link #builder()} instead. */ - public ZhiPuAiApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + @Deprecated + public ZhiPuAiApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { + this(baseUrl, new SimpleApiKey(zhiPuAiToken), new LinkedMultiValueMap<>(), DEFAULT_COMPLETIONS_PATH, + DEFAULT_EMBEDDINGS_PATH, restClientBuilder, WebClient.builder(), responseErrorHandler); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey ZhiPuAI apiKey. + * @param headers the http headers to use. + * @param completionsPath the path to the chat completions endpoint. + * @param embeddingsPath the path to the embeddings endpoint. + * @param restClientBuilder RestClient builder. + * @param webClientBuilder WebClient builder. + * @param responseErrorHandler Response error handler. + */ + private ZhiPuAiApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, String completionsPath, + String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + ResponseErrorHandler responseErrorHandler) { + Assert.hasText(completionsPath, "Completions Path must not be null"); + Assert.hasText(embeddingsPath, "Embeddings Path must not be null"); + Assert.notNull(headers, "Headers must not be null"); + + this.baseUrl = baseUrl; + this.apiKey = apiKey; + this.headers = headers; + this.completionsPath = completionsPath; + this.embeddingsPath = embeddingsPath; + this.responseErrorHandler = responseErrorHandler; Consumer authHeaders = h -> { - h.setBearerAuth(zhiPuAiToken); h.setContentType(MediaType.APPLICATION_JSON); + h.addAll(headers); }; - this.restClient = restClientBuilder - .baseUrl(baseUrl) - .defaultHeaders(authHeaders) - .defaultStatusHandler(responseErrorHandler) - .build(); - - this.webClient = WebClient.builder() // FIXME: use a builder instead - .baseUrl(baseUrl) - .defaultHeaders(authHeaders) - .build(); + this.restClient = restClientBuilder.clone() + .baseUrl(baseUrl) + .defaultHeaders(authHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + + // @formatter:off + this.webClient = webClientBuilder.clone() + .baseUrl(baseUrl) + .defaultHeaders(authHeaders) + .build(); // @formatter:on } - public static String getTextContent(List content) { + public static String getTextContent(List content) { return content.stream() - .filter(c -> "text".equals(c.type())) - .map(ChatCompletionMessage.MediaContent::text) - .reduce("", (a, b) -> a + b); + .filter(c -> "text".equals(c.type())) + .map(ChatCompletionMessage.MediaContent::text) + .reduce("", (a, b) -> a + b); } /** * Creates a model response for the given chat conversation. - * * @param chatRequest The chat completion request. - * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. */ public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + return chatCompletionEntity(chatRequest, new LinkedMultiValueMap<>()); + } + + /** + * Creates a model response for the given chat conversation. + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code + * and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest, + MultiValueMap additionalHttpHeader) { Assert.notNull(chatRequest, "The request body can not be null."); Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); + // @formatter:off return this.restClient.post() - .uri("/v4/chat/completions") - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); + .uri(this.completionsPath) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + // @formatter:on } /** * Creates a streaming chat response for the given chat conversation. - * - * @param chatRequest The chat completion request. Must have the stream property set to true. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. * @return Returns a {@link Flux} stream from chat completion chunks. */ public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest, + MultiValueMap additionalHttpHeader) { Assert.notNull(chatRequest, "The request body can not be null."); Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); AtomicBoolean isInsideTool = new AtomicBoolean(false); + // @formatter:off return this.webClient.post() - .uri("/v4/chat/completions") - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) - .map(chunk -> { - if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { - isInsideTool.set(true); - } - return chunk; - }) - .windowUntil(chunk -> { - if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - .concatMapIterable(window -> { - Mono monoChunk = window.reduce( - new ChatCompletionChunk(null, null, null, null, null, null), - this.chunkMerger::merge); - return List.of(monoChunk); - }) - .flatMap(mono -> mono); + .uri(this.completionsPath) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) // @formatter:on + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + .concatMapIterable(window -> { + Mono monoChunk = window + .reduce(new ChatCompletionChunk(null, null, null, null, null, null), this.chunkMerger::merge); + return List.of(monoChunk); + }) + .flatMap(mono -> mono); } /** * Creates an embedding vector representing the input text or token array. - * * @param embeddingRequest The embedding request. * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. - * @param Type of the entity in the data list. Can be a {@link String} or {@link List} of tokens (e.g. - * Integers). For embedding multiple inputs in a single request, You can pass a {@link List} of {@link String} or - * {@link List} of {@link List} of tokens. For example: + * @param Type of the entity in the data list. Can be a {@link String} or + * {@link List} of tokens (e.g. Integers). For embedding multiple inputs in a single + * request, You can pass a {@link List} of {@link String} or {@link List} of + * {@link List} of tokens. For example: * *
{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
*/ @@ -206,7 +307,8 @@ public ResponseEntity> embeddings(EmbeddingRequest< Assert.notNull(embeddingRequest, "The request body can not be null."); - // Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single + // Input text to embed, encoded as a string or array of tokens. To embed multiple + // inputs in a single // request, pass an array of strings or array of token arrays. Assert.notNull(embeddingRequest.input(), "The input can not be null."); Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, @@ -215,17 +317,49 @@ public ResponseEntity> embeddings(EmbeddingRequest< if (embeddingRequest.input() instanceof List list) { Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); Assert.isTrue(list.size() <= 512, "The list must be 512 dimensions or less"); - Assert.isTrue(list.get(0) instanceof String || list.get(0) instanceof Integer - || list.get(0) instanceof List, + Assert.isTrue( + list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List, "The input must be either a String, or a List of Strings or list of list of integers."); } return this.restClient.post() - .uri("/v4/embeddings") - .body(embeddingRequest) - .retrieve() - .toEntity(new ParameterizedTypeReference<>() { - }); + .uri(this.embeddingsPath) + .headers(this::addDefaultHeadersIfMissing) + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + }); + } + + private void addDefaultHeadersIfMissing(HttpHeaders headers) { + if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) { + headers.setBearerAuth(this.apiKey.getValue()); + } + } + + // Package-private getters for mutate/copy + String getBaseUrl() { + return this.baseUrl; + } + + ApiKey getApiKey() { + return this.apiKey; + } + + MultiValueMap getHeaders() { + return this.headers; + } + + String getCompletionsPath() { + return this.completionsPath; + } + + String getEmbeddingsPath() { + return this.embeddingsPath; + } + + ResponseErrorHandler getResponseErrorHandler() { + return this.responseErrorHandler; } /** @@ -233,14 +367,21 @@ public ResponseEntity> embeddings(EmbeddingRequest< * ZhiPuAI Model. */ public enum ChatModel implements ChatModelDescription { + + // @formatter:off GLM_4("GLM-4"), + GLM_4V("glm-4v"), + GLM_4_Air("glm-4-air"), + GLM_4_AirX("glm-4-airx"), + GLM_4_Flash("glm-4-flash"), - GLM_3_Turbo("GLM-3-Turbo"); - public final String value; + GLM_3_Turbo("GLM-3-Turbo"); // @formatter:on + + public final String value; ChatModel(String value) { this.value = value; @@ -254,12 +395,14 @@ public String getValue() { public String getName() { return this.value; } + } /** * The reason the model stopped generating tokens. */ public enum ChatCompletionFinishReason { + /** * The model hit a natural stop point or a provided stop sequence. */ @@ -285,6 +428,7 @@ public enum ChatCompletionFinishReason { */ @JsonProperty("tool_call") TOOL_CALL + } /** @@ -303,7 +447,7 @@ public enum EmbeddingModel { */ Embedding_3("Embedding-3"); - public final String value; + public final String value; EmbeddingModel(String value) { this.value = value; @@ -312,11 +456,12 @@ public enum EmbeddingModel { public String getValue() { return this.value; } - } + } /** - * Represents a tool the model may call. Currently, only functions are supported as a tool. + * Represents a tool the model may call. Currently, only functions are supported as a + * tool. */ @JsonInclude(JsonInclude.Include.NON_NULL) public static class FunctionTool { @@ -325,7 +470,7 @@ public static class FunctionTool { @JsonProperty("type") private Type type = Type.FUNCTION; - // The function definition. + // The function definition. @JsonProperty("function") private Function function; @@ -338,9 +483,7 @@ public FunctionTool() { * @param type the tool type * @param function function definition */ - public FunctionTool( - Type type, - Function function) { + public FunctionTool(Type type, Function function) { this.type = type; this.function = function; } @@ -373,11 +516,13 @@ public void setFunction(Function function) { * Create a tool of type 'function' and the given function definition. */ public enum Type { + /** * Function tool type. */ @JsonProperty("function") FUNCTION + } /** @@ -403,18 +548,15 @@ private Function() { /** * Create tool function definition. - * - * @param description A description of what the function does, used by the model to choose when and how to call - * the function. - * @param name The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, - * with a maximum length of 64. - * @param parameters The parameters the functions accepts, described as a JSON Schema object. To describe a - * function that accepts no parameters, provide the value {"type": "object", "properties": {}}. + * @param description A description of what the function does, used by the + * model to choose when and how to call the function. + * @param name The name of the function to be called. Must be a-z, A-Z, 0-9, + * or contain underscores and dashes, with a maximum length of 64. + * @param parameters The parameters the functions accepts, described as a JSON + * Schema object. To describe a function that accepts no parameters, provide + * the value {"type": "object", "properties": {}}. */ - public Function( - String description, - String name, - Map parameters) { + public Function(String description, String name, Map parameters) { this.description = description; this.name = name; this.parameters = parameters; @@ -422,7 +564,6 @@ public Function( /** * Create tool function definition. - * * @param description tool function description. * @param name tool function name. * @param jsonSchema tool function schema as json. @@ -467,6 +608,7 @@ public void setJsonSchema(String jsonSchema) { } } + } /** @@ -474,32 +616,40 @@ public void setJsonSchema(String jsonSchema) { * * @param messages A list of messages comprising the conversation so far. * @param model ID of the model to use. - * @param maxTokens The maximum number of tokens to generate in the chat completion. The total length of input - * tokens and generated tokens is limited by the model's context length. + * @param maxTokens The maximum number of tokens to generate in the chat completion. + * The total length of input tokens and generated tokens is limited by the model's + * context length. * @param stop Up to 4 sequences where the API will stop generating further tokens. - * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events as - * they become available, with the stream terminated by a data: [DONE] message. - * @param temperature What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output - * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend - * altering this or top_p but not both. - * @param topP An alternative to sampling with temperature, called nucleus sampling, where the model considers the - * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - * probability mass are considered. We generally recommend altering this or temperature but not both. - * @param tools A list of tools the model may call. Currently, only functions are supported as a tool. Use this to - * provide a list of functions the model may generate JSON inputs for. - * @param toolChoice Controls which (if any) function is called by the model. none means the model will not call a - * function and instead generates a message. auto means the model can pick between generating a message or calling a - * function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces - * the model to call that function. none is the default when no functions are present. auto is the default if - * functions are present. Use the {@link ToolChoiceBuilder} to create the tool choice value. - * @param user A unique identifier representing your end-user, which can help ZhiPuAI to monitor and detect abuse. - * @param requestId A unique identifier for the request. If set, the request will be logged and can be used for - * debugging purposes. - * @param doSample If set, the model will use sampling to generate the next token. If not set, the model will use - * greedy decoding to generate the next token. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as + * data-only server-sent events as they become available, with the stream terminated + * by a data: [DONE] message. + * @param temperature What sampling temperature to use, between 0 and 1. Higher values + * like 0.8 will make the output more random, while lower values like 0.2 will make it + * more focused and deterministic. We generally recommend altering this or top_p but + * not both. + * @param topP An alternative to sampling with temperature, called nucleus sampling, + * where the model considers the results of the tokens with top_p probability mass. So + * 0.1 means only the tokens comprising the top 10% probability mass are considered. + * We generally recommend altering this or temperature but not both. + * @param tools A list of tools the model may call. Currently, only functions are + * supported as a tool. Use this to provide a list of functions the model may generate + * JSON inputs for. + * @param toolChoice Controls which (if any) function is called by the model. none + * means the model will not call a function and instead generates a message. auto + * means the model can pick between generating a message or calling a function. + * Specifying a particular function via {"type: "function", "function": {"name": + * "my_function"}} forces the model to call that function. none is the default when no + * functions are present. auto is the default if functions are present. Use the + * {@link ToolChoiceBuilder} to create the tool choice value. + * @param user A unique identifier representing your end-user, which can help ZhiPuAI + * to monitor and detect abuse. + * @param requestId A unique identifier for the request. If set, the request will be + * logged and can be used for debugging purposes. + * @param doSample If set, the model will use sampling to generate the next token. If + * not set, the model will use greedy decoding to generate the next token. */ @JsonInclude(Include.NON_NULL) - public record ChatCompletionRequest( + public record ChatCompletionRequest(// @formatter:off @JsonProperty("messages") List messages, @JsonProperty("model") String model, @JsonProperty("max_tokens") Integer maxTokens, @@ -511,70 +661,73 @@ public record ChatCompletionRequest( @JsonProperty("tool_choice") Object toolChoice, @JsonProperty("user") String user, @JsonProperty("request_id") String requestId, - @JsonProperty("do_sample") Boolean doSample) { + @JsonProperty("do_sample") Boolean doSample) { // @formatter:on /** - * Shortcut constructor for a chat completion request with the given messages and model. - * + * Shortcut constructor for a chat completion request with the given messages and + * model. * @param messages A list of messages comprising the conversation so far. * @param model ID of the model to use. * @param temperature What sampling temperature to use, between 0 and 1. */ public ChatCompletionRequest(List messages, String model, Double temperature) { - this(messages, model, null, null, false, temperature, null, - null, null, null, null, null); + this(messages, model, null, null, false, temperature, null, null, null, null, null, null); } /** - * Shortcut constructor for a chat completion request with the given messages, model and control for streaming. - * + * Shortcut constructor for a chat completion request with the given messages, + * model and control for streaming. * @param messages A list of messages comprising the conversation so far. * @param model ID of the model to use. * @param temperature What sampling temperature to use, between 0 and 1. - * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events - * as they become available, with the stream terminated by a data: [DONE] message. + * @param stream If set, partial message deltas will be sent.Tokens will be sent + * as data-only server-sent events as they become available, with the stream + * terminated by a data: [DONE] message. */ - public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { - this(messages, model, null, null, stream, temperature, null, - null, null, null, null, null); + public ChatCompletionRequest(List messages, String model, Double temperature, + boolean stream) { + this(messages, model, null, null, stream, temperature, null, null, null, null, null, null); } /** - * Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice. - * Streaming is set to false, temperature to 0.8 and all other parameters are null. - * + * Shortcut constructor for a chat completion request with the given messages, + * model, tools and tool choice. Streaming is set to false, temperature to 0.8 and + * all other parameters are null. * @param messages A list of messages comprising the conversation so far. * @param model ID of the model to use. - * @param tools A list of tools the model may call. Currently, only functions are supported as a tool. + * @param tools A list of tools the model may call. Currently, only functions are + * supported as a tool. * @param toolChoice Controls which (if any) function is called by the model. */ - public ChatCompletionRequest(List messages, String model, - List tools, Object toolChoice) { - this(messages, model, null, null, false, 0.8, null, - tools, toolChoice, null, null, null); + public ChatCompletionRequest(List messages, String model, List tools, + Object toolChoice) { + this(messages, model, null, null, false, 0.8, null, tools, toolChoice, null, null, null); } /** - * Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice. - * Streaming is set to false, temperature to 0.8 and all other parameters are null. - * + * Shortcut constructor for a chat completion request with the given messages, + * model, tools and tool choice. Streaming is set to false, temperature to 0.8 and + * all other parameters are null. * @param messages A list of messages comprising the conversation so far. - * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events - * as they become available, with the stream terminated by a data: [DONE] message. + * @param stream If set, partial message deltas will be sent.Tokens will be sent + * as data-only server-sent events as they become available, with the stream + * terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, Boolean stream) { - this(messages, null, null, null, stream, null, null, - null, null, null, null, null); + this(messages, null, null, null, stream, null, null, null, null, null, null, null); } /** - * Helper factory that creates a tool_choice of type 'none', 'auto' or selected function by name. + * Helper factory that creates a tool_choice of type 'none', 'auto' or selected + * function by name. */ public static class ToolChoiceBuilder { + /** * Model can pick between generating a message or calling a function. */ public static final String AUTO = "auto"; + /** * Model will not call a function and instead generates a message */ @@ -586,43 +739,46 @@ public static class ToolChoiceBuilder { public static Object function(String functionName) { return Map.of("type", "function", "function", Map.of("name", functionName)); } + } /** * An object specifying the format that the model must output. + * * @param type Must be one of 'text' or 'json_object'. */ @JsonInclude(Include.NON_NULL) - public record ResponseFormat( - @JsonProperty("type") String type) { + public record ResponseFormat(@JsonProperty("type") String type) { } } /** * Message comprising the conversation. * - * @param rawContent The contents of the message. Can be either a {@link MediaContent} or a {@link String}. - * The response message content is always a {@link String}. - * @param role The role of the messages author. Could be one of the {@link Role} types. - * @param name An optional name for the participant. Provides the model information to differentiate between - * participants of the same role. In case of Function calling, the name is the function name that the message is - * responding to. - * @param toolCallId Tool call that this message is responding to. Only applicable for the {@link Role#TOOL} role - * and null otherwise. - * @param toolCalls The tool calls generated by the model, such as function calls. Applicable only for - * {@link Role#ASSISTANT} role and null otherwise. + * @param rawContent The contents of the message. Can be either a {@link MediaContent} + * or a {@link String}. The response message content is always a {@link String}. + * @param role The role of the messages author. Could be one of the {@link Role} + * types. + * @param name An optional name for the participant. Provides the model information to + * differentiate between participants of the same role. In case of Function calling, + * the name is the function name that the message is responding to. + * @param toolCallId Tool call that this message is responding to. Only applicable for + * the {@link Role#TOOL} role and null otherwise. + * @param toolCalls The tool calls generated by the model, such as function calls. + * Applicable only for {@link Role#ASSISTANT} role and null otherwise. */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record ChatCompletionMessage( + public record ChatCompletionMessage(// @formatter:off @JsonProperty("content") Object rawContent, @JsonProperty("role") Role role, @JsonProperty("name") String name, @JsonProperty("tool_call_id") String toolCallId, - @JsonProperty("tool_calls") List toolCalls) { + @JsonProperty("tool_calls") List toolCalls) { // @formatter:on /** - * Create a chat completion message with the given content and role. All other fields are null. + * Create a chat completion message with the given content and role. All other + * fields are null. * @param content The contents of the message. * @param role The role of the author of this message. */ @@ -647,6 +803,7 @@ public String content() { * The role of the author of this message. */ public enum Role { + /** * System message. */ @@ -671,21 +828,21 @@ public enum Role { } /** - * An array of content parts with a defined type. - * Each MediaContent can be of either "text" or "image_url" type. Not both. + * An array of content parts with a defined type. Each MediaContent can be of + * either "text" or "image_url" type. Not both. * - * @param type Content type, each can be of type text or image_url. + * @param type Content type, each can be of type text or image_url. * @param text The text content of the message. - * @param imageUrl The image content of the message. You can pass multiple - * images by adding multiple image_url content parts. Image input is only - * supported when using the glm-4v model. + * @param imageUrl The image content of the message. You can pass multiple images + * by adding multiple image_url content parts. Image input is only supported when + * using the glm-4v model. */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record MediaContent( - @JsonProperty("type") String type, - @JsonProperty("text") String text, - @JsonProperty("image_url") ImageUrl imageUrl) { + public record MediaContent(// @formatter:off + @JsonProperty("type") String type, + @JsonProperty("text") String text, + @JsonProperty("image_url") ImageUrl imageUrl) { // @formatter:on /** * Shortcut constructor for a text content. @@ -705,75 +862,82 @@ public MediaContent(ImageUrl imageUrl) { /** * The image content of the message. - * @param url Either a URL of the image or the base64 encoded image data. - * The base64 encoded image data must have a special prefix in the following format: - * "data:{mimetype};base64,{base64-encoded-image-data}". + * + * @param url Either a URL of the image or the base64 encoded image data. The + * base64 encoded image data must have a special prefix in the following + * format: "data:{mimetype};base64,{base64-encoded-image-data}". * @param detail Specifies the detail level of the image. */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record ImageUrl( - @JsonProperty("url") String url, - @JsonProperty("detail") String detail) { + public record ImageUrl(// @formatter:off + @JsonProperty("url") String url, + @JsonProperty("detail") String detail) { // @formatter:on public ImageUrl(String url) { this(url, null); } } } + /** * The relevant tool call. * - * @param id The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the - * Submit tool outputs to run endpoint. - * @param type The type of tool call the output is required for. For now, this is always function. + * @param id The ID of the tool call. This ID must be referenced when you submit + * the tool outputs in using the Submit tool outputs to run endpoint. + * @param type The type of tool call the output is required for. For now, this is + * always function. * @param function The function definition. */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record ToolCall( + public record ToolCall(// @formatter:off @JsonProperty("id") String id, - @JsonProperty("type") String type, - @JsonProperty("function") ChatCompletionFunction function) { + @JsonProperty("type") String type, + @JsonProperty("function") ChatCompletionFunction function) { // @formatter:on } /** * The function definition. * * @param name The name of the function. - * @param arguments The arguments that the model expects you to pass to the function. + * @param arguments The arguments that the model expects you to pass to the + * function. */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record ChatCompletionFunction( + public record ChatCompletionFunction(// @formatter:off @JsonProperty("name") String name, - @JsonProperty("arguments") String arguments) { + @JsonProperty("arguments") String arguments) { // @formatter:on } } /** - * Represents a chat completion response returned by model, based on the provided input. + * Represents a chat completion response returned by model, based on the provided + * input. * * @param id A unique identifier for the chat completion. - * @param choices A list of chat completion choices. Can be more than one if n is greater than 1. - * @param created The Unix timestamp (in seconds) of when the chat completion was created. + * @param choices A list of chat completion choices. Can be more than one if n is + * greater than 1. + * @param created The Unix timestamp (in seconds) of when the chat completion was + * created. * @param model The model used for the chat completion. - * @param systemFingerprint This fingerprint represents the backend configuration that the model runs with. Can be - * used in conjunction with the seed request parameter to understand when backend changes have been made that might - * impact determinism. + * @param systemFingerprint This fingerprint represents the backend configuration that + * the model runs with. Can be used in conjunction with the seed request parameter to + * understand when backend changes have been made that might impact determinism. * @param object The object type, which is always chat.completion. * @param usage Usage statistics for the completion request. */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record ChatCompletion( + public record ChatCompletion(// @formatter:off @JsonProperty("id") String id, @JsonProperty("choices") List choices, @JsonProperty("created") Long created, @JsonProperty("model") String model, @JsonProperty("system_fingerprint") String systemFingerprint, @JsonProperty("object") String object, - @JsonProperty("usage") Usage usage) { + @JsonProperty("usage") Usage usage) { // @formatter:on /** * Chat completion choice. @@ -785,11 +949,11 @@ public record ChatCompletion( */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record Choice( + public record Choice(// @formatter:off @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, @JsonProperty("index") Integer index, @JsonProperty("message") ChatCompletionMessage message, - @JsonProperty("logprobs") LogProbs logprobs) { + @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on } } @@ -801,8 +965,8 @@ public record Choice( */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record LogProbs( - @JsonProperty("content") List content) { + public record LogProbs(// @formatter:off + @JsonProperty("content") List content) { // @formatter:on /** * Message content tokens with log probability information. @@ -812,35 +976,37 @@ public record LogProbs( * @param probBytes A list of integers representing the UTF-8 bytes representation * of the token. Useful in instances where characters are represented by multiple * tokens and their byte representations must be combined to generate the correct - * text representation. Can be null if there is no bytes representation for the token. - * @param topLogprobs List of the most likely tokens and their log probability, - * at this token position. In rare cases, there may be fewer than the number of + * text representation. Can be null if there is no bytes representation for the + * token. + * @param topLogprobs List of the most likely tokens and their log probability, at + * this token position. In rare cases, there may be fewer than the number of * requested top_logprobs returned. */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record Content( + public record Content(// @formatter:off @JsonProperty("token") String token, @JsonProperty("logprob") Float logprob, @JsonProperty("bytes") List probBytes, - @JsonProperty("top_logprobs") List topLogprobs) { + @JsonProperty("top_logprobs") List topLogprobs) { // @formatter:on /** * The most likely tokens and their log probability, at this token position. * * @param token The token. * @param logprob The log probability of the token. - * @param probBytes A list of integers representing the UTF-8 bytes representation - * of the token. Useful in instances where characters are represented by multiple - * tokens and their byte representations must be combined to generate the correct - * text representation. Can be null if there is no bytes representation for the token. + * @param probBytes A list of integers representing the UTF-8 bytes + * representation of the token. Useful in instances where characters are + * represented by multiple tokens and their byte representations must be + * combined to generate the correct text representation. Can be null if there + * is no bytes representation for the token. */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record TopLogProbs( + public record TopLogProbs(// @formatter:off @JsonProperty("token") String token, @JsonProperty("logprob") Float logprob, - @JsonProperty("bytes") List probBytes) { + @JsonProperty("bytes") List probBytes) { // @formatter:on } } } @@ -848,41 +1014,45 @@ public record TopLogProbs( /** * Usage statistics for the completion request. * - * @param completionTokens Number of tokens in the generated completion. Only applicable for completion requests. + * @param completionTokens Number of tokens in the generated completion. Only + * applicable for completion requests. * @param promptTokens Number of tokens in the prompt. - * @param totalTokens Total number of tokens used in the request (prompt + completion). + * @param totalTokens Total number of tokens used in the request (prompt + + * completion). */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record Usage( + public record Usage(// @formatter:off @JsonProperty("completion_tokens") Integer completionTokens, @JsonProperty("prompt_tokens") Integer promptTokens, - @JsonProperty("total_tokens") Integer totalTokens) { + @JsonProperty("total_tokens") Integer totalTokens) { // @formatter:on } /** - * Represents a streamed chunk of a chat completion response returned by model, based on the provided input. + * Represents a streamed chunk of a chat completion response returned by model, based + * on the provided input. * * @param id A unique identifier for the chat completion. Each chunk has the same ID. - * @param choices A list of chat completion choices. Can be more than one if n is greater than 1. - * @param created The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same - * timestamp. + * @param choices A list of chat completion choices. Can be more than one if n is + * greater than 1. + * @param created The Unix timestamp (in seconds) of when the chat completion was + * created. Each chunk has the same timestamp. * @param model The model used for the chat completion. - * @param systemFingerprint This fingerprint represents the backend configuration that the model runs with. Can be - * used in conjunction with the seed request parameter to understand when backend changes have been made that might - * impact determinism. + * @param systemFingerprint This fingerprint represents the backend configuration that + * the model runs with. Can be used in conjunction with the seed request parameter to + * understand when backend changes have been made that might impact determinism. * @param object The object type, which is always 'chat.completion.chunk'. */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record ChatCompletionChunk( + public record ChatCompletionChunk(// @formatter:off @JsonProperty("id") String id, @JsonProperty("choices") List choices, @JsonProperty("created") Long created, @JsonProperty("model") String model, @JsonProperty("system_fingerprint") String systemFingerprint, - @JsonProperty("object") String object) { + @JsonProperty("object") String object) { // @formatter:on /** * Chat completion choice. @@ -894,11 +1064,11 @@ public record ChatCompletionChunk( */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record ChunkChoice( + public record ChunkChoice(// @formatter:off @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, @JsonProperty("index") Integer index, @JsonProperty("delta") ChatCompletionMessage delta, - @JsonProperty("logprobs") LogProbs logprobs) { + @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on } } @@ -906,52 +1076,27 @@ public record ChunkChoice( * Represents an embedding vector returned by embedding endpoint. * * @param index The index of the embedding in the list of embeddings. - * @param embedding The embedding vector, which is a list of floats. The length of vector depends on the model. + * @param embedding The embedding vector, which is a list of floats. The length of + * vector depends on the model. * @param object The object type, which is always 'embedding'. */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record Embedding( + public record Embedding(// @formatter:off @JsonProperty("index") Integer index, @JsonProperty("embedding") float[] embedding, - @JsonProperty("object") String object) { + @JsonProperty("object") String object) { // @formatter:on /** - * Create an embedding with the given index, embedding and object type set to 'embedding'. - * + * Create an embedding with the given index, embedding and object type set to + * 'embedding'. * @param index The index of the embedding in the list of embeddings. - * @param embedding The embedding vector, which is a list of floats. The length of vector depends on the model. + * @param embedding The embedding vector, which is a list of floats. The length of + * vector depends on the model. */ public Embedding(Integer index, float[] embedding) { this(index, embedding, "embedding"); } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof Embedding embedding1)) { - return false; - } - return Objects.equals(this.index, embedding1.index) && Arrays.equals(this.embedding, embedding1.embedding) && Objects.equals(this.object, embedding1.object); - } - - @Override - public int hashCode() { - int result = Objects.hash(this.index, this.object); - result = 31 * result + Arrays.hashCode(this.embedding); - return result; - } - - @Override - public String toString() { - return "Embedding{" + - "index=" + this.index + - ", embedding=" + Arrays.toString(this.embedding) + - ", object='" + this.object + '\'' + - '}'; - } } /** @@ -962,24 +1107,22 @@ public String toString() { * @param model ID of the model to use. */ @JsonInclude(Include.NON_NULL) - public record EmbeddingRequest( + public record EmbeddingRequest(// @formatter:off @JsonProperty("input") T input, @JsonProperty("model") String model, - @JsonProperty("dimensions") Integer dimensions) { - + @JsonProperty("dimensions") Integer dimensions) { // @formatter:on /** - * Create an embedding request with the given input. Encoding model is set to 'embedding-2'. - * - * @param input Input text to embed. - */ + * Create an embedding request with the given input. Encoding model is set to + * 'embedding-2'. + * @param input Input text to embed. + */ public EmbeddingRequest(T input) { this(input, DEFAULT_EMBEDDING_MODEL, null); } /** * Create an embedding request with the given input and model. - * * @param input * @param model */ @@ -999,12 +1142,104 @@ public EmbeddingRequest(T input, String model) { */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record EmbeddingList( + public record EmbeddingList(// @formatter:off @JsonProperty("object") String object, @JsonProperty("data") List data, @JsonProperty("model") String model, - @JsonProperty("usage") Usage usage) { + @JsonProperty("usage") Usage usage) { // @formatter:on + } + + public static class Builder { + + private Builder() { + } + + public Builder(ZhiPuAiApi api) { + this.baseUrl = api.getBaseUrl(); + this.apiKey = api.getApiKey(); + this.headers = new LinkedMultiValueMap<>(api.getHeaders()); + this.completionsPath = api.getCompletionsPath(); + this.embeddingsPath = api.getEmbeddingsPath(); + this.restClientBuilder = api.restClient != null ? api.restClient.mutate() : RestClient.builder(); + this.webClientBuilder = api.webClient != null ? api.webClient.mutate() : WebClient.builder(); + this.responseErrorHandler = api.getResponseErrorHandler(); + } + + private String baseUrl = ZhiPuApiConstants.DEFAULT_BASE_URL; + + private ApiKey apiKey; + + private MultiValueMap headers = new LinkedMultiValueMap<>(); + + private String completionsPath = DEFAULT_COMPLETIONS_PATH; + + private String embeddingsPath = DEFAULT_EMBEDDINGS_PATH; + + private RestClient.Builder restClientBuilder = RestClient.builder(); + + private WebClient.Builder webClientBuilder = WebClient.builder(); + + private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER; + + public Builder baseUrl(String baseUrl) { + Assert.hasText(baseUrl, "baseUrl cannot be null or empty"); + this.baseUrl = baseUrl; + return this; + } + + public Builder apiKey(ApiKey apiKey) { + Assert.notNull(apiKey, "apiKey cannot be null"); + this.apiKey = apiKey; + return this; + } + + public Builder apiKey(String simpleApiKey) { + this.apiKey = new SimpleApiKey(simpleApiKey); + return this; + } + + public Builder headers(MultiValueMap headers) { + Assert.notNull(headers, "headers cannot be null"); + this.headers = headers; + return this; + } + + public Builder completionsPath(String completionsPath) { + Assert.hasText(completionsPath, "completionsPath cannot be null or empty"); + this.completionsPath = completionsPath; + return this; + } + + public Builder embeddingsPath(String embeddingsPath) { + Assert.hasText(embeddingsPath, "embeddingsPath cannot be null or empty"); + this.embeddingsPath = embeddingsPath; + return this; + } + + public Builder restClientBuilder(RestClient.Builder restClientBuilder) { + Assert.notNull(restClientBuilder, "restClientBuilder cannot be null"); + this.restClientBuilder = restClientBuilder; + return this; + } + + public Builder webClientBuilder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "webClientBuilder cannot be null"); + this.webClientBuilder = webClientBuilder; + return this; + } + + public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) { + Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null"); + this.responseErrorHandler = responseErrorHandler; + return this; + } + + public ZhiPuAiApi build() { + Assert.notNull(this.apiKey, "apiKey must be set"); + return new ZhiPuAiApi(this.baseUrl, this.apiKey, this.headers, this.completionsPath, this.embeddingsPath, + this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler); + } + } } -// @formatter:on diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java index a175a3058fa..2f514f63932 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ public class ChatCompletionRequestTests { @Test public void createRequestWithChatOptions() { - var client = new ZhiPuAiChatModel(new ZhiPuAiApi("TEST"), + var client = new ZhiPuAiChatModel(ZhiPuAiApi.builder().apiKey("TEST").build(), ZhiPuAiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build()); var prompt = client.buildRequestPrompt(new Prompt("Test message content")); @@ -63,7 +63,7 @@ public void promptOptionsTools() { final String TOOL_FUNCTION_NAME = "CurrentWeather"; - var client = new ZhiPuAiChatModel(new ZhiPuAiApi("TEST"), + var client = new ZhiPuAiChatModel(ZhiPuAiApi.builder().apiKey("TEST").build(), ZhiPuAiChatOptions.builder().model("DEFAULT_MODEL").build()); var request = client.createRequest(new Prompt("Test message content", @@ -89,7 +89,7 @@ public void defaultOptionsTools() { final String TOOL_FUNCTION_NAME = "CurrentWeather"; - var client = new ZhiPuAiChatModel(new ZhiPuAiApi("TEST"), + var client = new ZhiPuAiChatModel(ZhiPuAiApi.builder().apiKey("TEST").build(), ZhiPuAiChatOptions.builder() .model("DEFAULT_MODEL") .toolCallbacks(List.of(FunctionToolCallback.builder(TOOL_FUNCTION_NAME, new MockWeatherService()) diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java index 00a760cb1a2..dee82d1bc59 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ public class ZhiPuAiTestConfiguration { @Bean public ZhiPuAiApi zhiPuAiApi() { - return new ZhiPuAiApi(getApiKey()); + return ZhiPuAiApi.builder().apiKey(getApiKey()).build(); } @Bean diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiBuilderTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiBuilderTests.java new file mode 100644 index 00000000000..b193a5a9672 --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiBuilderTests.java @@ -0,0 +1,333 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.zhipuai.api; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.opentest4j.AssertionFailedError; + +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +class ZhiPuAiApiBuilderTests { + + private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key"); + + private static final String TEST_BASE_URL = "https://test.openai.com"; + + private static final String TEST_COMPLETIONS_PATH = "/test/completions"; + + private static final String TEST_EMBEDDINGS_PATH = "/test/embeddings"; + + @Test + void testMinimalBuilder() { + ZhiPuAiApi api = ZhiPuAiApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testFullBuilder() { + MultiValueMap headers = new LinkedMultiValueMap<>(); + headers.add("Custom-Header", "test-value"); + RestClient.Builder restClientBuilder = RestClient.builder(); + WebClient.Builder webClientBuilder = WebClient.builder(); + ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); + + ZhiPuAiApi api = ZhiPuAiApi.builder() + .apiKey(TEST_API_KEY) + .baseUrl(TEST_BASE_URL) + .headers(headers) + .completionsPath(TEST_COMPLETIONS_PATH) + .embeddingsPath(TEST_EMBEDDINGS_PATH) + .restClientBuilder(restClientBuilder) + .webClientBuilder(webClientBuilder) + .responseErrorHandler(errorHandler) + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testDefaultValues() { + ZhiPuAiApi api = ZhiPuAiApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + // We can't directly test the default values as they're private fields, + // but we know the builder succeeded with defaults + } + + @Test + void testMissingApiKey() { + assertThatThrownBy(() -> ZhiPuAiApi.builder().build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("apiKey must be set"); + } + + @Test + void testInvalidBaseUrl() { + assertThatThrownBy(() -> ZhiPuAiApi.builder().baseUrl("").build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + + assertThatThrownBy(() -> ZhiPuAiApi.builder().baseUrl(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + } + + @Test + void testInvalidHeaders() { + assertThatThrownBy(() -> ZhiPuAiApi.builder().headers(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("headers cannot be null"); + } + + @Test + void testInvalidCompletionsPath() { + assertThatThrownBy(() -> ZhiPuAiApi.builder().completionsPath("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("completionsPath cannot be null or empty"); + + assertThatThrownBy(() -> ZhiPuAiApi.builder().completionsPath(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("completionsPath cannot be null or empty"); + } + + @Test + void testInvalidEmbeddingsPath() { + assertThatThrownBy(() -> ZhiPuAiApi.builder().embeddingsPath("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("embeddingsPath cannot be null or empty"); + + assertThatThrownBy(() -> ZhiPuAiApi.builder().embeddingsPath(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("embeddingsPath cannot be null or empty"); + } + + @Test + void testInvalidRestClientBuilder() { + assertThatThrownBy(() -> ZhiPuAiApi.builder().restClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("restClientBuilder cannot be null"); + } + + @Test + void testInvalidWebClientBuilder() { + assertThatThrownBy(() -> ZhiPuAiApi.builder().webClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("webClientBuilder cannot be null"); + } + + @Test + void testInvalidResponseErrorHandler() { + assertThatThrownBy(() -> ZhiPuAiApi.builder().responseErrorHandler(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("responseErrorHandler cannot be null"); + } + + /** + * Tests the behavior of the {@link ZhiPuAiApi} class when using dynamic API + *

+ * This test refers to OpenAiApiBuilderTests. + */ + @Nested + class MockRequests { + + MockWebServer mockWebServer; + + @BeforeEach + void setUp() throws IOException { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + } + + @AfterEach + void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Test + void dynamicApiKeyRestClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + ZhiPuAiApi api = ZhiPuAiApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + {} + """); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + ZhiPuAiApi.ChatCompletionMessage chatCompletionMessage = new ZhiPuAiApi.ChatCompletionMessage("Hello world", + ZhiPuAiApi.ChatCompletionMessage.Role.USER); + ZhiPuAiApi.ChatCompletionRequest request = new ZhiPuAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false); + ResponseEntity response = api.chatCompletionEntity(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.chatCompletionEntity(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + @Test + void dynamicApiKeyRestClientWithAdditionalAuthorizationHeader() throws InterruptedException { + ZhiPuAiApi api = ZhiPuAiApi.builder().apiKey(() -> { + throw new AssertionFailedError("Should not be called, API key is provided in headers"); + }).baseUrl(mockWebServer.url("/").toString()).build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + {} + """); + mockWebServer.enqueue(mockResponse); + + ZhiPuAiApi.ChatCompletionMessage chatCompletionMessage = new ZhiPuAiApi.ChatCompletionMessage("Hello world", + ZhiPuAiApi.ChatCompletionMessage.Role.USER); + ZhiPuAiApi.ChatCompletionRequest request = new ZhiPuAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false); + + MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key"); + ResponseEntity response = api.chatCompletionEntity(request, additionalHeaders); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key"); + } + + @Test + void dynamicApiKeyWebClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + ZhiPuAiApi api = ZhiPuAiApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + {} + """.replace("\n", "")); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + ZhiPuAiApi.ChatCompletionMessage chatCompletionMessage = new ZhiPuAiApi.ChatCompletionMessage("Hello world", + ZhiPuAiApi.ChatCompletionMessage.Role.USER); + ZhiPuAiApi.ChatCompletionRequest request = new ZhiPuAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true); + List response = api.chatCompletionStream(request).collectList().block(); + assertThat(response).hasSize(1); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.chatCompletionStream(request).collectList().block(); + assertThat(response).hasSize(1); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + @Test + void dynamicApiKeyWebClientWithAdditionalAuthorizationHeader() throws InterruptedException { + ZhiPuAiApi api = ZhiPuAiApi.builder().apiKey(() -> { + throw new AssertionFailedError("Should not be called, API key is provided in headers"); + }).baseUrl(mockWebServer.url("/").toString()).build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + {} + """.replace("\n", "")); + mockWebServer.enqueue(mockResponse); + + ZhiPuAiApi.ChatCompletionMessage chatCompletionMessage = new ZhiPuAiApi.ChatCompletionMessage("Hello world", + ZhiPuAiApi.ChatCompletionMessage.Role.USER); + ZhiPuAiApi.ChatCompletionRequest request = new ZhiPuAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true); + MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key"); + List response = api.chatCompletionStream(request, additionalHeaders) + .collectList() + .block(); + assertThat(response).hasSize(1); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key"); + } + + @Test + void dynamicApiKeyRestClientEmbeddings() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + ZhiPuAiApi api = ZhiPuAiApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + {} + """); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + ZhiPuAiApi.EmbeddingRequest request = new ZhiPuAiApi.EmbeddingRequest<>("Hello world"); + ResponseEntity> response = api.embeddings(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.embeddings(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + } + +} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java index 44f9cd79f63..27a376ab0e7 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ @EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".+") public class ZhiPuAiApiIT { - ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); + ZhiPuAiApi zhiPuAiApi = ZhiPuAiApi.builder().apiKey(System.getenv("ZHIPU_AI_API_KEY")).build(); @Test void chatCompletionEntity() { diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java index c45b8b0171b..05e4341ba2d 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,7 +48,7 @@ public class ZhiPuAiApiToolFunctionCallIT { MockWeatherService weatherService = new MockWeatherService(); - ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); + ZhiPuAiApi zhiPuAiApi = ZhiPuAiApi.builder().apiKey(System.getenv("ZHIPU_AI_API_KEY")).build(); private static T fromJson(String json, Class targetClass) { try { diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java index e24f846c02c..953c7c3bb4e 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -159,7 +159,7 @@ public TestObservationRegistry observationRegistry() { @Bean public ZhiPuAiApi zhiPuAiApi() { - return new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); + return ZhiPuAiApi.builder().apiKey(System.getenv("ZHIPU_AI_API_KEY")).build(); } @Bean diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java index 4238088e890..f6a33037566 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/ZhiPuAiEmbeddingModelObservationIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,7 +99,7 @@ public TestObservationRegistry observationRegistry() { @Bean public ZhiPuAiApi zhiPuAiApi() { - return new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); + return ZhiPuAiApi.builder().apiKey(System.getenv("ZHIPU_AI_API_KEY")).build(); } @Bean