最近在搞大模型应用的开发,目前用的是Java语言。我们与大模型调用交互的方式是基于SSE(Server-Sent Events)协议。目前的方案是引用了okhttp-see的包。因为是刚刚接触,所以发出这篇文章描述一下我是如何实现的,希望业内的大佬给指点一下,我这种方式如何?有没有标准的调用代码,有没有更优雅的代码。正确的调用模型接口是怎么样的?
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-sse</artifactId>
</dependency>
具体实现代码如下:
EventSourceListener 的实现:
// 写了一个基础类,定义了一些基础的变量
@Data
public abstract class BaseSSEListener extends EventSourceListener {
private CountDownLatch countDownLatch = new CountDownLatch(1);
// 保存会话信息
public static ConcurrentHashMap<String, EventSource> listenerHashMap = new ConcurrentHashMap<>();
}
// 具体的实现类
@Slf4j
@Data
public class SSEPictureGeneratedSpeechListener extends BaseSSEListener {
// 会话的id,会返给调用者,调用者可以根据这个id调用停止接口
private String eventId;
private HttpServletResponse response;
public SSEPictureGeneratedSpeechListener(String eventId, HttpServletResponse response) {
this.eventId = eventId;
this.response = response;
}
/**
* {@inheritDoc}
* 建立sse连接
*/
@Override
public void onOpen(final EventSource eventSource, final Response
response) {
listenerHashMap.put(eventId, eventSource);
log.info("与模型建立sse连接...");
}
/**
* 事件
*
* @param eventSource
* @param id
* @param type
* @param data
*/
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
try {
if (response != null) {
if ("\n".equals(data)) {
response.getWriter().write("id:" + eventId + "\n");
response.getWriter().write("event:" + "message" + "\n");
response.getWriter().write("data:\n\n");
response.getWriter().flush();
} else {
response.getWriter().write("id:" + eventId + "\n");
response.getWriter().write("event:" + "message" + "\n");
response.getWriter().write("data:" + data + "\n\n");
response.getWriter().flush();
}
}
} catch (Exception e) {
log.error("消息错误]", e);
getCountDownLatch().countDown();
listenerHashMap.remove(eventId);
throw new RuntimeException(e);
}
}
// 经过测试和看代码,发现不管是正常停止还是异常停止都会走onClosed接口
@Override
public void onClosed(final EventSource eventSource) {
log.info("连接关闭:{}", JsonUtils.toJsonString(eventSource));
getCountDownLatch().countDown();
listenerHashMap.remove(eventId);
}
// 异常情况,调用了EventSource的cancel()方法会进入这里,但是模型服务异常了也会进入这个方法。我发现response的响应码不一样,就根据这个区分的,200是用户主动停止的
@Override
public void onFailure(final EventSource eventSource, final Throwable t, final Response response) {
log.error("使用事件源时出现异常...:{}", t != null ? t.getMessage() : "未知错误");
getCountDownLatch().countDown();
listenerHashMap.remove(eventId);
if (response.code() == 200) {
try {
this.response.getWriter().write("id:" + eventId + "\n");
this.response.getWriter().write("event:" + "cancel" + "\n");
this.response.getWriter().write("data:" + "\n\n");
} catch (IOException e) {
throw new RuntimeException(e);
}
} else {
try {
this.response.getWriter().write("id:" + eventId + "\n");
this.response.getWriter().write("event:" + "error" + "\n");
this.response.getWriter().write("data:" + "\n\n");
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}
下面是controller的代码
// 停止接口
@GetMapping("stopOutput/{eventId}")
public AjaxResult stopOutput(@PathVariable("eventId") String eventId) {
EventSource eventSource = BaseSSEListener.listenerHashMap.get(eventId);
if (eventSource != null) {
eventSource.cancel();
}
return AjaxResult.success();
}
// 调用的接口
@PostMapping("pictureGeneratedSpeech")
public void pictureGeneratedSpeech(@RequestBody PictureGeneratedSpeechAO pictureGeneratedSpeechAO, HttpServletResponse response) throws IOException {
// 开启流式返回
response.setContentType("text/event-stream");
response.setCharacterEncoding("UTF-8");
response.setStatus(200);
log.info("建立sse连接...");
// 生成id
Long id = IdGenerateUtil.nextId();
HashMap<String, Object> map = new HashMap<>();
// 下载图片
InputStream inputStream = null;
try {
inputStream = FileUtils.downloadFile(storageUrl + pictureGeneratedSpeechAO.getUrl(), null);
String base64String = Base64.getEncoder().encodeToString(FileUtils.inputStreamToByteArray(inputStream));
map.put("imageBase64", base64String);
} catch (Exception e) {
e.printStackTrace();
} finally {
if (inputStream != null) {
inputStream.close();
}
}
map.put("prompt_text", pictureGeneratedSpeechAO.getPrompt());
map.put("stream", true);
try {
SSEPictureGeneratedSpeechListener sseListener = new SSEPictureGeneratedSpeechListener(id + "", response);
ExecuteSSEUtil.executeListener(lagerModelToolUrl + "/internvl/infer", null, sseListener, JSON.toJSONString(map));
} catch (Exception e) {
log.error("请求SSE错误处理", e);
throw new BaseException("请求SSE错误处理");
}
}