思路
- 配置拦截器
- 遇到问题:拦截器没法直接抛出异常给 controllerAdvance 去捕捉
- 找到 HandlerExceptionResolver 这个实体类
- 使用spring 获取注入的bean:handlerExceptionResolver(至于这个东西是怎么找到全局的advance还需要后续确认,反正是能找到自己定义的全局 advance)
- 在校验不通过的地方抛出 runtimeException
- 原地try catch 然后用 resolver.resolveException(request/requestWrapper,response,null,e) 向外抛出异常,那个为null的参数必为null,否则就是用spring默认的错误处理了
具体配置
配置拦截器 CharCheckFilter
public class CharCheckFilter implements Filter
{
private final static Logger log = LoggerFactory.getLogger(CharCheckFilter.class);
private HandlerExceptionResolver resolver = null;
@Override
public void init(FilterConfig filterConfig) throws ServletException
{
log.info("init");
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException
{
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
String uri = request.getRequestURI();
String method = request.getMethod().toUpperCase();
if (uri.contains("/auth")) { // 鉴权接口不做处理
filterChain.doFilter(request, response);
return;
}
if ("POST".equals(method)) { // post 请求需要处理跨域
// 请求参数解析
XyRequestWrapper requestWrapper = new XyRequestWrapper(request);
try {
// 跨域解决,post请求跟get不同,进来就得重设,要不直接跨域
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Access-Control-Allow-Credentials", "true");
response.addHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE");
response.setHeader("Access-Control-Max-Age", "3600");
response.setHeader("Access-Control-Allow-Headers", "X-Requested-With,Authorization,Content-Type");
String body = requestWrapper.getBody();
if (StringUtils.isNotBlank(body) && body.contains("%")) {
log.info("POST请求存在非法字符 % , uri={},入参={}",uri,body);
resolver = (HandlerExceptionResolver)TestSpringUtil.getBean("handlerExceptionResolver");
throw new RuntimeException("非法字符 %");
}
filterChain.doFilter(requestWrapper, response);
return;
} catch(RuntimeException e) {
resolver.resolveException(requestWrapper, response, null,e);//交给全局异常处理类处理
return;
}
}
if ("GET".equals(method)) { // GET 请求不用处理跨域
try {
Map<String, String[]> parameterMap = request.getParameterMap();
Set<Map.Entry<String, String[]>> entries = parameterMap.entrySet();
Iterator<Map.Entry<String, String[]>> iterator = entries.iterator();
while (iterator.hasNext()) {
Map.Entry<String, String[]> next = iterator.next();
String key = next.getKey();
String value = next.getValue()[0];
if (key.contains("%") || value.contains("%")) {
log.info("GET请求存在非法字符 % , uri={},入参key={},入参value={}",uri,key,value);
resolver = (HandlerExceptionResolver)TestSpringUtil.getBean("handlerExceptionResolver");
throw new RuntimeException("非法字符 %");
}
}
filterChain.doFilter(request, response);
return;
} catch(Exception e) {
// 如果出现非法字符,这个请求需要抛弃,如果开头就设置跨域会有问题,所以在确认抛弃请求时,设置跨域
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Access-Control-Allow-Credentials", "true");
response.addHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE");
response.setHeader("Access-Control-Max-Age", "3600");
response.setHeader("Access-Control-Allow-Headers", "X-Requested-With,Authorization,Content-Type");
resolver.resolveException(request, response, null, e);//交给全局异常处理类处理
return;
}
}
filterChain.doFilter(request, response);
return;
}
@Override
public void destroy() {
}
}
配置request 包装类,防止获取一次inputStream就挂了
import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Charsets;
import org.apache.commons.lang3.StringUtils;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
public class XyRequestWrapper extends HttpServletRequestWrapper {
private String body;
public XyRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
StringBuilder stringBuilder = new StringBuilder();
BufferedReader bufferedReader = null;
try {
InputStream inputStream = request.getInputStream();
if (inputStream != null) {
bufferedReader = new BufferedReader(new InputStreamReader(inputStream,"UTF-8"));
char[] charBuffer = new char[128];
int bytesRead = -1;
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
stringBuilder.append(charBuffer, 0, bytesRead);
}
} else {
stringBuilder.append("");
}
} catch (IOException ex) {
throw ex;
} finally {
if (bufferedReader != null) {
try {
bufferedReader.close();
} catch (IOException ex) {
throw ex;
}
}
}
body = stringBuilder.toString();
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(body.getBytes("UTF-8"));
ServletInputStream servletInputStream = new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public int read() throws IOException {
return byteArrayInputStream.read();
}
};
return servletInputStream;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(this.getInputStream(), Charsets.UTF_8));
}
public String getBody() {
return this.body;
}
@Override
public String getParameter(String name) {
return super.getParameter(name);
}
@Override
public Map<String, String[]> getParameterMap() {
return super.getParameterMap();
}
@Override
public Enumeration<String> getParameterNames() {
return super.getParameterNames();
}
@Override
public String[] getParameterValues(String name) {
return super.getParameterValues(name);
}
/**
* 设置自定义post参数 //
*
* @param paramMaps
* @return
*/
public void setParamsMaps(Map paramMaps) {
Map paramBodyMap = new HashMap();
if (!StringUtils.isEmpty(body)) {
paramBodyMap = JSONObject.parseObject(body, Map.class);
}
paramBodyMap.putAll(paramMaps);
body = JSONObject.toJSONString(paramBodyMap);
}
}
配置spring工具类,用于获取 拦截器里外抛异常的 HandlerExceptionResolver 对象
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
@Component
public class TestSpringUtil implements ApplicationContextAware {
private static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
if (TestSpringUtil.applicationContext == null) {
TestSpringUtil.applicationContext = applicationContext;
}
}
//获取applicationContext
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
//通过name获取 Bean.
public static Object getBean(String name) {
return getApplicationContext().getBean(name);
}
//通过class获取Bean.
public static <T> T getBean(Class<T> clazz) {
return getApplicationContext().getBean(clazz);
}
//通过name,以及Clazz返回指定的Bean
public static <T> T getBean(String name, Class<T> clazz) {
return getApplicationContext().getBean(name, clazz);
}
}