项目场景:
springboot项目中,需要对所有请求数据进行过滤,拦截特殊字符。
如 ‘|’、‘&’、‘;’、‘$’、‘%’、‘@’、‘ ’ ’、‘ " ’、‘ ’ ’、‘ " ’、‘<>’、‘ () ’、‘ + ’、‘\n’、‘ \r ’、‘ , ’、‘ \ ’、‘…/’、(目录遍历类漏洞要求过滤‘…/’、‘/’)‘/’等等
文章背景
大部分帖子涉及以下问题:
1.编写自定义的HttpServlet会导致所有文件上传的接口不可用
2.处理post请求时不能处理复杂json,只能处理简单的json,这点拓展性不够好
例如
{“name”:“名称1”,“algorithmId”:“3815009226752”,“totalAmount”:“100”,“effectTime”:[“2023-09-05”,“2023-10-05”],“description”:“备注1”,“dataMarketList”:[{“dataMarketUuid”:“”,“origin”:“2”,“datasetList”:[{“nodeId”:“ds02”,“datasetCode”:[“2705385631744”]}]}]}
大部分帖子的工具类都无法处理
3.特殊字符包含的内容过于简单
版本
JDK17
SpringBoot3.0
解决方案:
1.注册拦截器
@Configuration
public class WebMvcConfiguration implements WebMvcConfigurer {
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(getCommonApiAuthFilter()).addPathPatterns("/**");
}
@Bean
public CommonpiAuthInterceptor getCommonApiAuthFilter() {
return new CommonpiAuthInterceptor();
}
}
2.注册过滤器
@Component
@WebFilter(filterName = "specialCharFilter", urlPatterns = "/*")
public class SpecialCharFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
if (request instanceof HttpServletRequest httpServletRequest) {
String contentType = httpServletRequest.getContentType();
if (StringUtils.isNotEmpty(contentType) && contentType.contains("multipart/form-data")) {
chain.doFilter(new StandardServletMultipartResolver().resolveMultipart(httpServletRequest), response);
} else {
chain.doFilter(new SpecialCharHttpServletRequestWrapper(httpServletRequest), response);
}
return;
}
chain.doFilter(request, response);
}
@Override
public void destroy() {
}
}
3.编写自定义的HttpServletRequestWrapper
public class SpecialCharHttpServletRequestWrapper extends HttpServletRequestWrapper {
public final HttpServletRequest request;
private final String bodyStr;
public SpecialCharHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
this.request = request;
this.bodyStr = getBodyString();
}
/**
* 获取请求Body
* @return
*/
public String getBodyString() {
StringBuilder sb = new StringBuilder();
InputStream inputStream = null;
BufferedReader reader = null;
try {
inputStream = cloneInputStream(request.getInputStream());
reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
String line = "";
while ((line = reader.readLine()) != null) {
sb.append(line);
}
}
catch (IOException e) {
e.printStackTrace();
}
finally {
if (inputStream != null) {
try {
inputStream.close();
}
catch (IOException e) {
e.printStackTrace();
}
}
if (reader != null) {
try {
reader.close();
}
catch (IOException e) {
e.printStackTrace();
}
}
}
return sb.toString();
}
/**
* 复制输入流
* @param inputStream 输入流
* @return
*/
public InputStream cloneInputStream(ServletInputStream inputStream) {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int len;
try {
while ((len = inputStream.read(buffer)) > -1) {
byteArrayOutputStream.write(buffer, 0, len);
}
byteArrayOutputStream.flush();
}
catch (IOException e) {
e.printStackTrace();
}
InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
return byteArrayInputStream;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream bais = new ByteArrayInputStream(bodyStr.getBytes(Charset.forName("UTF-8")));
return new ServletInputStream() {
@Override
public int read() throws IOException {
return bais.read();
}
@Override
public void setReadListener(ReadListener listener) {
}
@Override
public boolean isReady() {
return false;
}
@Override
public boolean isFinished() {
return false;
}
};
}
}
4.校验特殊字符的工具类
public class ParamCheckUtils {
/**
* 特殊字符正则表达式
*/
private final static String REG_EX = "[`~!@#$%^*()+\n\r|{}\\[\\]<>/?!()【】‘;:”“’。,、\\\\]";
/**
* 判断url中是否含有特殊字符
*
* @param urls 前端请求链接
* @return 是否包含特殊字符
*/
public static boolean checkSpecials(String urls) {
try {
if (StringUtils.isNotEmpty(urls)) {
// url参数转义
urls = URLDecoder.decode(urls, "utf-8");
if (Pattern.compile(REG_EX).matcher(urls).find()) {
return true;
}
}
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
return false;
}
/**
* 判断formData值对象中是否包含特殊字符
*
* @param map formData值对象
* @return 是否包含特殊字符
*/
public static boolean checkSpecials(Map<String, String[]> map) {
if (!map.isEmpty()) {
for (String[] paraArray : map.values()) {
for (String paraStr : paraArray) {
if (Pattern.compile(REG_EX).matcher(paraStr).find()) {
return true;
}
}
}
}
return false;
}
/**
* 判断前端传过来的json和json数组中是否含有特殊字符
*
* @param request 前端请求(包含json数据)
* @return 是否包含特殊字符
*/
public static boolean checkSpecials(HttpServletRequest request) {
try {
SpecialCharHttpServletRequestWrapper wrapper = new SpecialCharHttpServletRequestWrapper(request);
InputStream is = wrapper.getInputStream();
BufferedReader br = new BufferedReader(new InputStreamReader(is));
StringBuilder sb = new StringBuilder();
String line;
while ((line = br.readLine()) != null) {
sb.append(line);
}
String context = sb.toString();
if (StringUtils.isEmpty(context)) {
return false;
}
boolean isIllegalStr = checkJsonIsIllegal(context);
if (isIllegalStr) {
return true;
}
} catch (IOException e) {
e.printStackTrace();
}
return false;
}
private static boolean checkJsonIsIllegal(String objectStr) {
Map<String, Object> result = JsonUtils.parse2Map(objectStr);
for (Map.Entry<String, Object> entry : result.entrySet()) {
Object value1 = entry.getValue();
if(Objects.isNull(value1) || value1.toString().contains("import")){
continue;
}
if (Pattern.compile(REG_EX).matcher(value1.toString()).find()) {
return true;
}
}
return false;
}
}
5.Json工具类
使用fastjson
public class JsonUtils {
public static String prettyJson(String json) {
if (StringUtils.isBlank(json)) {
return json;
}
JSONObject jsonObject;
try {
jsonObject = JSONObject.parseObject(json);
} catch (Exception e) {
return json;
}
return JSONObject.toJSONString(jsonObject, true);
}
public static Map<String, Object> parse2Map(String jsonString) {
Map<String, Object> result = new LinkedHashMap<>();
JSONObject jsonObject = JSON.parseObject(jsonString);
parseObject(jsonObject, "", result);
return result;
}
private static void parseObject(JSONObject jsonObject, String parentKey, Map<String, Object> result) {
if (Objects.isNull(jsonObject)) {
return;
}
for (Map.Entry<String, Object> entry : jsonObject.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
String fullKey = getParentKey(parentKey) + key;
if (value instanceof JSONObject) {
parseObject((JSONObject) value, fullKey + ".", result);
} else if (value instanceof JSONArray) {
parseArray((JSONArray) value, fullKey + ".", result);
} else {
result.put(fullKey, value);
}
}
}
private static void parseArray(JSONArray jsonArray, String parentKey, Map<String, Object> result) {
for (int i = 0; i < jsonArray.size(); i++) {
Object value = jsonArray.get(i);
String fullKey = getParentKey(parentKey) + i;
if (value instanceof JSONObject) {
parseObject((JSONObject) value, fullKey + ".", result);
} else if (value instanceof JSONArray) {
parseArray((JSONArray) value, fullKey + ".", result);
} else {
result.put(fullKey, value);
}
}
}
private static String getParentKey(String parentKey) {
if (parentKey.isEmpty()) {
return "";
}
return parentKey + ".";
}
}
6.编写拦截器
@Slf4j
public class CommonApiAuthInterceptor implements HandlerInterceptor {
@Override
public boolean preHandle(@NotNull HttpServletRequest request, @NotNull HttpServletResponse response, @NotNull Object handler) throws Exception {
boolean paramIllegal = ParamCheckerChain.getInstance().check(request);
if (paramIllegal) {
//这里自定义异常,不再多阐述
HttpServletUtils.setResponseInfo(response, HttpStatus.OK.value(), CommonErrorCode.PARAM_ILLEGAL.getMsg());
return false;
}
}
7.责任链chain
public class ParamCheckerChain {
private static ParamCheckerChain instance;
private ParamChecker head;
private ParamCheckerChain() {
buildChain();
}
public static ParamCheckerChain getInstance() {
if (instance == null) {
synchronized (ParamCheckerChain.class) {
if (instance == null) {
instance = new ParamCheckerChain();
}
}
}
return instance;
}
private void buildChain() {
head = new GetRequestChecker();
ParamChecker second = new FileRequestChecker();
ParamChecker third = new PostRequestChecker();
head.setNextHandler(second);
second.setNextHandler(third);
}
public boolean check(HttpServletRequest request) {
return head.check(request);
}
}
8.校验get请求
public class GetRequestChecker extends ParamChecker {
@Override
public boolean check(HttpServletRequest request) {
if (RequestMethod.GET.name().equals(request.getMethod())) {
if (ParamCheckUtils.checkSpecials(request.getQueryString())) {
return true;
}
} else if (nextHandler != null) {
return nextHandler.check(request);
}
return false;
}
}
9.校验文件类型的请求
public class FileRequestChecker extends ParamChecker {
@Override
public boolean check(HttpServletRequest request) {
String contentType = request.getContentType();
if (contentType != null && contentType.contains("multipart/form-data")) {
MultipartResolver resolver = new StandardServletMultipartResolver();
MultipartHttpServletRequest multipartRequest = resolver.resolveMultipart(request);
if (ParamCheckUtils.checkSpecials(multipartRequest.getParameterMap())) {
return true;
}
} else if (nextHandler != null) {
return nextHandler.check(request);
}
return false;
}
}
10.校验post请求
public class PostRequestChecker extends ParamChecker {
@Override
public boolean check(HttpServletRequest request) {
if (ParamCheckUtils.checkSpecials(request)) {
return true;
} else if (nextHandler != null) {
return nextHandler.check(request);
}
return false;
}
}