为了熟悉下 OKhttp 和 ChatGLM 接口,写几个 demo 试试
1. 准备工作
从 ChatGLM 的接口文档可知,每次 HTTP 调用都需要带上一个鉴权 token,而组装这个 token,我们需要先获取一个 API Key,这个可从智谱AI开放平台 API Keys 页面获得,API Key 包含 “用户标识 id” 和 “签名密钥 secret”,即格式为 {id}.{secret}
获取 token 和接口请求参数的代码在最后的附录中
2. SSE 调用
SSE(Sever-Sent Event),就是浏览器向服务器发送一个HTTP请求,保持长连接,服务器不断单向地向浏览器推送“信息”(message),这么做是为了节约网络资源,不用一直发请求,建立新连接。
// 设置请求参数
RequestParam requestParam = new RequestParam();
List<RequestParam.Prompt> prompts = new ArrayList<>();
prompts.add(RequestParam.Prompt.builder()
.role(Role.user.getCode())
.content("你好,我想问你一些 Java 相关的问题")
.build());
requestParam.setPrompt(prompts);
// 创建请求体
MediaType json = MediaType.parse("application/json; charset=utf-8");
RequestBody requestBody = RequestBody.create(json, requestParam.toString());
// 创建请求对象
Request request = new Request.Builder()
.url("https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_turbo/sse-invoke")
.post(requestBody) // 请求体
.addHeader("Authorization", "Bearer " + token)
.addHeader("Accept", "text/event-stream")
.build();
// 开启 Http 客户端
OkHttpClient okHttpClient = new OkHttpClient.Builder()
.connectTimeout(10, TimeUnit.SECONDS) // 建立连接的超时时间
.readTimeout(10, TimeUnit.MINUTES) // 建立连接后读取数据的超时时间
.build();
// 创建一个 CountDownLatch 对象,其初始计数为1,表示需要等待一个事件发生后才能继续执行。
CountDownLatch eventLatch = new CountDownLatch(1);
// 实例化EventSource,注册EventSource监听器 -- 创建一个用于处理服务器发送事件的实例,并定义处理事件的回调逻辑
RealEventSource realEventSource = new RealEventSource(request, new EventSourceListener() {
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
System.out.println(data); // 请求到的数据
if ("finish".equals(type)) { // 消息类型,add 增量,finish 结束,error 错误,interrupted 中断
eventLatch.countDown();
}
}
});
// 与服务器建立连接
realEventSource.connect(okHttpClient);
// await() 方法被调用来阻塞当前线程,直到 CountDownLatch 的计数变为0。
eventLatch.await();
3. 异步调用
根据文档描述,首先得通过异步 POST 请求获得 task_id ,再根据 task_id 发送 GET 请求获得最终结果
// TODO 设置请求参数,同 SSE 调用
// 开启 Http 客户端
OkHttpClient okHttpClient = new OkHttpClient();
// 创建请求体
MediaType json = MediaType.parse("application/json; charset=utf-8");
RequestBody requestBody = RequestBody.create(json, requestParam.toString());
// 第一步:发送异步请求(POST)获取 task_id,并存放到 taskIdFuture 中
CompletableFuture<String> taskIdFuture = new CompletableFuture<>();
Request requestForTaskId = new Request.Builder()
.url("https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_turbo/async-invoke")
.post(requestBody)
.addHeader("Authorization", "Bearer " + token)
.build();
// 创建一个新的异步 HTTP 请求,并指定请求的回调函数
okHttpClient.newCall(requestForTaskId).enqueue(new Callback() {
// 在请求成功并返回响应时被调用
@Override
public void onResponse(Call call, Response response) throws IOException {
if (response.isSuccessful()) {
String responseBody = response.body().string();
System.out.println("requestForTaskId: " + responseBody);
// 解析 JSON 响应获取 task_id
JSONObject jsonObject = JSON.parseObject(responseBody);
String taskId = jsonObject.getJSONObject("data").getString("task_id");
// 将结果设置到 CompletableFuture
taskIdFuture.complete(taskId);
} else {
taskIdFuture.completeExceptionally(new Exception("Request for task_id failed"));
}
}
// 在请求失败时被调用
@Override
public void onFailure(Call call, IOException e) {
taskIdFuture.completeExceptionally(e);
}
});
// 阻塞主线程,等待 CompletableFuture 的结果,设置了最大等待时间
String taskId = taskIdFuture.get(10, TimeUnit.SECONDS);
System.out.println("Task ID: " + taskId);
// TODO 第二步,使用 task_id 发送同步请求(GET)获取最终响应结果(和第四节基本一样)
4. 同步调用
// TODO 设置请求参数,同 SSE 调用
// 开启 Http 客户端
OkHttpClient client = new OkHttpClient();
// 创建请求体
MediaType json = MediaType.parse("application/json; charset=utf-8");
RequestBody requestBody = RequestBody.create(json, requestParam.toString());
// 创建请求对象
Request request = new Request.Builder()
.url("https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_turbo/invoke")
.post(requestBody)
.addHeader("Authorization", "Bearer " + token)
.build();
// 发送请求
Response response = client.newCall(request).execute();
// 处理响应
if (response.isSuccessful()) {
String responseBody = response.body().string();
System.out.println("Response: " + responseBody);
} else {
System.out.println("Request failed: " + response.code() + " " + response.message());
}
5. 附录
5.1 组装鉴权 token
// 这里的 secret 是 API Key 中的 {secret} 部分
Algorithm algorithm = Algorithm.HMAC256(secret.getBytes(StandardCharsets.UTF_8));
Map<String, Object> payload = new HashMap<>();
// 这里的 id 是 API Key 中的 {id} 部分
payload.put("api_key", id);
payload.put("exp", System.currentTimeMillis() + 30 * 60 * 1000L); // 过期时间, 30分钟
payload.put("timestamp", Calendar.getInstance().getTimeInMillis()); // 时间戳
Map<String, Object> headerClaims = new HashMap<>();
headerClaims.put("alg", "HS256");
headerClaims.put("sign_type", "SIGN");
String token = JWT.create().withPayload(payload).withHeader(headerClaims).sign(algorithm);
5.2 接口请求参数
@Data
@JsonInclude(JsonInclude.Include.NON_NULL)
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class RequestParam {
@JsonProperty("request_id")
private String requestId = String.format("gpt-%d", System.currentTimeMillis());
private float temperature = 0.9f;
@JsonProperty("top_p")
private float topP = 0.7f;
/**
* 输入给模型的会话信息
* 用户输入的内容;role=user
* 挟带历史的内容;role=assistant
*/
private List<RequestParam.Prompt> prompt;
private boolean incremental = true;
private String sseFormat = "data";
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class Prompt {
private String role;
private String content;
}
@Override
public String toString() {
Map<String, Object> paramsMap = new HashMap<>();
paramsMap.put("request_id", requestId);
paramsMap.put("prompt", prompt);
paramsMap.put("incremental", incremental);
paramsMap.put("temperature", temperature);
paramsMap.put("top_p", topP);
paramsMap.put("sseFormat", sseFormat);
try {
return new ObjectMapper().writeValueAsString(paramsMap);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}