SpringBoot拦截所有请求校验特殊字符

项目场景:

springboot项目中,需要对所有请求数据进行过滤,拦截特殊字符。
如 ‘|’、‘&’、‘;’、‘$’、‘%’、‘@’、‘ ’ ’、‘ " ’、‘ ’ ’、‘ " ’、‘<>’、‘ () ’、‘ + ’、‘\n’、‘ \r ’、‘ , ’、‘ \ ’、‘…/’、(目录遍历类漏洞要求过滤‘…/’、‘/’)‘/’等等

文章背景

大部分帖子涉及以下问题:
1.编写自定义的HttpServlet会导致所有文件上传的接口不可用
2.处理post请求时不能处理复杂json,只能处理简单的json,这点拓展性不够好
例如
{“name”:“名称1”,“algorithmId”:“3815009226752”,“totalAmount”:“100”,“effectTime”:[“2023-09-05”,“2023-10-05”],“description”:“备注1”,“dataMarketList”:[{“dataMarketUuid”:“”,“origin”:“2”,“datasetList”:[{“nodeId”:“ds02”,“datasetCode”:[“2705385631744”]}]}]}
大部分帖子的工具类都无法处理
3.特殊字符包含的内容过于简单

版本

JDK17
SpringBoot3.0

解决方案:

1.注册拦截器

@Configuration
public class WebMvcConfiguration implements WebMvcConfigurer {
 @Override
    public void addInterceptors(InterceptorRegistry registry) {
            registry.addInterceptor(getCommonApiAuthFilter()).addPathPatterns("/**");
     }
    @Bean
    public CommonpiAuthInterceptor getCommonApiAuthFilter() {
        return new CommonpiAuthInterceptor();
    }

}

2.注册过滤器

@Component
@WebFilter(filterName = "specialCharFilter", urlPatterns = "/*")
public class SpecialCharFilter implements Filter {
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        if (request instanceof HttpServletRequest httpServletRequest) {
            String contentType = httpServletRequest.getContentType();
            if (StringUtils.isNotEmpty(contentType) && contentType.contains("multipart/form-data")) {
                chain.doFilter(new StandardServletMultipartResolver().resolveMultipart(httpServletRequest), response);
            } else {
                chain.doFilter(new SpecialCharHttpServletRequestWrapper(httpServletRequest), response);
            }
            return;
        }
        chain.doFilter(request, response);
    }

    @Override
    public void destroy() {
    }
}

3.编写自定义的HttpServletRequestWrapper

public class SpecialCharHttpServletRequestWrapper extends HttpServletRequestWrapper {
    public final HttpServletRequest request;
    private final String bodyStr;

    public SpecialCharHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        this.request = request;
        this.bodyStr = getBodyString();
    }

    /**
     * 获取请求Body
     * @return
     */
    public String getBodyString() {
        StringBuilder sb = new StringBuilder();
        InputStream inputStream = null;
        BufferedReader reader = null;
        try {
            inputStream = cloneInputStream(request.getInputStream());
            reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
            String line = "";
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        finally {
            if (inputStream != null) {
                try {
                    inputStream.close();
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (reader != null) {
                try {
                    reader.close();
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        return sb.toString();
    }


    /**
     * 复制输入流
     * @param inputStream 输入流
     * @return
     */
    public InputStream cloneInputStream(ServletInputStream inputStream) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] buffer = new byte[1024];
        int len;
        try {
            while ((len = inputStream.read(buffer)) > -1) {
                byteArrayOutputStream.write(buffer, 0, len);
            }
            byteArrayOutputStream.flush();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
        return byteArrayInputStream;
    }


    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }


    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream bais = new ByteArrayInputStream(bodyStr.getBytes(Charset.forName("UTF-8")));
        return new ServletInputStream() {

            @Override
            public int read() throws IOException {
                return bais.read();
            }

            @Override
            public void setReadListener(ReadListener listener) {

            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public boolean isFinished() {
                return false;
            }
        };
    }
}


4.校验特殊字符的工具类

public class ParamCheckUtils {

    /**
     * 特殊字符正则表达式
     */

    private final static String REG_EX = "[`~!@#$%^*()+\n\r|{}\\[\\]<>/?!()【】‘;:”“’。,、\\\\]";

    /**
     * 判断url中是否含有特殊字符
     *
     * @param urls 前端请求链接
     * @return 是否包含特殊字符
     */
    public static boolean checkSpecials(String urls) {
        try {
            if (StringUtils.isNotEmpty(urls)) {
                // url参数转义
                urls = URLDecoder.decode(urls, "utf-8");
                if (Pattern.compile(REG_EX).matcher(urls).find()) {
                    return true;
                }
            }
        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
        }
        return false;
    }

    /**
     * 判断formData值对象中是否包含特殊字符
     *
     * @param map formData值对象
     * @return 是否包含特殊字符
     */
    public static boolean checkSpecials(Map<String, String[]> map) {
        if (!map.isEmpty()) {
            for (String[] paraArray : map.values()) {
                for (String paraStr : paraArray) {
                    if (Pattern.compile(REG_EX).matcher(paraStr).find()) {
                        return true;
                    }
                }
            }
        }
        return false;
    }

    /**
     * 判断前端传过来的json和json数组中是否含有特殊字符
     *
     * @param request 前端请求(包含json数据)
     * @return 是否包含特殊字符
     */
    public static boolean checkSpecials(HttpServletRequest request) {
        try {
            SpecialCharHttpServletRequestWrapper wrapper = new SpecialCharHttpServletRequestWrapper(request);
            InputStream is = wrapper.getInputStream();

            BufferedReader br = new BufferedReader(new InputStreamReader(is));
            StringBuilder sb = new StringBuilder();
            String line;
            while ((line = br.readLine()) != null) {
                sb.append(line);
            }
            String context = sb.toString();
            if (StringUtils.isEmpty(context)) {
                return false;
            }
            boolean isIllegalStr = checkJsonIsIllegal(context);
            if (isIllegalStr) {
                return true;
            }

        } catch (IOException e) {
            e.printStackTrace();
        }
        return false;
    }

    private static boolean checkJsonIsIllegal(String objectStr) {
        Map<String, Object> result = JsonUtils.parse2Map(objectStr);
        for (Map.Entry<String, Object> entry : result.entrySet()) {
            Object value1 = entry.getValue();
            if(Objects.isNull(value1) || value1.toString().contains("import")){
                continue;
            }
            if (Pattern.compile(REG_EX).matcher(value1.toString()).find()) {
                return true;
            }
        }
        return false;
    }


}

5.Json工具类

使用fastjson

public class JsonUtils {
    public static String prettyJson(String json) {
        if (StringUtils.isBlank(json)) {
            return json;
        }
        JSONObject jsonObject;
        try {
            jsonObject = JSONObject.parseObject(json);
        } catch (Exception e) {
            return json;
        }
        return JSONObject.toJSONString(jsonObject, true);
    }

    public static Map<String, Object> parse2Map(String jsonString) {
        Map<String, Object> result = new LinkedHashMap<>();
        JSONObject jsonObject = JSON.parseObject(jsonString);
        parseObject(jsonObject, "", result);
        return result;
    }

    private static void parseObject(JSONObject jsonObject, String parentKey, Map<String, Object> result) {
        if (Objects.isNull(jsonObject)) {
            return;
        }
        for (Map.Entry<String, Object> entry : jsonObject.entrySet()) {
            String key = entry.getKey();
            Object value = entry.getValue();
            String fullKey = getParentKey(parentKey) + key;
            if (value instanceof JSONObject) {
                parseObject((JSONObject) value, fullKey + ".", result);
            } else if (value instanceof JSONArray) {
                parseArray((JSONArray) value, fullKey + ".", result);
            } else {
                result.put(fullKey, value);
            }
        }
    }

    private static void parseArray(JSONArray jsonArray, String parentKey, Map<String, Object> result) {
        for (int i = 0; i < jsonArray.size(); i++) {
            Object value = jsonArray.get(i);
            String fullKey = getParentKey(parentKey) + i;
            if (value instanceof JSONObject) {
                parseObject((JSONObject) value, fullKey + ".", result);
            } else if (value instanceof JSONArray) {
                parseArray((JSONArray) value, fullKey + ".", result);
            } else {
                result.put(fullKey, value);
            }
        }
    }

    private static String getParentKey(String parentKey) {
        if (parentKey.isEmpty()) {
            return "";
        }
        return parentKey + ".";
    }
}

6.编写拦截器

@Slf4j
public class CommonApiAuthInterceptor implements HandlerInterceptor {

    @Override
    public boolean preHandle(@NotNull HttpServletRequest request, @NotNull HttpServletResponse response, @NotNull Object handler) throws Exception {
        boolean paramIllegal = ParamCheckerChain.getInstance().check(request);
        if (paramIllegal) {
        //这里自定义异常,不再多阐述
            HttpServletUtils.setResponseInfo(response, HttpStatus.OK.value(), CommonErrorCode.PARAM_ILLEGAL.getMsg());
            return false;
        }
}

7.责任链chain

public class ParamCheckerChain {

    private static ParamCheckerChain instance;
    private ParamChecker head;

    private ParamCheckerChain() {
        buildChain();
    }

    public static ParamCheckerChain getInstance() {
        if (instance == null) {
            synchronized (ParamCheckerChain.class) {
                if (instance == null) {
                    instance = new ParamCheckerChain();
                }
            }
        }
        return instance;
    }

    private void buildChain() {
        head = new GetRequestChecker();
        ParamChecker second = new FileRequestChecker();
        ParamChecker third = new PostRequestChecker();
        head.setNextHandler(second);
        second.setNextHandler(third);
    }

    public boolean check(HttpServletRequest request) {
        return head.check(request);
    }
}

8.校验get请求

public class GetRequestChecker extends ParamChecker {
    @Override
    public boolean check(HttpServletRequest request) {
        if (RequestMethod.GET.name().equals(request.getMethod())) {
            if (ParamCheckUtils.checkSpecials(request.getQueryString())) {
                return true;
            }
        } else if (nextHandler != null) {
            return nextHandler.check(request);
        }
        return false;
    }
}

9.校验文件类型的请求

public class FileRequestChecker extends ParamChecker {
    @Override
    public boolean check(HttpServletRequest request) {
        String contentType = request.getContentType();
        if (contentType != null && contentType.contains("multipart/form-data")) {
            MultipartResolver resolver = new StandardServletMultipartResolver();
            MultipartHttpServletRequest multipartRequest = resolver.resolveMultipart(request);
            if (ParamCheckUtils.checkSpecials(multipartRequest.getParameterMap())) {
                return true;
            }
        } else if (nextHandler != null) {
            return nextHandler.check(request);
        }
        return false;
    }
}


10.校验post请求

public class PostRequestChecker extends ParamChecker {
    @Override
    public boolean check(HttpServletRequest request) {
        if (ParamCheckUtils.checkSpecials(request)) {
            return true;
        } else if (nextHandler != null) {
            return nextHandler.check(request);
        }
        return false;
    }
}
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值