package cn.gtmap.hlw.core.config;

import cn.gtmap.estateplat.register.common.util.PublicUtil;
import cn.gtmap.hlw.core.util.encryption.sm2.Sm2lib;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map.Entry;

/**
 * @Description 重写HttpServletRequestWrapper 解析body数据
 * @Author admin
 * @Date  2024/6/25 17:24
 */
public class RequestWrapper extends HttpServletRequestWrapper {
    /**
     * 存放JSON数据主体
     */
    private final String body;

    private final String customContentType;

    private static final Logger logger = LoggerFactory.getLogger(RequestWrapper.class);

    public RequestWrapper(HttpServletRequest request) {
        super(request);
        this.customContentType = "application/json;charset=UTF-8";

        StringBuilder stringBuilder = new StringBuilder();
        InputStream inputStream = null;
        BufferedReader bufferedReader = null;
        try {
            inputStream = request.getInputStream();
            if (inputStream != null) {
                bufferedReader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
                char[] charBuffer = new char[128];
                int bytesRead = -1;
                while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
                    stringBuilder.append(charBuffer, 0, bytesRead);
                }
            }
        } catch (IOException ex) {
            logger.error(ex.getMessage());
        } finally {
            try {
                if (bufferedReader != null) {
                    bufferedReader.close();
                }
                if (inputStream != null) {
                    inputStream.close();
                }
            } catch (IOException ex) {
                logger.error("RequestWrapper --> 流关闭异常：" + ex.getMessage());
            }
        }

        String uale = stringBuilder.toString();
        //解密
        body = Sm2lib.decode(uale);
    }

    @Override
    public String getContentType() {
        return StringUtils.isNotBlank(customContentType) ? customContentType : super.getContentType();
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        String encodeBody = xssEncode(body);
        if (StringUtils.isBlank(encodeBody)) {
            return null;
        }
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(encodeBody.getBytes(StandardCharsets.UTF_8));
        ServletInputStream servletInputStream = new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return byteArrayInputStream.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener listener) {

            }
        };
        return servletInputStream;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(this.getInputStream()));
    }

    /**
     * 覆盖getParameter方法，将参数名和参数值都做xss过滤。<br/>
     * 如果需要获得原始的值，则通过super.getParameterValues(name)来获取<br/>
     * getParameterNames,getParameterValues和getParameterMap也可能需要覆盖
     */
    @Override
    public String getParameter(String name) {
        String value = super.getParameter(name);
        if (value != null) {
            value = xssEncode(value);
        }
        return value;
    }

    /**
     * 覆盖getHeader方法，将参数名和参数值都做xss过滤。<br/>
     * 如果需要获得原始的值，则通过super.getHeaders(name)来获取<br/> getHeaderNames 也可能需要覆盖
     */
    @Override
    public String getHeader(String name) {
        String value = super.getHeader(name);
        if (value != null) {
            value = xssEncode(value);
        }
        return value;
    }

    /**
     * 将容易引起xss漏洞的半角字符直接替换成全角字符
     *
     * @param body
     * @return
     */
    private String xssEncode(String body) {
        if (body == null || body.isEmpty()) {
            return body;
        }
        if (!PublicUtil.isJson(body)) {
            return StringEscapeUtils.escapeHtml4(body);
        }
        HashMap hashMap = PublicUtil.getBeanByJsonObj(body, HashMap.class);
        foreachMap(hashMap);
        return JSON.toJSONString(hashMap);
    }

    private void foreachMap(HashMap<String, Object> hashMap) {
        if (hashMap != null) {
            Iterator<Entry<String, Object>> iterator = hashMap.entrySet().iterator();
            while (iterator.hasNext()) {
                Entry entry = iterator.next();
                Object value = entry.getValue();
                if (value instanceof String) {
                    entry.setValue(StringEscapeUtils.escapeHtml4((String) value));
                } else if (value instanceof JSONObject) {
                    JSONObject jsonObject = (JSONObject) value;
                    Iterator it = jsonObject.entrySet().iterator();
                    while (it.hasNext()) {
                        Entry<String, Object> jsonEntry = (Entry<String, Object>) it.next();
                        if (jsonEntry.getValue() instanceof JSONArray) {
                            continue;
                        }
                        jsonEntry.setValue(StringEscapeUtils.escapeHtml4(jsonEntry.getValue().toString()));
                    }
                }
            }
        }
    }
}
