记录spring4中websocket的使用方式
pom jar包配置
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-websocket</artifactId>
<version>${spring.version}</version>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-messaging</artifactId>
<version>${spring.version}</version>
</dependency>
其中spring.version的配置是:
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<spring.version>4.0.0.RELEASE</spring.version>
<java.version>1.8</java.version>
<druid.version>1.1.6</druid.version>
</properties>
涉及到json消息的支持jar用的是alibaba提供的:
<!-- json -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.28</version>
</dependency>
配置websocket服务
在spring webscoket中有两种方式配置webscoket服务,一种是xml中配置,一种是使用代码继承WebSocketConfigurer,这里使用第二种:
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
/**
* spring websocket配置
* @author ThatWay
* 2018-5-8
*/
@Configuration
@EnableWebMvc
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
//注册webscoket处理类、webscocket的访问地址、过滤处理类
registry.addHandler(webSocketHandler(), "/ws").addInterceptors(webSocketInterceptor());
}
/**
* websocket请求处理
* @return
*/
@Bean
public WebSocketHandler webSocketHandler() {
return new WebScoketHandler();
}
/**
* websocket拦截器
* @return
*/
@Bean
public WebSocketInterceptor webSocketInterceptor(){
return new WebSocketInterceptor();
}
}
webscoket请求过滤
在上一步的服务配置中,使用的webSocketInterceptor是实现了HandshakeInterceptor接口的过滤处理类,它将拦截所有到达服务端的websocket请求,可websocket消息处理前和处理后插入动作。
这里面主要做的事是,客户端创建连接时传递的参数可以取出来,放入到创建连接后产生的session中,在服务端下发消息时可以通过参数来区分session,下面代码中作为session标识的是pageFlag参数。客户端请求的地址是这样的:ws://localhost:8080/integrate_pipe/ws?pageFlag=p1&actionFlag=simple
import javax.servlet.http.HttpServletRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
/**
* websocket请求过滤器
* @author ThatWay
* 2018-5-8
*/
public class WebSocketInterceptor implements HandshakeInterceptor {
private static Logger logger = LoggerFactory.getLogger(WebSocketInterceptor.class);
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
logger.info("webscoket处理后过滤回调触发");
}
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
logger.info("webscoket处理前过滤回调触发");
boolean flag = true;
//在调用handler前处理方法
if (request instanceof ServletServerHttpRequest) {
ServletServerHttpRequest serverHttpRequest = (ServletServerHttpRequest) request;
HttpServletRequest req = serverHttpRequest.getServletRequest();
// 从请求中获取页面标志
String pageFlag = req.getParameter("pageFlag");
// 获取初始化需要的数据
String actionFlag = req.getParameter("actionFlag");
if(StringUtils.isEmpty(pageFlag) || StringUtils.isEmpty(actionFlag) ){
flag = false;
logger.info("webscoket连接请求,页面标志pageFlag:"+pageFlag+",动作标志:"+actionFlag+",参数不正确,请求拒绝");
} else {
logger.info("webscoket连接请求,页面标志pageFlag:"+pageFlag+",动作标志:"+actionFlag);
// 将页面标识放入参数中,之后的session将根据这个值来区分
attributes.put("pageFlag", pageFlag.trim());
attributes.put("actionFlag", actionFlag.trim());
}
} else {
flag = false;
}
return flag;
}
}
消息处理
在服务配置中,使用的WebSocketHandler是继承了TextWebSocketHandler的消息处理类,将由这个类来处理消息,spring中将webscoket相关的生命周期回调也封装到了这里。另外,通过@Service将此类注解为服务,在其他业务controller中就可以使用此类方法触发消息下发了。
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import cn.qingk.entity.User;
/**
* 消息处理类
* @author ThatWay
* 2018-5-5
*/
@Service
public class WebScoketHandler extends TextWebSocketHandler {
private static Logger logger = LoggerFactory.getLogger(WebSocketHandler.class);
// 页面标识名称
private final String CLIENT_ID = "pageFlag";
// 初始化动作标识名称
private final String ACTION_INIT = "actionFlag";
// 页面集合
private static Map<String, WebSocketSession> clients = new ConcurrentHashMap<String, WebSocketSession>();
// 静态变量,用来记录当前在线连接数
private static final AtomicInteger connectCount = new AtomicInteger(0);
/***********
/**
* 连接建立成功后的回调
*/
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
logger.info("wescoket成功建立连接");
// 页面标识
String pageFlag = getAttributeFlag(session,this.CLIENT_ID);
// 初始化动作标识
String reqAction = getAttributeFlag(session,this.ACTION_INIT);
// 返回结果
int code = WebSocketStatus.CODE_FAIL;
String msg = WebSocketStatus.MSG_FAIL;
String returnJson = "";
if (!StringUtils.isEmpty(pageFlag)) {
// 连接数加一,为了保证多个同页面标识的请求能被处理
addOnlineCount();
int onlineCount = getOnlineCount();
String key = pageFlag+"_"+onlineCount;
//管理已连接的session
clients.put(key, session);
logger.info("在线屏数:"+onlineCount);
// 从数据库里查询需要信息返回
code = WebSocketStatus.CODE_SUCCESS;
msg = WebSocketStatus.MSG_SUCCESS;
// 查询数据库得到type
String type = WebSocketStatus.TYPE_BDXW;
if (reqAction.toLowerCase().equals(WebSocketStatus.ACTION_SIMPLE)) {
// DB基本数据
logger.info("数据库查询【"+pageFlag+"】的基本数据");
Map<String, Object> infoMap = new HashMap<String, Object>();
infoMap.put("type", "qwzx");
infoMap.put("title", "全网资讯");
returnJson = this.makeInfoResponseJson(code, type,reqAction, msg, infoMap);
} else if (reqAction.toLowerCase().equals(WebSocketStatus.ACTION_DETAIL)) {
// DB数据列表
logger.info("数据库查询【"+pageFlag+"】的列表数据");
int totalCount = 1;
List<Object> userList = new ArrayList<Object>();
User user1 = new User();
user1.setAddress("address 1");
user1.setAge(18);
user1.setId(1);
user1.setName("name 1");
userList.add(user1);
returnJson = this.makeListResponseJson(code, type,reqAction, msg, totalCount,userList);
} else {
code = WebSocketStatus.CODE_FAIL;
msg = WebSocketStatus.MSG_FAIL;
logger.error("客户端请求的action为:"+reqAction);
}
// 返回信息
TextMessage returnMessage = new TextMessage(returnJson);
session.sendMessage(returnMessage);
} else {
session.sendMessage(new TextMessage("无页面标识,连接关闭!"));
session.close();
}
}
/**
* 接收消息处理
* 客户端发送消息需遵循的格式:
{
"pageFlag": "p1",
"actionFlag": "simple/detail"
}
*/
@Override
public void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
long start = System.currentTimeMillis();
// 返回结果
int code = WebSocketStatus.CODE_FAIL;
String msg = WebSocketStatus.MSG_FAIL;
String returnJson = "";
//接收终端发过来的消息
String reqMsg = message.getPayload();
// 根据页面标识进行逻辑处理,提取需要的数据
if (!StringUtils.isEmpty(msg)) {
JSONObject terminalMsg = JSONObject.parseObject(reqMsg);
if (!terminalMsg.isEmpty()) {
if (terminalMsg.containsKey("pageFlag") && terminalMsg.containsKey("actionFlag")) {
//pageFlag
String reqPageFlag = terminalMsg.getString("pageFlag");
String reqAction = terminalMsg.getString("actionFlag");
// 从数据库里查询需要信息返回
code = WebSocketStatus.CODE_SUCCESS;
msg = WebSocketStatus.MSG_SUCCESS;
// 查询数据库得到type
String type = WebSocketStatus.TYPE_BDXW;
if (reqAction.toLowerCase().equals(WebSocketStatus.ACTION_SIMPLE)) {
// DB基本数据
logger.info("数据库查询【"+reqPageFlag+"】的基本数据");
Map<String, Object> infoMap = new HashMap<String, Object>();
infoMap.put("type", "qwzx");
infoMap.put("title", "全网资讯");
returnJson = this.makeInfoResponseJson(code, type,reqAction, msg, infoMap);
} else if (reqAction.toLowerCase().equals(WebSocketStatus.ACTION_DETAIL)) {
// DB数据列表
logger.info("数据库查询【"+reqPageFlag+"】的列表数据");
int totalCount = 1;
List<Object> userList = new ArrayList<Object>();
User user1 = new User();
user1.setAddress("address 1");
user1.setAge(18);
user1.setId(1);
user1.setName("name 1");
userList.add(user1);
returnJson = this.makeListResponseJson(code, type,reqAction, msg, totalCount,userList);
} else {
code = WebSocketStatus.CODE_FAIL;
msg = WebSocketStatus.MSG_FAIL;
logger.error("客户端请求的action为:"+reqAction);
}
}
} else {
logger.error("客户端请求的消息转换json为空");
}
} else {
logger.error("客户端请求的消息为空");
}
// 返回信息
TextMessage returnMessage = new TextMessage(returnJson);
long pass = System.currentTimeMillis() - start;
logger.info("接收终端请求返回:" + returnMessage.toString()+",耗时:"+pass+"ms");
// 向终端发送信息
session.sendMessage(returnMessage);
}
/**
* 出现异常时的回调
*/
@Override
public void handleTransportError(WebSocketSession session, Throwable thrwbl) throws Exception {
if(session.isOpen()){
session.close();
}
logger.info("websocket 连接出现异常准备关闭");
}
/**
* 连接关闭后的回调
*/
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus cs) throws Exception {
// 连接数减1
for (Entry<String, WebSocketSession> entry : clients.entrySet()) {
String clientKey = entry.getKey();
WebSocketSession closeSession = entry.getValue();
if(closeSession == session){
logger.info("移除clientKey:"+clientKey);
clients.remove(clientKey);
decOnlineCount();
int leftOnlineCount = getOnlineCount();
logger.info("剩余在线屏数:"+leftOnlineCount);
}
}
logger.info("websocket 连接关闭了");
}
@Override
public boolean supportsPartialMessages() {
return false;
}
/**
* 发送信息给指定页面
* @param clientId
* @param message
* @return
*/
public boolean sendMessageToPage(String pageFlag, TextMessage message) {
boolean flag = false;
int all_counter = 0;
int send_counter = 0;
long start = System.currentTimeMillis();
if(!StringUtils.isEmpty(pageFlag)){
for (Entry<String, WebSocketSession> entry : clients.entrySet()) {
String clientKey = entry.getKey();
// 给所有以此id标识开头的终端发送消息
if(clientKey.startsWith(pageFlag)){
all_counter++;
WebSocketSession session = entry.getValue();
if (!session.isOpen()) {
flag = false;
} else {
try {
session.sendMessage(message);
send_counter++;
flag = true;
logger.info("sendMessageToPage:[clientKey:"+clientKey+"],flag:"+flag);
} catch (IOException e) {
e.printStackTrace();
flag = false;
}
}
}
}
}
long pass = System.currentTimeMillis() - start;
logger.info("sendMessageToPage:"+pageFlag+",flag:"+flag+",all_counter:"+all_counter+",send_counter:"+send_counter+",pass:"+pass+"ms");
return flag;
}
/**
* 发送信息给所有页面
* @param clientId
* @param message
* @return
*/
public boolean sendMessageToAll(TextMessage message) {
boolean flag = false;
int all_counter = 0;
int send_counter = 0;
long start = System.currentTimeMillis();
for (Entry<String, WebSocketSession> entry : clients.entrySet()) {
all_counter++;
String clientKey = entry.getKey();
WebSocketSession session = entry.getValue();
if (!session.isOpen()) {
flag = false;
} else {
try {
session.sendMessage(message);
flag = true;
send_counter++;
logger.info("sendMessageToAll:[clientKey:"+clientKey+"],flag:"+flag);
} catch (IOException e) {
e.printStackTrace();
flag = false;
}
}
}
long pass = System.currentTimeMillis() - start;
logger.info("sendMessageToAll,flag:"+flag+",all_counter:"+all_counter+",send_counter:"+send_counter+",pass:"+pass+"ms");
return flag;
}
/**
* 给指定的精准发送消息
* @param message
* @param toUser
* @throws IOException
*/
public boolean sendMessageToId(String clientId,TextMessage message) throws IOException {
boolean flag = false;
int all_counter = 0;
int send_counter = 0;
long start = System.currentTimeMillis();
if(!StringUtils.isEmpty(clientId)){
all_counter++;
WebSocketSession session = clients.get(clientId);
if (!session.isOpen()) {
flag = false;
} else {
try {
session.sendMessage(message);
flag = true;
send_counter++;
} catch (IOException e) {
e.printStackTrace();
flag = false;
}
}
}
long pass = System.currentTimeMillis() - start;
logger.info("sendMessageToId:"+clientId+",flag:"+flag+",all_counter:"+all_counter+",send_counter:"+send_counter+",pass:"+pass+"ms");
return flag;
}
/**
* 获取参数标识
* @param session
* @return
*/
private String getAttributeFlag(WebSocketSession session,String flagName) {
String flag = null;
try {
flag = (String) session.getHandshakeAttributes().get(flagName);
} catch (Exception e) {
logger.error(e.getMessage());
}
return flag;
}
/**
* 当前连接数
* @return
*/
private synchronized int getOnlineCount() {
return connectCount.get();
}
/**
* 新增连接数
*/
private synchronized void addOnlineCount() {
connectCount.getAndIncrement();
}
/**
* 减连接数
*/
private synchronized void decOnlineCount() {
connectCount.getAndDecrement();
}
/**
* 生成列表响应json
* @param code 状态码
* @param type 数据类型
* @param action 操作类选
* @param msg 提示信息
* @param totalCount 总数量
* @param dataList 数据列表
* @return json
*/
public synchronized String makeListResponseJson(int code,String type,String action,String msg,int totalCount,List<Object> dataList){
JSONObject jsonObj = new JSONObject();
jsonObj.put("code", code);
jsonObj.put("type", type);
jsonObj.put("action", action);
jsonObj.put("msg", msg);
JSONObject contentObj = new JSONObject();
contentObj.put("totalCount", totalCount);
JSONArray listArray = new JSONArray(dataList);
contentObj.put("list", listArray);
jsonObj.put("body", contentObj);
logger.info("生成list json:" + jsonObj.toString());
return jsonObj.toString();
}
/**
* 生成详情响应json
* @param code 状态
* @param type 数据类型
* @param action 操作类型
* @param msg 提示消息
* @param info 数据详情
* @return json
*/
public synchronized String makeInfoResponseJson(int code,String type,String action,String msg,Object info){
JSONObject jsonObj = new JSONObject();
jsonObj.put("code", code);
jsonObj.put("type", type);
jsonObj.put("action", action);
jsonObj.put("msg", msg);
jsonObj.put("body", info);
logger.info("生成info json:" + jsonObj.toString());
return jsonObj.toString();
}
}
状态辅助类
在消息处理类中用到了一些状态码、下发消息等静态变量主要是为了和客户端交互时定义好消息格式的。这个类不一定需要。
public class WebSocketStatus {
/*********************状态码 开始**********************/
//需要根据业务具体情况扩展状态码
// 处理成功
public static final int CODE_SUCCESS = 200;
// 处理失败
public static final int CODE_FAIL = 200;
/*********************状态码 结束**********************/
/*********************信息 开始**********************/
//需要根据业务具体情况扩展信息
// 处理成功
public static final String MSG_SUCCESS = "OK";
// 处理失败
public static final String MSG_FAIL = "FAIL";
/*********************信息 结束**********************/
/*********************数据类型 开始**********************/
// 全网热点
public static final String TYPE_QWRD = "qwrd";
// 本地新闻
public static final String TYPE_BDXW = "bdxw";
// 网络热搜
public static final String TYPE_WLRS = "wlrs";
// 地方舆论
public static final String TYPE_DFYL = "dfyl";
// 新闻选题
public static final String TYPE_XWXT = "xwxt";
// 外采调度
public static final String TYPE_WCDD = "wcdd";
// 生产力统计
public static final String TYPE_SCLTJ = "scltj";
// 影响力统计
public static final String TYPE_YXLTJ = "yxltj";
// 任务统计
public static final String TYPE_RWTJ = "rwtj";
// 资讯热榜
public static final String TYPE_ZXRB = "zxrb";
// 视频热榜
public static final String TYPE_SPRB = "sprb";
// 列表自定义
public static final String TYPE_LBZDY = "lbzdy";
// 图表自定义
public static final String TYPE_TBZDY = "tbzdy";
/*********************数据类型 结束**********************/
/*********************动作类型 开始**********************/
// 基本信息
public static final String ACTION_SIMPLE = "simple";
// 详情信息
public static final String ACTION_DETAIL = "detail";
/*********************动作类型 开始**********************/
}
控制器中调用
这里主要是模拟了控制器中由于某个动作需要触发给指定的session发送消息。
@Controller
@RequestMapping("/testController")
public class TestController {
public static final Logger LOGGER = Logger.getLogger(TestController.class);
@Autowired
private TestService testService;
@Autowired
private WebScoketHandler handler;
@RequestMapping("/test")
public void test(HttpServletRequest request, HttpServletResponse response) {
try {
Map<String, Object> infoMap = new HashMap<String, Object>();
infoMap.put("type", "qwzx");
infoMap.put("title", "全网资讯");
TextMessage infoMessage = new TextMessage(handler.makeInfoResponseJson(WebSocketStatus.CODE_SUCCESS, WebSocketStatus.TYPE_QWRD,WebSocketStatus.ACTION_SIMPLE, WebSocketStatus.MSG_SUCCESS, infoMap));
int totalCount = 3;
User user1 = new User();
user1.setAddress("address 1");
user1.setAge(18);
user1.setId(1);
user1.setName("name 1");
User user2 = new User();
user2.setAddress("address 2");
user2.setAge(18);
user2.setId(1);
user2.setName("name 2");
User user3 = new User();
user3.setAddress("address 3");
user3.setAge(18);
user3.setId(1);
user3.setName("name 3");
List<Object> userList = new ArrayList<Object>();
userList.add(user1);
userList.add(user2);
userList.add(user3);
TextMessage listMessage = new TextMessage(handler.makeListResponseJson(WebSocketStatus.CODE_SUCCESS, WebSocketStatus.TYPE_QWRD,WebSocketStatus.ACTION_DETAIL, WebSocketStatus.MSG_SUCCESS, totalCount,userList));
String pageFlag = "p1";
//向所有打开P1的浏览器发送消息
boolean sendFlag1 = this.handler.sendMessageToPage(pageFlag, infoMessage);
System.out.println("sendFlag1:"+sendFlag1);
response.getWriter().print(sendFlag1);
boolean sendFlag2 = this.handler.sendMessageToPage(pageFlag, listMessage);
System.out.println("sendFlag1:"+sendFlag2);
response.getWriter().print(sendFlag2);
} catch (IOException e) {
e.printStackTrace();
} catch (Exception e) {
e.printStackTrace();
}
}
}