Skip to content

Commit

Permalink
Merge pull request #171 from yuluo-yx/1207-yuluo/add-json-mode
Browse files Browse the repository at this point in the history
feat: add response_format params
  • Loading branch information
chickenlj authored Dec 10, 2024
2 parents 21f1ab0 + 64eff6a commit 37c8643
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
package com.alibaba.cloud.ai.dashscope.api;

import java.io.File;
import java.io.FileInputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import com.alibaba.cloud.ai.dashscope.common.DashScopeException;
import com.alibaba.cloud.ai.dashscope.common.ErrorCodeEnum;
import com.alibaba.cloud.ai.dashscope.rag.DashScopeDocumentRetrieverOptions;
Expand All @@ -8,30 +20,27 @@
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.boot.context.properties.bind.ConstructorBinding;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.InputStreamResource;
import org.springframework.http.*;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.io.File;
import java.io.FileInputStream;
import java.net.URI;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import static com.alibaba.cloud.ai.dashscope.common.DashScopeApiConstants.DEFAULT_BASE_URL;

Expand Down Expand Up @@ -858,6 +867,7 @@ public record ChatCompletionRequestParameter(@JsonProperty("result_format") Stri
@JsonProperty("repetition_penalty") Double repetitionPenalty,
@JsonProperty("presence_penalty") Double presencePenalty, @JsonProperty("temperature") Double temperature,
@JsonProperty("stop") List<Object> stop, @JsonProperty("enable_search") Boolean enableSearch,
@JsonProperty("response_format") DashScopeResponseFormat responseFormat,
@JsonProperty("incremental_output") Boolean incrementalOutput,
@JsonProperty("tools") List<FunctionTool> tools, @JsonProperty("tool_choice") Object toolChoice,
@JsonProperty("stream") Boolean stream,
Expand All @@ -867,7 +877,7 @@ public record ChatCompletionRequestParameter(@JsonProperty("result_format") Stri
* shortcut constructor for chat request parameter
*/
public ChatCompletionRequestParameter() {
this(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);
this(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);
}

/**
Expand All @@ -894,15 +904,6 @@ public static Object function(String functionName) {
}

}

/**
* An object specifying the format that the model must output.
*
* @param type Must be one of 'text' or 'json_object'.
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public record ResponseFormat(@JsonProperty("type") String type) {
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.cloud.ai.dashscope.api;

import java.util.Objects;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;

/**
* Lets you specify the format of the returned content. Valid values: {"type": "text"} or
* {"type": "json_object"}. When set to {"type": "json_object"}, a JSON string in standard
* format is output. Params reference:
* <a href="https://help.aliyun.com/zh/dashscope/developer-reference/qwen-api">...</a>
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
*/

@JsonInclude(JsonInclude.Include.NON_NULL)
public class DashScopeResponseFormat {

/**
* Parameters must be one of 'text' or 'json_object'.
*/
@JsonProperty("type")
private Type type;

public Type getType() {

return type;
}

public void setType(Type type) {

this.type = type;
}

public DashScopeResponseFormat() {
}

public DashScopeResponseFormat(Type type) {

this.type = type;
}

public static Builder builder() {

return new Builder();
}

/**
* Builder for {@link DashScopeResponseFormat}.
*/
public static class Builder {

private Type type;

public Builder type(Type type) {

this.type = type;
return this;
}

public DashScopeResponseFormat build() {

return new DashScopeResponseFormat(this.type);
}

}

@Override
public String toString() {

return "DashScopeResponseFormat { " + "type='" + type + '}';
}

@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
DashScopeResponseFormat that = (DashScopeResponseFormat) o;
return Objects.equals(type, that.type);
}

@Override
public int hashCode() {
return Objects.hash(type);
}

/**
* ResponseFormat type. Valid values: {"type": "text"} or {"type": "json_object"}.
*/
public enum Type {

/**
* Generates a text response. (default)
*/
@JsonProperty("text")
TEXT,

/**
* Enables JSON mode, which guarantees the message the model generates is valid
* JSON string.
*/
@JsonProperty("json_object")
JSON_OBJECT,

}

}
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.cloud.ai.dashscope.chat;

import java.util.ArrayList;
Expand All @@ -9,6 +25,19 @@
import java.util.concurrent.ConcurrentHashMap;

import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletion;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionChunk;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionFinishReason;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionMessage;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionMessage.ChatCompletionFunction;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionMessage.MediaContent;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionMessage.ToolCall;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionOutput;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionOutput.Choice;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionRequest;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionRequestInput;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionRequestParameter;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.FunctionTool;
import com.alibaba.cloud.ai.dashscope.chat.observation.DashScopeChatModelObservationConvention;
import com.alibaba.cloud.ai.dashscope.metadata.DashScopeAiUsage;
import com.alibaba.cloud.ai.observation.conventions.AiProvider;
Expand All @@ -17,16 +46,7 @@
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.*;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import reactor.core.publisher.Flux;

import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.*;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionMessage.*;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionOutput.Choice;
import reactor.core.publisher.Mono;

import org.springframework.ai.chat.messages.AssistantMessage;
Expand All @@ -35,6 +55,14 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
Expand Down Expand Up @@ -419,15 +447,29 @@ private List<FunctionTool> getFunctionTools(Set<String> functionNames) {
}

private ChatCompletionRequestParameter toDashScopeRequestParameter(DashScopeChatOptions options, boolean stream) {

if (options == null) {
return new ChatCompletionRequestParameter();
}

Boolean incrementalOutput = stream && options.getIncrementalOutput();
return new ChatCompletionRequestParameter("message", options.getSeed(), options.getMaxTokens(),
options.getTopP(), options.getTopK(), options.getRepetitionPenalty(), options.getPresencePenalty(),
options.getTemperature(), options.getStop(), options.getEnableSearch(), incrementalOutput,
options.getTools(), options.getToolChoice(), stream, options.getVlHighResolutionImages());
return new ChatCompletionRequestParameter(
"message",
options.getSeed(),
options.getMaxTokens(),
options.getTopP(),
options.getTopK(),
options.getRepetitionPenalty(),
options.getPresencePenalty(),
options.getTemperature(),
options.getStop(),
options.getEnableSearch(),
options.getResponseFormat(),
incrementalOutput,
options.getTools(),
options.getToolChoice(),
stream, options.getVlHighResolutionImages()
);
}

}
Loading

0 comments on commit 37c8643

Please # to comment.