Spring Boot WebSocket介绍(二)之拦截器

本文深入探讨SpringBoot中WebSocket拦截器的功能,演示如何通过拦截器修改消息内容和验证连接,提供了详细的代码示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

介绍完Spring Boot WebSocket介绍(一)后,我们在看看更高级的WebSocket拦截器功能。 阅读本文前,请先阅读Spring Boot WebSocket介绍(一)

注意事项:本文代码是为了介绍拦截过程,示例代码并非可直接用于生产的代码。仅供参看而已。

拦截器

顾名思义,可以拦截websocket请求,拦截后我们可以做各种行为的改变,本文以收到MESSAGE类型的消息后,对消息的内容进行改变为例。 首先我们假设发送到websocket的消息是json格式,然后我们加上,并获取‘’yqfield1’的内容。同时我们也可以针对websocket连接进行验证,比如不符合一定要求的connect就直接reject,不允许连接websockt

    jsonObj.put(logFlag + "ChannelContent2", "add to");
    String value = jsonObj.getString("myfield1");
    jsonObj.put("myfield1", logFlag + " add to " + value);

代码

完整的代码在这里,欢迎加星和fork。

package com.yq.config;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.converter.SimpleMessageConverter;
import org.springframework.messaging.converter.SmartMessageConverter;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ChannelInterceptorAdapter;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.stereotype.Service;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;

import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

@Configuration
@EnableWebSocketMessageBroker
@Service
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
    private static final Logger log = LoggerFactory.getLogger(WebSocketConfig.class);

    @Autowired
    private SimpMessagingTemplate messagingTemplate;

    @Override
    public void registerStompEndpoints(StompEndpointRegistry registry) {
        //开启/myEndPoint端点
        registry.addEndpoint("/myEndPoint")
                //允许跨域访问
                .setAllowedOrigins("*")
                //使用sockJS
                .withSockJS();
    }

    @Override
    public void configureMessageBroker(MessageBrokerRegistry registry) {
        //订阅topic的前缀,比订阅topic必须是这种格式// /myPrefixes/topic1/xxx
        registry.enableSimpleBroker("/myPrefixes");
    }


    @Override
    public void configureClientInboundChannel(ChannelRegistration registration) {
        ChannelInterceptor interceptor = new ChannelInterceptorAdapter() {
            @Override
            public Message<?> preSend(Message<?> message, MessageChannel channel) {
                log.info("Inbound preSend. message={}", message);
                StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message);
                MessageHeaders header = message.getHeaders();
                String sessionId = (String)header.get("simpSessionId");
                if (accessor != null && accessor.getCommand() !=null && accessor.getCommand().getMessageType() != null) {
                    SimpMessageType type = accessor.getCommand().getMessageType();
                    if (accessor!= null && SimpMessageType.CONNECT.equals(type)) {
                        String jwtToken = accessor.getFirstNativeHeader("AuthToken");

                        if(StringUtils.isNotBlank(jwtToken)) {
                            log.info("Inbound preSend: sessionId={}, jwtToken={}", sessionId, jwtToken);
                        }
                        else {
                            log.error("no token, will be disallowed to connect.");
                            return null;
                        }
                    }else if (type == SimpMessageType.DISCONNECT) {
                        log.info("Inbound sessionId={} is disconnected", sessionId);
                    }else if (type == SimpMessageType.SUBSCRIBE) {
                        String topicDest = (String)header.get("simpDestination");
                        log.info("subscribe topicDest={}, message={} SUBSCRIBE", topicDest, message);
                    } else if (type == SimpMessageType.MESSAGE) {
                        String topicDest = (String)header.get("simpDestination");
                        log.info("之前的消息 topicDest={}, message={} MESSAGE", topicDest, message);
                        message = UpdateMessage(message, "Inbound");
                        log.info("之后的消息e topicDest={}, message={} MESSAGE", topicDest, message);
                    }
                }

                return message;
            }
            @Override
            public boolean preReceive(MessageChannel channel) {
                log.info("Inbound preReceive. channel={}", channel);
                return true;
            }

            @Override
            public Message<?> postReceive(Message<?> message, MessageChannel channel) {
                log.info("Inbound postReceive. message={}", message);
                return message;
            }

            @Override
            public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
                log.info("Inbound postSend. message={}", message);
            }

            @Override
            public void afterSendCompletion(Message<?> message, MessageChannel channel, boolean sent, @Nullable Exception ex) {
                log.info("Inbound afterSendCompletion. message={}", message);
                StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message);
                MessageHeaders header = message.getHeaders();
                if (accessor != null && accessor.getCommand() !=null && accessor.getCommand().getMessageType() != null) {
                    SimpMessageType type = accessor.getCommand().getMessageType();
                    if (type == SimpMessageType.SUBSCRIBE) {
                        String topicDest = (String)header.get("simpDestination");
                       log.info("afterSenfCompletion. topicDest={}, message={} SUBSCRIBE", topicDest, message);

                        String payload = "{\"myfield1\":\"afterSendCompletion初始化消息\"}";
                        messagingTemplate.convertAndSend(topicDest, payload);
                        log.info("send complete. topic={}", topicDest);
                    }
                }
            }

            @Override
            public void afterReceiveCompletion(@Nullable Message<?> message, MessageChannel channel, @Nullable Exception ex) {
                log.info("Inbound afterReceiveCompletion. message={}", message);
            }
        };

        registration.interceptors(interceptor);
    }

    @Override
    public void configureClientOutboundChannel(ChannelRegistration registration) {
        ChannelInterceptor interceptor = new ChannelInterceptorAdapter() {
            @Override
            public boolean preReceive(MessageChannel channel) {
                log.info("Outbound preReceive: channel={}", channel);
                return true;
            }

            @Override
            public Message<?> preSend(Message<?> message, MessageChannel channel) {
                log.info("Outbound preSend: message={}", message);
                return message;
            }

            @Override
            public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
                log.info("Outbound postSend. message={}", message);
            }

            @Override
            public Message<?> postReceive(Message<?> message, MessageChannel channel) {
                log.info("Outbound postReceive. message={}", message);
                return message;
            }

            @Override
            public void afterSendCompletion(Message<?> message, MessageChannel channel, boolean sent, @Nullable Exception ex) {
                log.info("Outbound afterSendCompletion. message={}", message);
            }

            @Override
            public void afterReceiveCompletion(@Nullable Message<?> message, MessageChannel channel, @Nullable Exception ex) {
                log.info("Outbound afterReceiveCompletion. message={}", message);
            }
        };

        registration.interceptors(interceptor);
    }

    private Message<?> UpdateMessage(Message<?> message, String logFlag) {
        log.info(logFlag + " preSend: message={}", message);
        MessageHeaders header = message.getHeaders();
        Object obj = message.getPayload();
        //一般都是byte[]
        JSONObject jsonObj = null;
        String strUTF8 = null;
        String strJsonUTF8 = null;
        Message<?> msg = null;
        try {
            strUTF8 = new String((byte[])obj,"UTF-8");
            jsonObj = JSON.parseObject(strUTF8);
            jsonObj.put(logFlag + "ChannelContent2", "add to");
            String value = jsonObj.getString("myfield1");
            jsonObj.put("myfield1", logFlag + " add to " + value);
            strJsonUTF8 = jsonObj.toJSONString();
            byte[] msgToByte = strJsonUTF8.getBytes("UTF-8");
            msg = new GenericMessage<>(msgToByte, header);
        }
        catch (Exception ex) {
            log.info("(byte[] to string exception. ex={}", ex.getLocalizedMessage());
        }

        if (msg != null) {
            log.info(logFlag + " preSend Modified: message={}, strUTF8={}, strJsonUTF8={}", msg, strUTF8, strJsonUTF8);
            return msg;
        }
        else {
            log.info(logFlag + " preSend Original: message={}, strUTF8={}", message, strUTF8);
            return message;
        }
    }
//"simpMessageType" -> "MESSAGE"
    Message sendInitMsg(Message<?> oldMessage, String dest, Map<String, Object> headers, String payload) {
        MessageHeaders messageHeaders = null;
        Object conversionHint = headers != null?headers.get("conversionHint"):null;
        Map<String, Object> headersToUse = new HashMap<>();
        headersToUse.put("simpMessageType", SimpMessageType.MESSAGE);
        headersToUse.put("destination", dest);
        headersToUse.put("contentType", "text/plain;charset=UTF-8");
        headersToUse.put("stompCommand", "SEND");

        Map<String, Object> nativeHeaders = new LinkedHashMap<>();
        nativeHeaders.put("id", "sub-0");
        nativeHeaders.put("destination", dest);
        headersToUse.put("nativeHeaders", nativeHeaders);
        messageHeaders = new MessageHeaders(headersToUse);
        MessageHeaders oldHeaders = oldMessage.getHeaders()

        MessageConverter converter = new SimpleMessageConverter();
        Message<?> message = converter instanceof SmartMessageConverter ?((SmartMessageConverter)converter).toMessage(payload, messageHeaders, conversionHint):converter.toMessage(payload, messageHeaders);
        return message;
    }
}

效果截图

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值