java 大模型代理 chat completions

1 controller

    @SneakyThrows
    @PostMapping(value = "/v1/chat/completions", produces = {TEXT_EVENT_STREAM_VALUE, APPLICATION_JSON_VALUE})
    @Operation(summary = AGENT_SERVICE_CHAT, description = AGENT_SERVICE)
    @Logger
    @RateLimiter
    @Monitor
    public Flux<AgentChatVo> chat(@RequestBody AgentChatRequest request) {
        Flux<AgentChatVo> agentChatVoFlux = agentServiceService.chat(request);
        if (!request.getStream()) {
            jsonOut(agentChatVoFlux);
            return null;
        }
        return agentChatVoFlux;
    }

2 service

@Service
@Slf4j
public class AgentServiceServiceImpl implements AgentServiceService {
    private static final String BEARER = "Bearer ";

    // 根据不同的type走不同的类型接口 modelArts AIGC 自建
    private static final Map<AiChatApiTypeEnum, Function<AgentChatRequest, Flux<AgentChatVo>>> TYPE_NAP = new HashMap<>(
        3);

    private static final AgentChatVo EMPTY = new AgentChatVo();

    private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;

    @Resource
    private AiChatAIGCIDProperties aiChatAIGCIDProperties;

    @Resource
    private AiChatAIGCProperties aiChatAIGCProperties;

    @Resource
    private CommonService commonService;

    @Resource
    private ChatApiService chatApiService;

    @Resource
    private DynamicRouteService dynamicRouteService;

    @Resource(name = "ignoreSSLWebClient")
    private WebClient ignoreSSLWebClient;

    @Resource(name = "webClient")
    private WebClient webClient;

    {
        // modelArts 接口
        TYPE_NAP.put(MODEL_ARTS, this::modelArtsTypeOperation);
        // AIGC 接口
        TYPE_NAP.put(AIGC, this::aigcTypeOperation);
        // 自建 接口
        TYPE_NAP.put(OWNS, this::ownsTypeOperation);
    }

    /**
     * chat main 方法
     *
     * @param request request
     * @return Flux<AgentChatVo>
     */
    @Override
    public Flux<AgentChatVo> chat(AgentChatRequest request) {
        // 校验模型
        String type = checkModel(request);
        // 获取模型枚举
        Optional<AiChatApiTypeEnum> chatApiTypeEnumOptional = AiChatApiTypeEnum.from(type);
        if (chatApiTypeEnumOptional.isEmpty()) {
            return Flux.just(EMPTY);
        }
        // 根据不同模型选择不同策略
        return TYPE_NAP.get(chatApiTypeEnumOptional.get()).apply(request);
    }

    private String checkModel(AgentChatRequest request) {
        List<ChatApi> chatApis = chatApiService.getChatApis();

        for (ChatApi chatApi : chatApis) {
            if (StringUtils.equals(chatApi.getModel(), request.getModel())) {
                return chatApi.getType();
            }
        }
        List<String> allModels = chatApis.stream().map(ChatApi::getModel).distinct().collect(Collectors.toList());
        String typeStr = String.join(",", allModels);
        throw new AppException(HttpStatus.BAD_REQUEST.value(), "model 只能是 " + typeStr);
    }

    /**
     * modelarts 接口请求处理
     *
     * @param request request
     * @return Flux<AgentChatVo>
     */
    private Flux<AgentChatVo> modelArtsTypeOperation(AgentChatRequest request) {
        ChatApi chatApi = chatApiService.findByModel(request.getModel()).get();

        // 组装入参
        request.getMessages().forEach(Message::initRole);
        request.setModel(getModelIfOriginalModelNonNull(chatApi));

        if (request.getStream()) {
            return getModelArtsApiResponse(TEXT_EVENT_STREAM, chatApi, request);
        }

        return getModelArtsApiResponse(APPLICATION_JSON, chatApi, request);
    }

    @NotNull
    private Flux<AgentChatVo> getModelArtsApiResponse(MediaType mediaType, ChatApi chatApi, AgentChatRequest request) {
        WebClient.ResponseSpec retrieve = webClient.post()
            .uri(getModelArtsApiUrl(chatApi))
            .accept(mediaType)
            .header(CSB_TOKEN, chatApi.getHeader().get(CSB_TOKEN))
            .body(BodyInserters.fromValue(request))
            .retrieve();
        if (APPLICATION_JSON.equals(mediaType)) {
            return retrieve.bodyToFlux(AgentChatVo.class);
        }
        return getStreamFilterDone(retrieve);
    }

    @NotNull
    private Flux<AgentChatVo> getStreamFilterDone(WebClient.ResponseSpec retrieve) {
        return retrieve.bodyToFlux(String.class)
            // cancels the flux stream after the "[DONE]" is received.
            .takeUntil(SSE_DONE_PREDICATE)
            // filters out the "[DONE]" message.
            .filter(SSE_DONE_PREDICATE.negate())
            .map(content -> JsonUtils.parseJsonStringToEntity(content, AgentChatVo.class));
    }

    @NotNull
    private String getModelArtsApiUrl(ChatApi chatApi) {
        Map<String, String> param = chatApi.getParam();
        StringBuilder queryParams = new StringBuilder();
        for (Map.Entry<String, String> entry : param.entrySet()) {
            queryParams.append("&").append(entry.getKey()).append("=").append(entry.getValue());
        }
        return chatApi.getUrl() + "?" + queryParams.substring(1);
    }

    private String getModelIfOriginalModelNonNull(ChatApi chatApi) {
        return StringUtils.isBlank(chatApi.getOriginalModel()) ? chatApi.getModel() : chatApi.getOriginalModel();
    }

    /**
     * aigc 接口请求处理
     *
     * @param request request
     * @return Flux<AgentChatVo>
     */
    @SneakyThrows
    private Flux<AgentChatVo> aigcTypeOperation(AgentChatRequest request) {
        // modelArts uuid 初始化
        AIGCIdEnum.initId(aiChatAIGCIDProperties);

        Optional<AiChatModelEnum> chatModelEnumOpt = AiChatModelEnum.from(request.getModel());
        if (chatModelEnumOpt.isEmpty()) {
            return Flux.just(EMPTY);
        }
        // 封装api入参
        AIGCChatRequest aigcChatRequest = getAigcApiRequest(request, chatModelEnumOpt.get());

        if (request.getStream()) {
            return getAigcApiStreamResponse(TEXT_EVENT_STREAM, request, aigcChatRequest, ignoreSSLWebClient);
        }
        return getAigcApiStreamResponse(APPLICATION_JSON, request, aigcChatRequest, ignoreSSLWebClient);
    }

    private AIGCChatRequest getAigcApiRequest(AgentChatRequest request, AiChatModelEnum chatModelEnum) {
        // 根据不同的 model 获取不同的 uuid
        String token = AIGCIdEnum.getUUID(request.getContext(), chatModelEnum);
        // 获取最后一个
        Message message = request.getMessages().stream().reduce((first, second) -> second).orElseGet(Message::new);
        message.initRole();
        return new AIGCChatRequest().setQuestion(message.getContent())
            .setScenarioUuid(token)
            .setStop(request.getStop())
            .setTemperature(request.getTemperature())
            .setTopP(request.getTopP())
            .setMaxTokens(request.getMaxTokens())
            .setUserId(ThreadLocalUtil.getAccountId());
    }

    @NotNull
    private Flux<AgentChatVo> getAigcApiStreamResponse(MediaType mediaType, AgentChatRequest request,
        AIGCChatRequest aigcChatRequest, WebClient webClient) {
        Optional<IamTokenModel> tokenOpt = commonService.getToken();

        return tokenOpt.map(iamTokenModel -> webClient.post()
            .uri(aiChatAIGCProperties.getAppUrl())
            .header(AUTHORIZATION, iamTokenModel.getAccessToken())
            .accept(mediaType)
            .body(BodyInserters.fromValue(aigcChatRequest))
            .retrieve()
            .bodyToFlux(AIGCChatModel.class)
            .map(node -> converterAgentChatVo(request, node))).orElseGet(Flux::empty);
    }

    private AgentChatVo converterAgentChatVo(AgentChatRequest request, AIGCChatModel aigcChatModel) {
        String response = getAIGCApiResponseContent(aigcChatModel);
        Choice choice = new Choice().setMessage(new Choice.Message().setContent(response));
        choice.initDelta();
        return new AgentChatVo().setModel(request.getModel()).setChoices(List.of(choice));
    }

    @NotNull
    private String getAIGCApiResponseContent(AIGCChatModel aigcChatModel) {
        return Optional.ofNullable(aigcChatModel)
            .map(AIGCChatModel::getChoices)
            .orElseGet(ArrayList::new)
            .stream()
            .findFirst()
            .map(AIGCChatModel.Choice::getContent)
            .orElse("");
    }

    /**
     * 自有接口请求处理
     *
     * @param request request
     * @return Flux<AgentChatVo>
     */
    @SneakyThrows
    private Flux<AgentChatVo> ownsTypeOperation(AgentChatRequest request) {
        ChatApi chatApi = chatApiService.findByModel(request.getModel()).get();

        request.getMessages().forEach(Message::initRole);
        request.setContext(null);
        request.setModel(getModelIfOriginalModelNonNull(chatApi));

        if (request.getStream()) {
            return getOwnsApiStreamResponse(TEXT_EVENT_STREAM, request, chatApi);
        }
        return getOwnsApiStreamResponse(APPLICATION_JSON, request, chatApi);
    }

    @NotNull
    private Flux<AgentChatVo> getOwnsApiStreamResponse(MediaType mediaType, AgentChatRequest request, ChatApi chatApi) {
        WebClient.ResponseSpec retrieve = webClient.post()
            .uri(chatApi.getUrl())
            .accept(mediaType)
            .body(BodyInserters.fromValue(request))
            .retrieve();
        if (APPLICATION_JSON.equals(mediaType)) {
            return retrieve.bodyToFlux(AgentChatVo.class);
        }
        return getStreamFilterDone(retrieve);
    }
}

3 数据库实体

@Data
@Schema(name = "ChatApi", description = "chat api 接口")
@TableName(value = "chat_api", autoResultMap = true)
@Accessors(chain = true)
public class ChatApi implements Serializable {
    @Serial
    private static final long serialVersionUID = 1L;

    /**
     * 主键 自增
     */
    @TableId(type = IdType.AUTO)
    @Schema(allowableValues = "ID")
    private String id;

    private String tenantId;

    private String url;

    private String requestType;

    @TableField(typeHandler = MapTypeHandler.class)
    private Map<String, String> header = new HashMap<>();

    @TableField(typeHandler = MapTypeHandler.class)
    private Map<String, String> param = new HashMap<>();

    private String type;

    private String model;

    private String originalModel;

    @TableLogic(value = "0", delval = "1")
    private Boolean delFlag;

    /**
     * 创建人
     */
    @Schema(description = "创建人")
    @TableField(fill = FieldFill.INSERT)
    private String createBy;

    /**
     * 创建时间
     */
    @JsonFormat(timezone = "GMT+8", pattern = "yyyy-MM-dd HH:mm:ss")
    @DateTimeFormat(pattern = "yyyy-MM-dd HH:mm:ss")
    @Schema(description = "创建时间")
    @TableField(fill = FieldFill.INSERT)
    private Date createTime;

    /**
     * 更新人
     */
    @Schema(description = "更新人")
    @TableField(fill = FieldFill.INSERT_UPDATE)
    private String updateBy;

    /**
     * 更新时间
     */
    @JsonFormat(timezone = "GMT+8", pattern = "yyyy-MM-dd HH:mm:ss")
    @DateTimeFormat(pattern = "yyyy-MM-dd HH:mm:ss")
    @Schema(description = "更新时间")
    @TableField(fill = FieldFill.INSERT_UPDATE)
    private Date updateTime;

}
posted @ 2024-06-19 11:09  linzm14  阅读(26)  评论(0编辑  收藏  举报