package com.bcxin.ars.filter;

import com.bcxin.ars.exception.ArsException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.util.*;

/**
 * @author panxianwei
 * @date 2023/4/13 13:56
 */
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {

    private static Logger logger = LoggerFactory.getLogger(XssHttpServletRequestWrapper.class);
    private static String key = "and|exec|insert|select|delete|update|count|*|%|chr|mid|master|truncate|char|declare|;|or|-|+";
    private static Set<String> notAllowedKeyWords = new HashSet<String>(0);
    private static String replacedString = "INVALID";

    static {
        String keyStr[] = key.split("\\|");
        for (String str : keyStr) {
            notAllowedKeyWords.add(str);
        }
    }

    public XssHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }
    /**
     * <b> 覆盖getParameter方法，将参数和参数值做xss过滤 </b>
     * @author ZXF
     * @create 2022/08/08 0008 14:16
     * @version
     * @注意事项 </b>
     */
    @Override
    public String getParameter(String name) {
        String value = super.getParameter(name);
        if (value == null){
            return null;
        }
        return cleanXss(value);
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        Map<String, String[]> map = super.getParameterMap();
        Map<String, String[]> encodedMap = new HashMap<String, String[]>();
        encodedMap.putAll(map);

        for (Map.Entry<String, String[]> entry : encodedMap.entrySet()) {
            String[] value = entry.getValue();
            String[] encodedValues = new String[value.length];
            for (int i = 0; i < value.length; i++) {
                encodedValues[i] = cleanXss(value[i]);
            }
            encodedMap.put(entry.getKey(), encodedValues);
        }
        return encodedMap;
    }

    /**
     * 在获取所有的参数名,必须重写此方法，否则对象中参数值映射不上
     * @return
     */
    @SuppressWarnings({ "unchecked", "rawtypes" })
    @Override
    public Enumeration<String> getParameterNames() {
        return new Vector(super.getParameterMap().keySet()).elements();
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] values = super.getParameterValues(name);
        if (values != null)
        {
            int length = values.length;
            String[] escapseValues = new String[length];
            for (int i = 0; i < length; i++)
            {
                // 防xss攻击和过滤前后空格
                escapseValues[i] = cleanXss(values[i]);
            }
            return escapseValues;
        }
        return super.getParameterValues(name);
    }

//    @Override
//    public ServletInputStream getInputStream() throws IOException {
//        final ByteArrayInputStream bais = new ByteArrayInputStream(body);
//        return new ServletInputStream() {
//
//            @Override
//            public int read() throws IOException {
//                return bais.read();
//            }
//
//            @Override
//            public boolean isFinished() {
//                return false;
//            }
//
//            @Override
//            public boolean isReady() {
//                return false;
//            }
//
//            @Override
//            public void setReadListener(ReadListener arg0) {
//
//            }
//        };
//    }
//
//    @Override
//    public BufferedReader getReader() throws IOException {
//        return new BufferedReader(new InputStreamReader(getInputStream()));
//    }

    /**
     * <b> 参数转换 </b>
     * @author ZXF
     * @create 2022/08/08 0008 14:16
     * @version
     * @注意事项 </b>
     */
    private String cleanXss(String valueP){
        /*String value = EscapeUtil.clean(valueP).trim();
        String value = valueP.replaceAll("<","＜").replaceAll(">","＞");
        value = value.replaceAll("\\(","& #40;").replaceAll("\\)","& #41;");
        value = value.replaceAll("'","&#39;");
        value = value.replaceAll("\"","&#34;");
        value = value.replaceAll("eval\\((.*)\\)","");
        value = value.replaceAll("select", "seleᴄt");// "c"→"ᴄ"
        value = value.replaceAll("truncate", "trunᴄate");// "c"→"ᴄ"
        value = value.replaceAll("exec", "exeᴄ");// "c"→"ᴄ"
        value = value.replaceAll("join", "jᴏin");// "o"→"ᴏ"
        value = value.replaceAll("union", "uniᴏn");// "o"→"ᴏ"
        value = value.replaceAll("drop", "drᴏp");// "o"→"ᴏ"
        value = value.replaceAll("count", "cᴏunt");// "o"→"ᴏ"
        value = value.replaceAll("insert", "ins℮rt");// "e"→"℮"
        value = value.replaceAll("update", "updat℮");// "e"→"℮"
        value = value.replaceAll("delete", "delet℮");// "e"→"℮"
        value = value.replaceAll("script", "sᴄript");// "c"→"ᴄ"
        value = value.replaceAll("cookie", "cᴏᴏkie");// "o"→"ᴏ"
        value = value.replaceAll("iframe", "ifram℮");// "e"→"℮"
        value = value.replaceAll("onmouseover", "onmouseov℮r");// "e"→"℮"
        value = value.replaceAll("onmousemove", "onmousemov℮");// "e"→"℮"*/

        String[] SpecialCharacters=new String[] { "<", ">", "select ", "truncate ", "exec ","join ","union ","drop ","insert ",
        "update ","delete ","script","cookie","iframe","onmouseover","onmousemove"};

        for (String keyword : SpecialCharacters) {
            if (valueP.length() > keyword.length() + 4
                    && (valueP.contains(" "+keyword)||valueP.contains(keyword+" ")||valueP.contains(" "+keyword+" ")||
                    valueP.contains(keyword))) {
                throw new ArsException("提交的数据包含特殊字符，请修改后重新提交！");
            }
        }

        return valueP;
    }

    /**
     * @Description 解析参数SQL关键字
     * @Date 2020/5/20 10:01
     */
    private String cleanSqlKeyWords(String value){
        String paramValue = value.toLowerCase();
        int sign=0;
        for (String keyWord : notAllowedKeyWords) {
            if(paramValue.length() > keyWord.length() + 4
                    && (paramValue.contains(" " + keyWord)
                    || paramValue.contains(keyWord + " ")
                    || paramValue.contains(" " + keyWord + " "))){
                paramValue = StringUtils.replace(paramValue,keyWord,replacedString);
                sign=1;
                logger.error( "链接已被过滤，因为参数中包含不允许sql的关键词(" + keyWord + ");参数：" + value + "；过滤后的参数：" + paramValue);
            }
        }
        if(value.toLowerCase().equals(paramValue)){
            return value;
        }
        if(sign==1){
            paramValue = value;
        }
        return paramValue;
    }

    public boolean checkSqlKeyWords(String value){
        String paramValue = value;
        for (String keyword : notAllowedKeyWords) {
            if (paramValue.length() > keyword.length() + 4
                    && (paramValue.contains(" "+keyword)||paramValue.contains(keyword+" ")||paramValue.contains(" "+keyword+" "))) {
                logger.error(this.getRequestURI()+ "参数中包含不允许sql的关键词(" + keyword
                        + ")");
                return true;
            }
        }
        return false;
    }
}
