package com.xjrsoft.common.xss; import cn.hutool.core.collection.ListUtil; import cn.hutool.core.util.StrUtil; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.StringUtils; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.util.AntPathMatcher; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; /** * XSS过滤处理 * * @author tzx */ public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper { /** * 没被包装过的HttpServletRequest(特殊场景,需要自己过滤) */ HttpServletRequest orgRequest; public static final String HTTP_METHOD_OVERRIDE = "x-http-method-override"; private String method; private final List ignoreXssUrl = ListUtil.toList("/magic-api/**", "/magic/web/**", "/workflow/execute/*", "/oa/wfMeetingApply/update_meetingSummary", "/base/baseClassDynamics"); //html过滤 private final static HTMLFilter HTML_FILTER = new HTMLFilter(); public XssHttpServletRequestWrapper(HttpServletRequest request) { super(request); orgRequest = request; // 判断请求方式是否需要转换 String methodOverride = this.getHeader(HTTP_METHOD_OVERRIDE); this.method = request.getMethod(); if (StrUtil.isNotBlank(methodOverride) && (methodOverride.equals("PUT") || methodOverride.equals("DELETE"))) { method = methodOverride; } } @Override public ServletInputStream getInputStream() throws IOException { //非json类型,直接返回 if (!MediaType.APPLICATION_JSON_VALUE.equalsIgnoreCase(super.getHeader(HttpHeaders.CONTENT_TYPE))) { return super.getInputStream(); } //为空,直接返回 String json = IOUtils.toString(super.getInputStream(), StandardCharsets.UTF_8); if (StringUtils.isBlank(json)) { return super.getInputStream(); } AntPathMatcher matcher = new AntPathMatcher(); if (ignoreXssUrl.stream().noneMatch(url -> matcher.matchStart(url, orgRequest.getRequestURI()))) { //xss过滤 orgRequest.getRequestURI() json = xssEncode(json); } final ByteArrayInputStream bis = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)); return new ServletInputStream() { @Override public boolean isFinished() { return true; } @Override public boolean isReady() { return true; } @Override public void setReadListener(ReadListener readListener) { } @Override public int read() throws IOException { return bis.read(); } }; } @Override public String getParameter(String name) { String value = super.getParameter(xssEncode(name)); if (StringUtils.isNotBlank(value)) { value = xssEncode(value); } return value; } @Override public String[] getParameterValues(String name) { String[] parameters = super.getParameterValues(name); if (parameters == null || parameters.length == 0) { return null; } for (int i = 0; i < parameters.length; i++) { parameters[i] = xssEncode(parameters[i]); } return parameters; } @Override public Map getParameterMap() { Map map = new LinkedHashMap<>(); Map parameters = super.getParameterMap(); for (String key : parameters.keySet()) { String[] values = parameters.get(key); for (int i = 0; i < values.length; i++) { values[i] = xssEncode(values[i]); } map.put(key, values); } return map; } @Override public String getHeader(String name) { String value = super.getHeader(xssEncode(name)); if (StringUtils.isNotBlank(value)) { value = xssEncode(value); } return value; } @Override public String getMethod() { return method; } private String xssEncode(String input) { return HTML_FILTER.filter(input); } /** * 获取最原始的request */ public HttpServletRequest getOrgRequest() { return orgRequest; } /** * 获取最原始的request */ public static HttpServletRequest getOrgRequest(HttpServletRequest request) { if (request instanceof XssHttpServletRequestWrapper) { return ((XssHttpServletRequestWrapper) request).getOrgRequest(); } return request; } }