package cn.gtmap.estateplat.filter;

import cn.gtmap.estateplat.utils.CommonUtil;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;

/**
 * @author <a href="mailto:dingweiwei@gtmap.cn">dingweiwei</a>
 * @description sql注入过滤器
 */
public class SqlInjectFilter implements Filter {
    private Logger logger = LoggerFactory.getLogger(this.getClass());
    //sql关键字
    private String sqlInjectStrList = new String();

    private String invalidPage = "/common/404.ftl";

    //白名单页面
    private Map<String, String> whitePageList = new HashMap();

    public void destroy() {
    }

    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws ServletException, IOException {
        HttpServletRequest request = (HttpServletRequest) req;
        HttpServletResponse response = (HttpServletResponse) res;
        String currentURL = request.getRequestURI();
        //不需要过滤的特定页面
        if (MapUtils.isNotEmpty(whitePageList)) {
            for (String key : whitePageList.keySet()) {
                if (StringUtils.indexOf(currentURL, key) > -1) {
                    chain.doFilter(request, response);
                    return;
                }
            }
        }
        // 获得所有请求参数名
        Enumeration<?> params = request.getParameterNames();
        String sql = "";
        while (params.hasMoreElements()) {
            // 得到参数名
            String name = params.nextElement().toString();
            // 得到参数对应值
            String[] value = request.getParameterValues(name);
            for (int i = 0; i < value.length; i++) {
                sql = sql + value[i];
            }
        }
        // 过滤掉的SQL关键字，可以手动添加
        if (sqlValidate(sql, sqlInjectStrList)) {
            // 重定向或跳转
            response.sendRedirect(request.getContextPath() + this.invalidPage);
            logger.error(request.getRequestURI() + ":" + sql + "存在非法字符字符，请检查！");
        } else {
            chain.doFilter(request, response);
        }
    }

    public void init(FilterConfig config) throws ServletException {

        this.sqlInjectStrList = config.getInitParameter("sqlInjectStrList");

        String whitePage = config.getInitParameter("WhitePageList");
        if (StringUtils.isNotBlank(whitePage)) {
            String[] arrPage = whitePage.split(";");
            if (ArrayUtils.isNotEmpty(arrPage)) {
                for (int i = 0; i < arrPage.length; i++) {
                    String key = arrPage[i];
                    this.whitePageList.put(key, (String) null);
                }
            }
        }
    }

    // 校验SQL
    protected static boolean sqlValidate(String str, String sqlInjectStrList) {
        // 统一转为小写
        str = str.toLowerCase();
        /*if (str.contains(" ")) {
            // 转换为数组
            String[] badStrs = sqlInjectStrList.split("\\|");
            for (int i = 0; i < badStrs.length; i++) {
                // 检索
                if (str.indexOf(badStrs[i]) >= 0) {
                    return true;
                }
            }
        }*/
        // 转换为数组
        String[] badStrs = sqlInjectStrList.split("\\|");
        for (int i = 0; i < badStrs.length; i++) {
            // 检索
            if (str.indexOf(badStrs[i]) >= 0) {
                System.out.println("str = " + str);
                System.out.println("非法字符: " + badStrs[i]);
                return true;
            }
        }
        return false;
    }
}
