2022-10-31 拦截器过滤 GET POST 请求中的非法字符,包括可能出现的跨域问题

思路

  1. 配置拦截器
  2. 遇到问题:拦截器没法直接抛出异常给 controllerAdvance 去捕捉
  3. 找到 HandlerExceptionResolver 这个实体类
  4. 使用spring 获取注入的bean:handlerExceptionResolver(至于这个东西是怎么找到全局的advance还需要后续确认,反正是能找到自己定义的全局 advance)
  5. 在校验不通过的地方抛出 runtimeException
  6. 原地try catch 然后用 resolver.resolveException(request/requestWrapper,response,null,e) 向外抛出异常,那个为null的参数必为null,否则就是用spring默认的错误处理了

具体配置

配置拦截器 CharCheckFilter

public class CharCheckFilter implements Filter
{
        private final static Logger log = LoggerFactory.getLogger(CharCheckFilter.class);


        private HandlerExceptionResolver resolver = null;

        @Override
        public void init(FilterConfig filterConfig) throws ServletException
        {
                log.info("init");
        }


        @Override
        public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException
        {

                HttpServletRequest request = (HttpServletRequest) servletRequest;
                HttpServletResponse response = (HttpServletResponse) servletResponse;

                String uri = request.getRequestURI();
                String method = request.getMethod().toUpperCase();

                if (uri.contains("/auth")) { // 鉴权接口不做处理
                        filterChain.doFilter(request, response);
                        return;
                }
                if ("POST".equals(method)) { // post 请求需要处理跨域
                        // 请求参数解析
                        XyRequestWrapper requestWrapper = new XyRequestWrapper(request);
                        try {
                                // 跨域解决,post请求跟get不同,进来就得重设,要不直接跨域
                                response.setHeader("Access-Control-Allow-Origin", "*");
                                response.setHeader("Access-Control-Allow-Credentials", "true");
                                response.addHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE");
                                response.setHeader("Access-Control-Max-Age", "3600");
                                response.setHeader("Access-Control-Allow-Headers", "X-Requested-With,Authorization,Content-Type");

                                String body = requestWrapper.getBody();
                                if (StringUtils.isNotBlank(body) && body.contains("%")) {
                                        log.info("POST请求存在非法字符 % , uri={},入参={}",uri,body);
                                        resolver = (HandlerExceptionResolver)TestSpringUtil.getBean("handlerExceptionResolver");
                                        throw new RuntimeException("非法字符 %");
                                }
                                filterChain.doFilter(requestWrapper, response);
                                return;
                        } catch(RuntimeException e) {
                                resolver.resolveException(requestWrapper, response, null,e);//交给全局异常处理类处理
                                return;
                        }
                }
                if ("GET".equals(method)) { // GET 请求不用处理跨域
                        try {
                                Map<String, String[]> parameterMap = request.getParameterMap();
                                Set<Map.Entry<String, String[]>> entries = parameterMap.entrySet();
                                Iterator<Map.Entry<String, String[]>> iterator = entries.iterator();
                                while (iterator.hasNext()) {
                                        Map.Entry<String, String[]> next = iterator.next();
                                        String key = next.getKey();
                                        String value = next.getValue()[0];
                                        if (key.contains("%") || value.contains("%")) {
                                                log.info("GET请求存在非法字符 % , uri={},入参key={},入参value={}",uri,key,value);
                                                resolver = (HandlerExceptionResolver)TestSpringUtil.getBean("handlerExceptionResolver");
                                                throw new RuntimeException("非法字符 %");
                                        }
                                }
                                filterChain.doFilter(request, response);
                                return;
                        } catch(Exception e) {

                                // 如果出现非法字符,这个请求需要抛弃,如果开头就设置跨域会有问题,所以在确认抛弃请求时,设置跨域
                                response.setHeader("Access-Control-Allow-Origin", "*");
                                response.setHeader("Access-Control-Allow-Credentials", "true");
                                response.addHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE");
                                response.setHeader("Access-Control-Max-Age", "3600");
                                response.setHeader("Access-Control-Allow-Headers", "X-Requested-With,Authorization,Content-Type");
                                resolver.resolveException(request, response, null, e);//交给全局异常处理类处理
                                return;
                        }
                }

                filterChain.doFilter(request, response);
                return;

        }

        @Override
        public void destroy() {

        }

}

配置request 包装类,防止获取一次inputStream就挂了

import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Charsets;
import org.apache.commons.lang3.StringUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;

public class XyRequestWrapper extends HttpServletRequestWrapper {

    private String body;
    public XyRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        StringBuilder stringBuilder = new StringBuilder();
        BufferedReader bufferedReader = null;
        try {
            InputStream inputStream = request.getInputStream();
            if (inputStream != null) {
                bufferedReader = new BufferedReader(new InputStreamReader(inputStream,"UTF-8"));
                char[] charBuffer = new char[128];
                int bytesRead = -1;
                while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
                    stringBuilder.append(charBuffer, 0, bytesRead);
                }
            } else {
                stringBuilder.append("");
            }
        } catch (IOException ex) {
            throw ex;
        } finally {
            if (bufferedReader != null) {
                try {
                    bufferedReader.close();
                } catch (IOException ex) {
                    throw ex;
                }
            }
        }
        body = stringBuilder.toString();
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(body.getBytes("UTF-8"));
        ServletInputStream servletInputStream = new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() throws IOException {
                return byteArrayInputStream.read();
            }
        };
        return servletInputStream;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(this.getInputStream(), Charsets.UTF_8));
    }

    public String getBody() {
        return this.body;
    }

    @Override
    public String getParameter(String name) {
        return super.getParameter(name);
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        return super.getParameterMap();
    }

    @Override
    public Enumeration<String> getParameterNames() {
        return super.getParameterNames();
    }

    @Override
    public String[] getParameterValues(String name) {
        return super.getParameterValues(name);
    }

    /**
     * 设置自定义post参数 //
     *
     * @param paramMaps
     * @return
     */
    public void setParamsMaps(Map paramMaps) {
        Map paramBodyMap = new HashMap();
        if (!StringUtils.isEmpty(body)) {
            paramBodyMap = JSONObject.parseObject(body, Map.class);
        }
        paramBodyMap.putAll(paramMaps);
        body = JSONObject.toJSONString(paramBodyMap);
    }
}

配置spring工具类,用于获取 拦截器里外抛异常的 HandlerExceptionResolver 对象

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

@Component
public class TestSpringUtil implements ApplicationContextAware {

    private static ApplicationContext applicationContext;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        if (TestSpringUtil.applicationContext == null) {
            TestSpringUtil.applicationContext = applicationContext;
        }
    }

    //获取applicationContext
    public static ApplicationContext getApplicationContext() {
        return applicationContext;
    }

    //通过name获取 Bean.
    public static Object getBean(String name) {
        return getApplicationContext().getBean(name);
    }

    //通过class获取Bean.
    public static <T> T getBean(Class<T> clazz) {
        return getApplicationContext().getBean(clazz);
    }

    //通过name,以及Clazz返回指定的Bean
    public static <T> T getBean(String name, Class<T> clazz) {
        return getApplicationContext().getBean(name, clazz);
    }
}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值