public static Object sseSendMsg(StandardReqBody reqBody) {
final StringBuilder result = new StringBuilder();
final StringBuilder standardRsp = new StringBuilder();
if (reqBody.getStream()) {
// 创建新的SseEmitter实例,设置超时时间为60秒
SseEmitterUTF8 emitter = new SseEmitterUTF8(180_000L);
String sessionId = UUID.randomUUID().toString();
// 添加到活跃列表
emitters.add(emitter);
// 设置回调
emitter.onCompletion(() -> {
log.info("SSE connection completed for session: {}", sessionId);
emitters.remove(emitter);
activeWebSockets.remove(sessionId);
});
emitter.onTimeout(() -> {
log.info("SSE connection timed out for session: {}", sessionId);
emitters.remove(emitter);
activeWebSockets.remove(sessionId);
emitter.complete();
});
emitter.onError(e -> {
log.error("SSE error for session {}: {}", sessionId, e.getMessage());
emitters.remove(emitter);
activeWebSockets.remove(sessionId);
emitter.completeWithError(e);
});
try {
// 启动WebSocket连接
CallAiUtil.standardFlowApi(sessionId, emitter, reqBody, result, standardRsp);
} catch (Exception e) {
log.error("Error in req() method: {}", e.getMessage(), e);
emitters.remove(emitter);
activeWebSockets.remove(sessionId);
emitter.completeWithError(e);
}
return emitter;
} else { //非流式
try {
// 启动WebSocket连接
CallAiUtil.standardFlowApi(reqBody, result, standardRsp);
} catch (Exception e) {
log.error("Error in req() method: {}", e.getMessage(), e);
}
return standardRsp.toString();
}
}public static String standardFlowApi(StandardReqBody reqBody, StringBuilder result, StringBuilder standardRsp) {
CountDownLatch latch = new CountDownLatch(1);
// 1. 构建请求client
OkHttpClient.Builder builder = new OkHttpClient.Builder()
.connectionPool(new ConnectionPool(5, 100, TimeUnit.SECONDS))
.readTimeout(100000, TimeUnit.MILLISECONDS)
.connectTimeout(100000, TimeUnit.MILLISECONDS)
.writeTimeout(100000, TimeUnit.MILLISECONDS)
.addInterceptor(new ApiGatewayAuthInterceptor(appId, appSecret, assistantCodeMap.get(reqBody.getModel())));
OkHttpClient okHttpClient = builder.build();
// 2. 构建消息的请求体
//构建请求体
ReqBody chatRequest = ReqBody.builder()
.header(ReqBody.ReqHeader.builder()
.traceId(UUID.randomUUID().toString())
.appId(appId)
.assistantCode(assistantCodeMap.get(reqBody.getModel()))
.build())
.payload(ReqBody.ReqPayload.builder()
.sessionId(reqBody.getSessionId())
.text(reqBody.getMessages())
.build())
.build();
Request request = new Request.Builder().url(baseUrl + File.separator + chatEndpoint).build();
WebSocket webSocket = okHttpClient.newWebSocket(request, new WebSocketListener() {
@Override
public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
}
@Override
public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {
SseEmitterUtil.safelySendEvent(text, reqBody, result, standardRsp);
}
@Override
public void onClosing(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
webSocket.close(1000, "Closing");
}
@Override
public void onClosed(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
}
@Override
public void onFailure(@NotNull WebSocket webSocket, @NotNull Throwable t, Response response) {
String errorMsg = "WebSocket failure: " + t.getMessage();
if (response != null && response.body() != null) {
try {
errorMsg += "; Response: " + response.body().string();
} catch (IOException e) {
log.error(errorMsg, e);
}
}
log.error("webSocket-onFailure-msg:{}", errorMsg);
}
});
try {
latch.await(100 * 1000, TimeUnit.MILLISECONDS); // 等待流结束
} catch (InterruptedException e) {
throw new BException(ConstantsResultCode.PUB.SYSTEM_ERR, "等待ai接口输出数据异常!");
}
webSocket.close(1000, "NORMAL_CLOSURE");
okHttpClient.dispatcher().executorService().shutdown();
return standardRsp.toString();
}public static void standardFlowApi(String sessionId, SseEmitterUTF8 emitter, StandardReqBody reqBody, StringBuilder result, StringBuilder standardRsp) {
// 1. 构建请求client
OkHttpClient.Builder builder = new OkHttpClient.Builder()
.connectionPool(new ConnectionPool(5, 100, TimeUnit.SECONDS))
.readTimeout(100000, TimeUnit.MILLISECONDS)
.connectTimeout(100000, TimeUnit.MILLISECONDS)
.writeTimeout(100000, TimeUnit.MILLISECONDS)
.addInterceptor(new ApiGatewayAuthInterceptor(appId, appSecret, assistantCodeMap.get(reqBody.getModel())));
OkHttpClient okHttpClient = builder.build();
// 2. 构建消息的请求体
//构建请求体
ReqBody chatRequest = ReqBody.builder()
.header(ReqBody.ReqHeader.builder()
.traceId(UUID.randomUUID().toString())
.appId(appId)
.assistantCode(assistantCodeMap.get(reqBody.getModel()))
.build())
.payload(ReqBody.ReqPayload.builder()
.sessionId(reqBody.getSessionId())
.text(reqBody.getMessages())
.build())
.build();
Request request = new Request.Builder().url(baseUrl + File.separator + chatEndpoint).build();
WebSocket webSocket = okHttpClient.newWebSocket(request, new WebSocketListener() {
@Override
public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
log.info("WebSocket opened for session: {}", sessionId);
SseEmitterUtil.activeWebSockets.put(sessionId, webSocket);
String jsonString = JSONObject.toJSONString(chatRequest);
webSocket.send(jsonString);
}
@Override
public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {
SseEmitterUtil.safelySendEvent(sessionId, emitter, text, reqBody, result, standardRsp);
}
@Override
public void onClosing(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
webSocket.close(1000, "Closing");
SseEmitterUtil.activeWebSockets.remove(sessionId);
}
@Override
public void onClosed(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
if (emitter != null) {
emitter.complete();
}
SseEmitterUtil.activeWebSockets.remove(sessionId);
}
@Override
public void onFailure(@NotNull WebSocket webSocket, @NotNull Throwable t, Response response) {
String errorMsg = "WebSocket failure: " + t.getMessage();
if (response != null && response.body() != null) {
try {
errorMsg += "; Response: " + response.body().string();
} catch (IOException e) {
}
}
log.error("webSocket-onFailure-msg:{}", errorMsg);
SseEmitterUtil.activeWebSockets.remove(sessionId);
emitter.completeWithError(t);
}
});
} // 安全发送SSE事件的辅助方法
public static void safelySendEvent(String sessionId, SseEmitterUTF8 emitter, String data, StandardReqBody reqBody, StringBuilder result, StringBuilder standardRsp) {
try {
StandardRspBody rspBody = buildStandRsp(JSONObject.parseObject(data, RspBody.class), reqBody);
// 检查emitter是否还在活动列表中
if (emitters.contains(emitter)) {
emitter.send(SseEmitterUTF8.event()
.data(JSONObject.toJSONString(rspBody)));
}
} catch (Exception e) {
// 发送失败时,从列表中移除并完成emitter
safelyCompleteEmitter(emitter, sessionId);
}
}
public static void safelySendEvent(String data, StandardReqBody reqBody, StringBuilder result, StringBuilder standardRsp) {
StandardRspBody rspBody = buildStandRsp(JSONObject.parseObject(data, RspBody.class), reqBody);
buildResult(rspBody, result, standardRsp);
} public static void buildResult(StandardRspBody rspBody, StringBuilder result, StringBuilder standardRsp) {
if (!"chat.completion".equals(rspBody.getObject()) && (rspBody.getChoices() == null || rspBody.getChoices().isEmpty() || rspBody.getChoices().get(0).getDelta() == null || !StringUtils.hasLength(rspBody.getChoices().get(0).getDelta().getContent()))) {
return;
}
result.append(rspBody.getChoices().get(0).getDelta().getContent());
StandardRspBody.MessageData messageData = rspBody.getChoices().get(0).getDelta();
messageData.setContent(result.toString());
rspBody.getChoices().get(0).setMessage(messageData);
rspBody.getChoices().get(0).setDelta(messageData);
standardRsp.setLength(0);
standardRsp.append(JSONObject.toJSONString(rspBody));
}我一个方法可能放回sse对象给前端显示流式输出,也有可能是非流式拿到websocket的数据拼接完结果再返回给前端数据,你看看有什么问题
最新发布