Springboot3巧妙运用拦截器阻断xss攻击

什么是xss

人们经常将跨站脚本攻击(Cross Site Scripting)缩写为CSS,但这会与层叠样式表(Cascading Style Sheets,CSS)的缩写混淆。因此,有人将跨站脚本攻击缩写为XSS。

跨站脚本攻击

是最普遍的Web应用安全漏洞。这类漏洞能够使得攻击者嵌入恶意脚本代码到正常用户会访问到的页面中,当正常用户访问该页面时,则可导致嵌入的恶意脚本代码的执行,从而达到恶意攻击用户的目的。

类型

  1. 持久型跨站 : 危害最大,跨站脚本存储在服务器的数据持久层中—数据库;
  2. 非持久型跨站 : 攻击者伪造反射型跨站脚本连接,直接访问该连接返回跨站代码;
  3. DOM跨站 :通常是客户端处理文档对象模型出现的逻辑安全问题

简单示例

我自己做了一个存在xss漏洞的简单web网站
我需要的正常访问结果
正常访问的结果
意料之外的异常结果
非法访问的结果
发现我们可以在连接后面拼接js等脚本,这很有可能就给攻击者留了后门

解决方法

拦截器

每次请求我们先看看内容存不存在非法内容

代码

代码结构
核心代码

package com.zxs.xss;

import com.zxs.xss.filter.XssFilter;
import jakarta.servlet.DispatcherType;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;



@Configuration
@EnableConfigurationProperties({XssProperties.class})
public class XssAutoConfiguration {

	@Autowired
	XssProperties xssProperties;

	@Bean
	@ConditionalOnProperty(prefix = XssProperties.XSS, name = "enable", havingValue = "true")
	public FilterRegistrationBean xssFilterRegistration() {
		FilterRegistrationBean registration = new FilterRegistrationBean();
		registration.setDispatcherTypes(DispatcherType.REQUEST);
		registration.setFilter(new XssFilter());
		registration.addUrlPatterns(xssProperties.getUrlPatterns());
		registration.setName(xssProperties.getName());
		registration.setOrder(xssProperties.getOrder());
		System.out.println("拦截器启动 :"+xssProperties.getName()+" 状态");
		return registration;
	}
}

package com.zxs.xss.filter;

import com.zxs.xss.servlet.XssHttpServletRequestWrapper;
import com.zxs.xss.util.HtmlFilterKit;
import jakarta.servlet.*;

import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;

public class XssFilter implements Filter {

	@Override
	public void init(FilterConfig config) throws ServletException {
	}

	@Override
	public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
			throws IOException, ServletException {
		XssHttpServletRequestWrapper xssRequest = new XssHttpServletRequestWrapper((HttpServletRequest) request,new HtmlFilterKit());
		System.out.println("拦截到:"+((HttpServletRequest) request).getMethod());
		chain.doFilter(xssRequest, response);
	}

	@Override
	public void destroy() {
	}
}

package com.zxs.xss.servlet;

import com.zxs.xss.util.HtmlFilterKit;
import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.LinkedHashMap;
import java.util.Map;


public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
	public static final String APPLICATION_JSON_VALUE = "application/json";
	public static final String CONTENT_TYPE = "Content-Type";

	/**
	 * 没被包装过的HttpServletRequest(特殊场景,需要自己过滤)
	 */
	HttpServletRequest orgRequest;
	/**
	 * html过滤
	 */
	private final  HtmlFilterKit htmlFilter;


	public XssHttpServletRequestWrapper(HttpServletRequest request,HtmlFilterKit htmlFilter) {
		super(request);
		this.htmlFilter = htmlFilter;
		orgRequest = request;
	}

	@Override
	public ServletInputStream getInputStream() throws IOException {
		//非json类型,直接返回
		if (!APPLICATION_JSON_VALUE.equalsIgnoreCase(super.getHeader(CONTENT_TYPE))) {
			return super.getInputStream();
		}

		//为空,直接返回
		String json = StreamUtils.copyToString(super.getInputStream(), Charset.forName("UTF-8"));
		if (StringUtils.isEmpty(json)) {
			return super.getInputStream();
		}

		//xss过滤
		json = xssEncode(json);
		final ByteArrayInputStream bis = new ByteArrayInputStream(json.getBytes("UTF-8"));
		return new ServletInputStream() {
			@Override
			public boolean isFinished() {
				return true;
			}

			@Override
			public boolean isReady() {
				return true;
			}

			@Override
			public void setReadListener(ReadListener readListener) {
			}

			@Override
			public int read() throws IOException {
				return bis.read();
			}
		};
	}

	@Override
	public String getParameter(String name) {
		String value = super.getParameter(xssEncode(name));
		if (!StringUtils.isEmpty(value)) {
			value = xssEncode(value);
		}
		return value;
	}

	@Override
	public String[] getParameterValues(String name) {
		String[] parameters = super.getParameterValues(name);
		if (parameters == null || parameters.length == 0) {
			return null;
		}

		for (int i = 0; i < parameters.length; i++) {
			parameters[i] = xssEncode(parameters[i]);
		}
		return parameters;
	}

	@Override
	public Map<String, String[]> getParameterMap() {
		Map<String, String[]> map = new LinkedHashMap<>();
		Map<String, String[]> parameters = super.getParameterMap();
		for (String key : parameters.keySet()) {
			String[] values = parameters.get(key);
			for (int i = 0; i < values.length; i++) {
				values[i] = xssEncode(values[i]);
			}
			map.put(key, values);
		}
		return map;
	}

	@Override
	public String getHeader(String name) {
		String value = super.getHeader(xssEncode(name));
		if (!StringUtils.isEmpty(value)) {
			value = xssEncode(value);
		}
		return value;
	}

	private String xssEncode(String input) {
		return htmlFilter.filter(input);
	}

	/**
	 * 获取最原始的request
	 */
	public HttpServletRequest getOrgRequest() {
		return orgRequest;
	}

	/**
	 * 获取最原始的request
	 */
	public static HttpServletRequest getOrgRequest(HttpServletRequest request) {
		if (request instanceof XssHttpServletRequestWrapper) {
			return ((XssHttpServletRequestWrapper) request).getOrgRequest();
		}
		return request;
	}
}

package com.zxs.xss.util;

public interface CharacterFilter {

    String filter(String input);
}

package com.zxs.xss.util;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public final class HtmlFilterKit implements CharacterFilter {

	/**
	 * regex flag union representing /si modifiers in php
	 **/
	private static final int REGEX_FLAGS_SI = Pattern.CASE_INSENSITIVE | Pattern.DOTALL;
	private static final Pattern P_COMMENTS = Pattern.compile("<!--(.*?)-->", Pattern.DOTALL);
	private static final Pattern P_COMMENT = Pattern.compile("^!--(.*)--$", REGEX_FLAGS_SI);
	private static final Pattern P_TAGS = Pattern.compile("<(.*?)>", Pattern.DOTALL);
	private static final Pattern P_END_TAG = Pattern.compile("^/([a-z0-9]+)", REGEX_FLAGS_SI);
	private static final Pattern P_START_TAG = Pattern.compile("^([a-z0-9]+)(.*?)(/?)$", REGEX_FLAGS_SI);
	private static final Pattern P_QUOTED_ATTRIBUTES = Pattern.compile("([a-z0-9]+)=([\"'])(.*?)\\2", REGEX_FLAGS_SI);
	private static final Pattern P_UNQUOTED_ATTRIBUTES = Pattern.compile("([a-z0-9]+)(=)([^\"\\s']+)", REGEX_FLAGS_SI);
	private static final Pattern P_PROTOCOL = Pattern.compile("^([^:]+):", REGEX_FLAGS_SI);
	private static final Pattern P_ENTITY = Pattern.compile("&#(\\d+);?");
	private static final Pattern P_ENTITY_UNICODE = Pattern.compile("&#x([0-9a-f]+);?");
	private static final Pattern P_ENCODE = Pattern.compile("%([0-9a-f]{2});?");
	private static final Pattern P_VALID_ENTITIES = Pattern.compile("&([^&;]*)(?=(;|&|$))");
	private static final Pattern P_VALID_QUOTES = Pattern.compile("(>|^)([^<]+?)(<|$)", Pattern.DOTALL);
	private static final Pattern P_END_ARROW = Pattern.compile("^>");
	private static final Pattern P_BODY_TO_END = Pattern.compile("<([^>]*?)(?=<|$)");
	private static final Pattern P_XML_CONTENT = Pattern.compile("(^|>)([^<]*?)(?=>)");
	private static final Pattern P_STRAY_LEFT_ARROW = Pattern.compile("<([^>]*?)(?=<|$)");
	private static final Pattern P_STRAY_RIGHT_ARROW = Pattern.compile("(^|>)([^<]*?)(?=>)");
	private static final Pattern P_AMP = Pattern.compile("&");
	private static final Pattern P_QUOTE = Pattern.compile("<");
	private static final Pattern P_LEFT_ARROW = Pattern.compile("<");
	private static final Pattern P_RIGHT_ARROW = Pattern.compile(">");
	private static final Pattern P_BOTH_ARROWS = Pattern.compile("<>");

	/**
	 * @xxx could grow large... maybe use sesat's ReferenceMap
	 */
	private static final ConcurrentMap<String, Pattern> P_REMOVE_PAIR_BLANKS = new ConcurrentHashMap<String, Pattern>();
	private static final ConcurrentMap<String, Pattern> P_REMOVE_SELF_BLANKS = new ConcurrentHashMap<String, Pattern>();

	/**
	 * set of allowed html elements, along with allowed attributes for each element
	 **/
	private final Map<String, List<String>> vAllowed;
	/**
	 * counts of open tags for each (allowable) html element
	 **/
	private final Map<String, Integer> vTagCounts = new HashMap<String, Integer>();

	/**
	 * html elements which must always be self-closing (e.g. "<img />")
	 **/
	private final String[] vSelfClosingTags;
	/**
	 * html elements which must always have separate opening and closing tags (e.g. "<b></b>")
	 **/
	private final String[] vNeedClosingTags;
	/**
	 * set of disallowed html elements
	 **/
	private final String[] vDisallowed;
	/**
	 * attributes which should be checked for valid protocols
	 **/
	private final String[] vProtocolAtts;
	/**
	 * allowed protocols
	 **/
	private final String[] vAllowedProtocols;

	private final String[] vRemoveBlanks;
	/**
	 * entities allowed within html markup
	 **/
	private final String[] vAllowedEntities;
	/**
	 * flag determining whether comments are allowed in input String.
	 */
	private final boolean stripComment;
	private final boolean encodeQuotes;
	private boolean vDebug = false;

	private final boolean alwaysMakeTags;

	/**
	 * Default constructor.
	 */
	public HtmlFilterKit() {
		vAllowed = new HashMap<>();

		final ArrayList<String> a_atts = new ArrayList<String>();
		a_atts.add("href");
		a_atts.add("target");
		vAllowed.put("a", a_atts);

		final ArrayList<String> img_atts = new ArrayList<String>();
		img_atts.add("src");
		img_atts.add("width");
		img_atts.add("height");
		img_atts.add("alt");
		vAllowed.put("img", img_atts);

		final ArrayList<String> no_atts = new ArrayList<String>();
		vAllowed.put("b", no_atts);
		vAllowed.put("strong", no_atts);
		vAllowed.put("i", no_atts);
		vAllowed.put("em", no_atts);

		vSelfClosingTags = new String[]{"img"};
		vNeedClosingTags = new String[]{"a", "b", "strong", "i", "em"};
		vDisallowed = new String[]{};
		vAllowedProtocols = new String[]{"http", "mailto", "https"}; // no ftp.
		vProtocolAtts = new String[]{"src", "href"};
		vRemoveBlanks = new String[]{"a", "b", "strong", "i", "em"};
		vAllowedEntities = new String[]{"amp", "gt", "lt", "quot"};
		stripComment = true;
		encodeQuotes = true;
		alwaysMakeTags = true;
	}

	/**
	 * Set debug flag to true. Otherwise use default settings. See the default constructor.
	 *
	 * @param debug turn debug on with a true argument
	 */
	public HtmlFilterKit(final boolean debug) {
		this();
		vDebug = debug;

	}

	/**
	 * Map-parameter configurable constructor.
	 *
	 * @param conf map containing configuration. keys match field names.
	 */
	public HtmlFilterKit(final Map<String, Object> conf) {
		assert conf.containsKey("vAllowed") : "configuration requires vAllowed";
		assert conf.containsKey("vSelfClosingTags") : "configuration requires vSelfClosingTags";
		assert conf.containsKey("vNeedClosingTags") : "configuration requires vNeedClosingTags";
		assert conf.containsKey("vDisallowed") : "configuration requires vDisallowed";
		assert conf.containsKey("vAllowedProtocols") : "configuration requires vAllowedProtocols";
		assert conf.containsKey("vProtocolAtts") : "configuration requires vProtocolAtts";
		assert conf.containsKey("vRemoveBlanks") : "configuration requires vRemoveBlanks";
		assert conf.containsKey("vAllowedEntities") : "configuration requires vAllowedEntities";

		vAllowed = Collections.unmodifiableMap((HashMap<String, List<String>>) conf.get("vAllowed"));
		vSelfClosingTags = (String[]) conf.get("vSelfClosingTags");
		vNeedClosingTags = (String[]) conf.get("vNeedClosingTags");
		vDisallowed = (String[]) conf.get("vDisallowed");
		vAllowedProtocols = (String[]) conf.get("vAllowedProtocols");
		vProtocolAtts = (String[]) conf.get("vProtocolAtts");
		vRemoveBlanks = (String[]) conf.get("vRemoveBlanks");
		vAllowedEntities = (String[]) conf.get("vAllowedEntities");
		stripComment = conf.containsKey("stripComment") ? (Boolean) conf.get("stripComment") : true;
		encodeQuotes = conf.containsKey("encodeQuotes") ? (Boolean) conf.get("encodeQuotes") : true;
		alwaysMakeTags = conf.containsKey("alwaysMakeTags") ? (Boolean) conf.get("alwaysMakeTags") : true;
	}

	private void reset() {
		vTagCounts.clear();
	}

	private void debug(final String msg) {
		if (vDebug) {
			Logger.getAnonymousLogger().info(msg);
		}
	}

	//---------------------------------------------------------------
	// my versions of some PHP library functions
	public static String chr(final int decimal) {
		return String.valueOf((char) decimal);
	}

	public static String htmlSpecialChars(final String s) {
		String result = s;
		result = regexReplace(P_AMP, "&amp;", result);
		result = regexReplace(P_QUOTE, "&quot;", result);
		result = regexReplace(P_LEFT_ARROW, "&lt;", result);
		result = regexReplace(P_RIGHT_ARROW, "&gt;", result);
		return result;
	}

	@Override
	public String filter(final String input) {
		reset();
		String s = input;

		debug("************************************************");
		debug("              INPUT: " + input);

		s = escapeComments(s);
		debug("     escapeComments: " + s);

		s = balanceHTML(s);
		debug("        balanceHTML: " + s);

		s = checkTags(s);
		debug("          checkTags: " + s);

		s = processRemoveBlanks(s);
		debug("processRemoveBlanks: " + s);

		s = validateEntities(s);
		debug("    validateEntites: " + s);

		debug("************************************************\n\n");
		return s;
	}

	public boolean isAlwaysMakeTags() {
		return alwaysMakeTags;
	}

	public boolean isStripComments() {
		return stripComment;
	}

	private String escapeComments(final String s) {
		final Matcher m = P_COMMENTS.matcher(s);
		final StringBuffer buf = new StringBuffer();
		if (m.find()) {
			final String match = m.group(1); //(.*?)
			m.appendReplacement(buf, Matcher.quoteReplacement("<!--" + htmlSpecialChars(match) + "-->"));
		}
		m.appendTail(buf);

		return buf.toString();
	}

	private String balanceHTML(String s) {
		if (alwaysMakeTags) {
			//
			// try and form html
			//
			s = regexReplace(P_END_ARROW, "", s);
			s = regexReplace(P_BODY_TO_END, "<$1>", s);
			s = regexReplace(P_XML_CONTENT, "$1<$2", s);

		} else {
			//
			// escape stray brackets
			//
			s = regexReplace(P_STRAY_LEFT_ARROW, "&lt;$1", s);
			s = regexReplace(P_STRAY_RIGHT_ARROW, "$1$2&gt;<", s);

			//
			// the last regexp causes '<>' entities to appear
			// (we need to do a lookahead assertion so that the last bracket can
			// be used in the next pass of the regexp)
			//
			s = regexReplace(P_BOTH_ARROWS, "", s);
		}

		return s;
	}

	private String checkTags(String s) {
		Matcher m = P_TAGS.matcher(s);

		final StringBuffer buf = new StringBuffer();
		while (m.find()) {
			String replaceStr = m.group(1);
			replaceStr = processTag(replaceStr);
			m.appendReplacement(buf, Matcher.quoteReplacement(replaceStr));
		}
		m.appendTail(buf);

		s = buf.toString();

		// these get tallied in processTag
		// (remember to reset before subsequent calls to filter method)
		for (String key : vTagCounts.keySet()) {
			for (int ii = 0; ii < vTagCounts.get(key); ii++) {
				s += "</" + key + ">";
			}
		}

		return s;
	}

	private String processRemoveBlanks(final String s) {
		String result = s;
		for (String tag : vRemoveBlanks) {
			if (!P_REMOVE_PAIR_BLANKS.containsKey(tag)) {
				P_REMOVE_PAIR_BLANKS.putIfAbsent(tag, Pattern.compile("<" + tag + "(\\s[^>]*)?></" + tag + ">"));
			}
			result = regexReplace(P_REMOVE_PAIR_BLANKS.get(tag), "", result);
			if (!P_REMOVE_SELF_BLANKS.containsKey(tag)) {
				P_REMOVE_SELF_BLANKS.putIfAbsent(tag, Pattern.compile("<" + tag + "(\\s[^>]*)?/>"));
			}
			result = regexReplace(P_REMOVE_SELF_BLANKS.get(tag), "", result);
		}

		return result;
	}

	private static String regexReplace(final Pattern regex_pattern, final String replacement, final String s) {
		Matcher m = regex_pattern.matcher(s);
		return m.replaceAll(replacement);
	}

	private String processTag(final String s) {
		// ending tags
		Matcher m = P_END_TAG.matcher(s);
		if (m.find()) {
			final String name = m.group(1).toLowerCase();
			if (allowed(name)) {
				if (!inArray(name, vSelfClosingTags)) {
					if (vTagCounts.containsKey(name)) {
						vTagCounts.put(name, vTagCounts.get(name) - 1);
						return "</" + name + ">";
					}
				}
			}
		}

		// starting tags
		m = P_START_TAG.matcher(s);
		if (m.find()) {
			final String name = m.group(1).toLowerCase();
			final String body = m.group(2);
			String ending = m.group(3);

			//debug( "in a starting tag, name='" + name + "'; body='" + body + "'; ending='" + ending + "'" );
			if (allowed(name)) {
				String params = "";

				final Matcher m2 = P_QUOTED_ATTRIBUTES.matcher(body);
				final Matcher m3 = P_UNQUOTED_ATTRIBUTES.matcher(body);
				final List<String> paramNames = new ArrayList<String>();
				final List<String> paramValues = new ArrayList<String>();
				while (m2.find()) {
					paramNames.add(m2.group(1)); //([a-z0-9]+)
					paramValues.add(m2.group(3)); //(.*?)
				}
				while (m3.find()) {
					paramNames.add(m3.group(1)); //([a-z0-9]+)
					paramValues.add(m3.group(3)); //([^\"\\s']+)
				}

				String paramName, paramValue;
				for (int ii = 0; ii < paramNames.size(); ii++) {
					paramName = paramNames.get(ii).toLowerCase();
					paramValue = paramValues.get(ii);

//          debug( "paramName='" + paramName + "'" );
//          debug( "paramValue='" + paramValue + "'" );
//          debug( "allowed? " + vAllowed.get( name ).contains( paramName ) );

					if (allowedAttribute(name, paramName)) {
						if (inArray(paramName, vProtocolAtts)) {
							paramValue = processParamProtocol(paramValue);
						}
						params += " " + paramName + "=\"" + paramValue + "\"";
					}
				}

				if (inArray(name, vSelfClosingTags)) {
					ending = " /";
				}

				if (inArray(name, vNeedClosingTags)) {
					ending = "";
				}

				if (ending == null || ending.length() < 1) {
					if (vTagCounts.containsKey(name)) {
						vTagCounts.put(name, vTagCounts.get(name) + 1);
					} else {
						vTagCounts.put(name, 1);
					}
				} else {
					ending = " /";
				}
				return "<" + name + params + ending + ">";
			} else {
				return "";
			}
		}

		// comments
		m = P_COMMENT.matcher(s);
		if (!stripComment && m.find()) {
			return "<" + m.group() + ">";
		}

		return "";
	}

	private String processParamProtocol(String s) {
		s = decodeEntities(s);
		final Matcher m = P_PROTOCOL.matcher(s);
		if (m.find()) {
			final String protocol = m.group(1);
			if (!inArray(protocol, vAllowedProtocols)) {
				// bad protocol, turn into local anchor link instead
				s = "#" + s.substring(protocol.length() + 1, s.length());
				if (s.startsWith("#//")) {
					s = "#" + s.substring(3, s.length());
				}
			}
		}

		return s;
	}

	private String decodeEntities(String s) {
		StringBuffer buf = new StringBuffer();

		Matcher m = P_ENTITY.matcher(s);
		while (m.find()) {
			final String match = m.group(1);
			final int decimal = Integer.decode(match).intValue();
			m.appendReplacement(buf, Matcher.quoteReplacement(chr(decimal)));
		}
		m.appendTail(buf);
		s = buf.toString();

		buf = new StringBuffer();
		m = P_ENTITY_UNICODE.matcher(s);
		while (m.find()) {
			final String match = m.group(1);
			final int decimal = Integer.valueOf(match, 16).intValue();
			m.appendReplacement(buf, Matcher.quoteReplacement(chr(decimal)));
		}
		m.appendTail(buf);
		s = buf.toString();

		buf = new StringBuffer();
		m = P_ENCODE.matcher(s);
		while (m.find()) {
			final String match = m.group(1);
			final int decimal = Integer.valueOf(match, 16).intValue();
			m.appendReplacement(buf, Matcher.quoteReplacement(chr(decimal)));
		}
		m.appendTail(buf);
		s = buf.toString();

		s = validateEntities(s);
		return s;
	}

	private String validateEntities(final String s) {
		StringBuffer buf = new StringBuffer();

		// validate entities throughout the string
		Matcher m = P_VALID_ENTITIES.matcher(s);
		while (m.find()) {
			final String one = m.group(1); //([^&;]*)
			final String two = m.group(2); //(?=(;|&|$))
			m.appendReplacement(buf, Matcher.quoteReplacement(checkEntity(one, two)));
		}
		m.appendTail(buf);

		return encodeQuotes(buf.toString());
	}

	private String encodeQuotes(final String s) {
		if (encodeQuotes) {
			StringBuffer buf = new StringBuffer();
			Matcher m = P_VALID_QUOTES.matcher(s);
			while (m.find()) {
				final String one = m.group(1); //(>|^)
				final String two = m.group(2); //([^<]+?)
				final String three = m.group(3); //(<|$)
				m.appendReplacement(buf, Matcher.quoteReplacement(one + regexReplace(P_QUOTE, "&quot;", two) + three));
			}
			m.appendTail(buf);
			return buf.toString();
		} else {
			return s;
		}
	}

	private String checkEntity(final String preamble, final String term) {

		return ";".equals(term) && isValidEntity(preamble)
				? '&' + preamble
				: "&amp;" + preamble;
	}

	private boolean isValidEntity(final String entity) {
		return inArray(entity, vAllowedEntities);
	}

	private static boolean inArray(final String s, final String[] array) {
		for (String item : array) {
			if (item != null && item.equals(s)) {
				return true;
			}
		}
		return false;
	}

	private boolean allowed(final String name) {
		return (vAllowed.isEmpty() || vAllowed.containsKey(name)) && !inArray(name, vDisallowed);
	}

	private boolean allowedAttribute(final String name, final String paramName) {
		return allowed(name) && (vAllowed.isEmpty() || vAllowed.get(name).contains(paramName));
	}
}

package com.zxs.xss;

import org.springframework.boot.context.properties.ConfigurationProperties;

@ConfigurationProperties(prefix = XssProperties.XSS)
public class XssProperties {
	public static final String XSS = "xss";

	/**
	 * xss 是否生效
	 */
	boolean enable = false;
	/**
	 * xss过滤器的名字
	 */
	String name = "xssFilter";
	/**
	 * xss过滤器需要匹配的路径
	 */
	String[] urlPatterns = {"/*"};

	/**
	 * 过滤器的优先级,值越小优先级越高
	 */
	int order = 0;

	public String getName() {
		return name;
	}

	public void setName(String name) {
		this.name = name;
	}

	public String[] getUrlPatterns() {
		return urlPatterns;
	}

	public void setUrlPatterns(String[] urlPatterns) {
		this.urlPatterns = urlPatterns;
	}

	public int getOrder() {
		return order;
	}

	public void setOrder(int order) {
		this.order = order;
	}

	public boolean isEnable() {
		return enable;
	}

	public void setEnable(boolean enable) {
		this.enable = enable;
	}
}

package com.zxs.xss.exception;

public class XSSException extends RuntimeException {
    private String msg;
    private int code = 500;

    public XSSException(String msg) {
        super(msg);
        this.msg = msg;
    }

    public String getMsg() {
        return msg;
    }

    public void setMsg(String msg) {
        this.msg = msg;
    }

    public int getCode() {
        return code;
    }

    public void setCode(int code) {
        this.code = code;
    }
}

com.zxs.xss.XssAutoConfiguration
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.zxs</groupId>
    <artifactId>zxs-xss-starter</artifactId>
    <version>0.0.1</version>
    <packaging>jar</packaging>

    <name>zxs-xss-starter</name>
    <description>为SpringbootWeb应用打造的Xss防护工具</description>

    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.4.0</version>
        <relativePath/>
    </parent>

    <properties>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
        <java.version>21</java.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter</artifactId>
            <optional>true</optional>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-configuration-processor</artifactId>
            <optional>true</optional>
        </dependency>

        <dependency>
            <groupId>jakarta.servlet</groupId>
            <artifactId>jakarta.servlet-api</artifactId>
            <scope>provided</scope>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <optional>true</optional>
            <scope>test</scope>
        </dependency>
    </dependencies>
</project>

使用demo

将上面的插件包打包引入到我们的demo包里面

   <dependency>
            <groupId>com.bizihang</groupId>
            <artifactId>zxs-xss-starter</artifactId>
            <version>0.0.1</version>
        </dependency>
# 开启xss防护
xss.enable=true
# 设置xss防护的url拦截路径
xss.url-patterns=/zxs/noxss-json
# 设置xss防护过滤器的优先级,值越小优先级越高
xss.order=0
package com.example.demo.controller;

import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.tags.Param;

@RequestMapping("/zxs")
@RestController
public class XSSTestController {


    @GetMapping("/xss-string")
    public Object xssFormTest(String value) {
        System.out.println(value);
        return value;
    }

    @PostMapping("/noxss-json")
    public Object xssJsonTest(@RequestBody Param param ) {
        return param;
    }
}

启动之后就可以了

通过例子可以看到我这边实现了一个存在xss的api与一个被拦截器保护的api

在这里插入图片描述
可以看到被保护的api是会过滤掉我们代码里面定义的非法字符

没被保护的api还是正常的出现xss

在这里插入图片描述
然后我们把
xss.url-patterns=/* 全保护起来,再次进入之前存在xss的api
在这里插入图片描述
发现已经被拦截了

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

斗码士

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值