java 大模型代理 chat completions
1 controller
@SneakyThrows
@PostMapping(value = "/v1/chat/completions", produces = {TEXT_EVENT_STREAM_VALUE, APPLICATION_JSON_VALUE})
@Operation(summary = AGENT_SERVICE_CHAT, description = AGENT_SERVICE)
@Logger
@RateLimiter
@Monitor
public Flux<AgentChatVo> chat(@RequestBody AgentChatRequest request) {
Flux<AgentChatVo> agentChatVoFlux = agentServiceService.chat(request);
if (!request.getStream()) {
jsonOut(agentChatVoFlux);
return null;
}
return agentChatVoFlux;
}
2 service
@Service
@Slf4j
public class AgentServiceServiceImpl implements AgentServiceService {
private static final String BEARER = "Bearer ";
// 根据不同的type走不同的类型接口 modelArts AIGC 自建
private static final Map<AiChatApiTypeEnum, Function<AgentChatRequest, Flux<AgentChatVo>>> TYPE_NAP = new HashMap<>(
3);
private static final AgentChatVo EMPTY = new AgentChatVo();
private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;
@Resource
private AiChatAIGCIDProperties aiChatAIGCIDProperties;
@Resource
private AiChatAIGCProperties aiChatAIGCProperties;
@Resource
private CommonService commonService;
@Resource
private ChatApiService chatApiService;
@Resource
private DynamicRouteService dynamicRouteService;
@Resource(name = "ignoreSSLWebClient")
private WebClient ignoreSSLWebClient;
@Resource(name = "webClient")
private WebClient webClient;
{
// modelArts 接口
TYPE_NAP.put(MODEL_ARTS, this::modelArtsTypeOperation);
// AIGC 接口
TYPE_NAP.put(AIGC, this::aigcTypeOperation);
// 自建 接口
TYPE_NAP.put(OWNS, this::ownsTypeOperation);
}
/**
* chat main 方法
*
* @param request request
* @return Flux<AgentChatVo>
*/
@Override
public Flux<AgentChatVo> chat(AgentChatRequest request) {
// 校验模型
String type = checkModel(request);
// 获取模型枚举
Optional<AiChatApiTypeEnum> chatApiTypeEnumOptional = AiChatApiTypeEnum.from(type);
if (chatApiTypeEnumOptional.isEmpty()) {
return Flux.just(EMPTY);
}
// 根据不同模型选择不同策略
return TYPE_NAP.get(chatApiTypeEnumOptional.get()).apply(request);
}
private String checkModel(AgentChatRequest request) {
List<ChatApi> chatApis = chatApiService.getChatApis();
for (ChatApi chatApi : chatApis) {
if (StringUtils.equals(chatApi.getModel(), request.getModel())) {
return chatApi.getType();
}
}
List<String> allModels = chatApis.stream().map(ChatApi::getModel).distinct().collect(Collectors.toList());
String typeStr = String.join(",", allModels);
throw new AppException(HttpStatus.BAD_REQUEST.value(), "model 只能是 " + typeStr);
}
/**
* modelarts 接口请求处理
*
* @param request request
* @return Flux<AgentChatVo>
*/
private Flux<AgentChatVo> modelArtsTypeOperation(AgentChatRequest request) {
ChatApi chatApi = chatApiService.findByModel(request.getModel()).get();
// 组装入参
request.getMessages().forEach(Message::initRole);
request.setModel(getModelIfOriginalModelNonNull(chatApi));
if (request.getStream()) {
return getModelArtsApiResponse(TEXT_EVENT_STREAM, chatApi, request);
}
return getModelArtsApiResponse(APPLICATION_JSON, chatApi, request);
}
@NotNull
private Flux<AgentChatVo> getModelArtsApiResponse(MediaType mediaType, ChatApi chatApi, AgentChatRequest request) {
WebClient.ResponseSpec retrieve = webClient.post()
.uri(getModelArtsApiUrl(chatApi))
.accept(mediaType)
.header(CSB_TOKEN, chatApi.getHeader().get(CSB_TOKEN))
.body(BodyInserters.fromValue(request))
.retrieve();
if (APPLICATION_JSON.equals(mediaType)) {
return retrieve.bodyToFlux(AgentChatVo.class);
}
return getStreamFilterDone(retrieve);
}
@NotNull
private Flux<AgentChatVo> getStreamFilterDone(WebClient.ResponseSpec retrieve) {
return retrieve.bodyToFlux(String.class)
// cancels the flux stream after the "[DONE]" is received.
.takeUntil(SSE_DONE_PREDICATE)
// filters out the "[DONE]" message.
.filter(SSE_DONE_PREDICATE.negate())
.map(content -> JsonUtils.parseJsonStringToEntity(content, AgentChatVo.class));
}
@NotNull
private String getModelArtsApiUrl(ChatApi chatApi) {
Map<String, String> param = chatApi.getParam();
StringBuilder queryParams = new StringBuilder();
for (Map.Entry<String, String> entry : param.entrySet()) {
queryParams.append("&").append(entry.getKey()).append("=").append(entry.getValue());
}
return chatApi.getUrl() + "?" + queryParams.substring(1);
}
private String getModelIfOriginalModelNonNull(ChatApi chatApi) {
return StringUtils.isBlank(chatApi.getOriginalModel()) ? chatApi.getModel() : chatApi.getOriginalModel();
}
/**
* aigc 接口请求处理
*
* @param request request
* @return Flux<AgentChatVo>
*/
@SneakyThrows
private Flux<AgentChatVo> aigcTypeOperation(AgentChatRequest request) {
// modelArts uuid 初始化
AIGCIdEnum.initId(aiChatAIGCIDProperties);
Optional<AiChatModelEnum> chatModelEnumOpt = AiChatModelEnum.from(request.getModel());
if (chatModelEnumOpt.isEmpty()) {
return Flux.just(EMPTY);
}
// 封装api入参
AIGCChatRequest aigcChatRequest = getAigcApiRequest(request, chatModelEnumOpt.get());
if (request.getStream()) {
return getAigcApiStreamResponse(TEXT_EVENT_STREAM, request, aigcChatRequest, ignoreSSLWebClient);
}
return getAigcApiStreamResponse(APPLICATION_JSON, request, aigcChatRequest, ignoreSSLWebClient);
}
private AIGCChatRequest getAigcApiRequest(AgentChatRequest request, AiChatModelEnum chatModelEnum) {
// 根据不同的 model 获取不同的 uuid
String token = AIGCIdEnum.getUUID(request.getContext(), chatModelEnum);
// 获取最后一个
Message message = request.getMessages().stream().reduce((first, second) -> second).orElseGet(Message::new);
message.initRole();
return new AIGCChatRequest().setQuestion(message.getContent())
.setScenarioUuid(token)
.setStop(request.getStop())
.setTemperature(request.getTemperature())
.setTopP(request.getTopP())
.setMaxTokens(request.getMaxTokens())
.setUserId(ThreadLocalUtil.getAccountId());
}
@NotNull
private Flux<AgentChatVo> getAigcApiStreamResponse(MediaType mediaType, AgentChatRequest request,
AIGCChatRequest aigcChatRequest, WebClient webClient) {
Optional<IamTokenModel> tokenOpt = commonService.getToken();
return tokenOpt.map(iamTokenModel -> webClient.post()
.uri(aiChatAIGCProperties.getAppUrl())
.header(AUTHORIZATION, iamTokenModel.getAccessToken())
.accept(mediaType)
.body(BodyInserters.fromValue(aigcChatRequest))
.retrieve()
.bodyToFlux(AIGCChatModel.class)
.map(node -> converterAgentChatVo(request, node))).orElseGet(Flux::empty);
}
private AgentChatVo converterAgentChatVo(AgentChatRequest request, AIGCChatModel aigcChatModel) {
String response = getAIGCApiResponseContent(aigcChatModel);
Choice choice = new Choice().setMessage(new Choice.Message().setContent(response));
choice.initDelta();
return new AgentChatVo().setModel(request.getModel()).setChoices(List.of(choice));
}
@NotNull
private String getAIGCApiResponseContent(AIGCChatModel aigcChatModel) {
return Optional.ofNullable(aigcChatModel)
.map(AIGCChatModel::getChoices)
.orElseGet(ArrayList::new)
.stream()
.findFirst()
.map(AIGCChatModel.Choice::getContent)
.orElse("");
}
/**
* 自有接口请求处理
*
* @param request request
* @return Flux<AgentChatVo>
*/
@SneakyThrows
private Flux<AgentChatVo> ownsTypeOperation(AgentChatRequest request) {
ChatApi chatApi = chatApiService.findByModel(request.getModel()).get();
request.getMessages().forEach(Message::initRole);
request.setContext(null);
request.setModel(getModelIfOriginalModelNonNull(chatApi));
if (request.getStream()) {
return getOwnsApiStreamResponse(TEXT_EVENT_STREAM, request, chatApi);
}
return getOwnsApiStreamResponse(APPLICATION_JSON, request, chatApi);
}
@NotNull
private Flux<AgentChatVo> getOwnsApiStreamResponse(MediaType mediaType, AgentChatRequest request, ChatApi chatApi) {
WebClient.ResponseSpec retrieve = webClient.post()
.uri(chatApi.getUrl())
.accept(mediaType)
.body(BodyInserters.fromValue(request))
.retrieve();
if (APPLICATION_JSON.equals(mediaType)) {
return retrieve.bodyToFlux(AgentChatVo.class);
}
return getStreamFilterDone(retrieve);
}
}
3 数据库实体
@Data
@Schema(name = "ChatApi", description = "chat api 接口")
@TableName(value = "chat_api", autoResultMap = true)
@Accessors(chain = true)
public class ChatApi implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
/**
* 主键 自增
*/
@TableId(type = IdType.AUTO)
@Schema(allowableValues = "ID")
private String id;
private String tenantId;
private String url;
private String requestType;
@TableField(typeHandler = MapTypeHandler.class)
private Map<String, String> header = new HashMap<>();
@TableField(typeHandler = MapTypeHandler.class)
private Map<String, String> param = new HashMap<>();
private String type;
private String model;
private String originalModel;
@TableLogic(value = "0", delval = "1")
private Boolean delFlag;
/**
* 创建人
*/
@Schema(description = "创建人")
@TableField(fill = FieldFill.INSERT)
private String createBy;
/**
* 创建时间
*/
@JsonFormat(timezone = "GMT+8", pattern = "yyyy-MM-dd HH:mm:ss")
@DateTimeFormat(pattern = "yyyy-MM-dd HH:mm:ss")
@Schema(description = "创建时间")
@TableField(fill = FieldFill.INSERT)
private Date createTime;
/**
* 更新人
*/
@Schema(description = "更新人")
@TableField(fill = FieldFill.INSERT_UPDATE)
private String updateBy;
/**
* 更新时间
*/
@JsonFormat(timezone = "GMT+8", pattern = "yyyy-MM-dd HH:mm:ss")
@DateTimeFormat(pattern = "yyyy-MM-dd HH:mm:ss")
@Schema(description = "更新时间")
@TableField(fill = FieldFill.INSERT_UPDATE)
private Date updateTime;
}
蓝天和白云是标配。