package cn.gtmap.secondaryMarket.common.utils.db;


import cn.gtmap.secondaryMarket.common.utils.ReflectUtil;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.MappedStatement.Builder;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.*;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/**
 * mybatis 拦截器，用来分页
 * Created by JIFF on 2016/11/22.
 */
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class,Integer.class}),
@Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class})})
public class OraclePageHelper implements Interceptor {

    private static Logger Log = LoggerFactory.getLogger(OraclePageHelper.class);
    private static String dialect = "oracle";
//    private static final ThreadLocal<Page> localPage = new ThreadLocal<Page>();

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        if (invocation.getTarget() instanceof StatementHandler) {
            StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
            BoundSql boundSql = statementHandler.getBoundSql();
            System.out.println(boundSql.getSql());
            PageRequest page=getPageFromBoundSql(boundSql);
            if (page != null) {
                //获取当前要执行的Sql语句，也就是我们直接在Mapper映射语句中写的Sql语句
                String sql = boundSql.getSql();
                //给当前的page参数对象设置总记录数,总页数
                if (page.getTotal()>0){
                    Connection connection = (Connection) invocation.getArgs()[0];
                    MappedStatement mappedStatement =(MappedStatement) ReflectUtil.getFieldValue(ReflectUtil.getFieldValue(statementHandler, "delegate"), "mappedStatement");
                    page.setTotal(getTotalCount(connection,mappedStatement,boundSql,sql));
                    int pages= (int) ((page.getTotal()+page.getPageSize()-1)/page.getPageSize());
                    page.setPages(pages);
                }
                //获取分页Sql语句
                String pageSql = this.getPageSql(sql, page);
                //利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
                ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
                System.out.println(pageSql);
            }
            return invocation.proceed();
        }else if (invocation.getTarget() instanceof ResultSetHandler) {
            Object result = invocation.proceed();
            if (result instanceof List) {
                ResultSetHandler resultSetHandler=(ResultSetHandler)invocation.getTarget();
                BoundSql boundSql =ReflectUtil.getFieldValue(resultSetHandler,"boundSql");
                PageRequest pageRequest=getPageFromBoundSql(boundSql);
                if (pageRequest != null) {
                    Page<?> page=pageRequest.getPage();
                    page.addAll((List) result);
//                    page.getResult().addAll((List) result);
                    return page;
                }
            }
            return result;
        }
        return null;
    }

    private PageRequest getPageFromBoundSql(BoundSql boundSql){
        Object obj = boundSql.getParameterObject();
        if (obj instanceof PageRequest){
            return (PageRequest)obj;
        } else if (obj instanceof Map) {
            for (Object param : ((Map) obj).values()) {
                if (param instanceof PageRequest) {
                    return (PageRequest)param;
                }
            }
        }
        return null;
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler || target instanceof ResultSetHandler) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }

    @Override
    public void setProperties(Properties properties) {

    }

    private long getTotalCount(Connection connection,MappedStatement mappedStatement,BoundSql boundSql,String sql){
        String countSql=getCountSql(sql);
        PreparedStatement countStmt = null;
        ResultSet rs = null;
        try {
            countStmt = connection.prepareStatement(countSql);

            BoundSql countBS = copyFromBoundSql(mappedStatement, boundSql, countSql);
            DefaultParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), countBS);
            parameterHandler.setParameters(countStmt);

            rs = countStmt.executeQuery();
            long totalCount = 0;
            if (rs.next()) {
                totalCount = rs.getLong(1);
            }
            return totalCount;
        } catch (SQLException e) {
            Log.error(e.getMessage(),e);
        } finally {
            try {
                rs.close();
            } catch (SQLException e) {
                Log.error(e.getMessage(),e);
            }
            try {
                countStmt.close();
            } catch (SQLException e) {
                Log.error(e.getMessage(),e);
            }
        }
        return 0;
    }


    private String getPageSql(String sql, PageRequest page) {
        StringBuilder pageSql = new StringBuilder(200);
        if ("postgresql".equals(dialect)) {
            pageSql.append(sql);
            pageSql.append(" limit " + page.getPageSize() + " offset "
                    + page.getStartRow());
        } else if ("mysql".equals(dialect)) {
            pageSql.append(sql);
            pageSql.append(" limit " + page.getStartRow() + ","
                    + page.getPageSize());
        } else if ("hsqldb".equals(dialect)) {
            pageSql.append(sql);
            pageSql.append(" LIMIT " + page.getPageSize() + " OFFSET "
                    + page.getStartRow());
        } else if ("oracle".equals(dialect)) {
            pageSql.append("select * from ( select temp.*, rownum row_id from ( ");
            pageSql.append(sql);
            pageSql.append(" ) temp where rownum <= ").append(page.getEndRow());
            pageSql.append(") where row_id > ").append(page.getStartRow());
        }
        return pageSql.toString();
    }


    public String getCountSql(String sql) {
        StringBuffer sb = new StringBuffer("select count(1) from ");
        sql = sql.toLowerCase();

        if (sql.lastIndexOf("order by") > sql.lastIndexOf(")")) {
            sb.append(sql.substring(sql.indexOf("from") + 4, sql.lastIndexOf("order by")));
        } else {
            sb.append(sql.substring(sql.indexOf("from") + 4));
        }
        return sb.toString();
    }



    public static void setDialect(String dialect) {
        OraclePageHelper.dialect = dialect;
    }

    /**
     * 复制MappedStatement对象
     */
    private MappedStatement copyFromMappedStatement(MappedStatement ms,SqlSource newSqlSource) {
        Builder builder = new Builder(ms.getConfiguration(),ms.getId(),newSqlSource,ms.getSqlCommandType());

        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
//        builder.keyProperty(ms.getKeyProperty());
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());

        return builder.build();
    }

    /**
     * 复制BoundSql对象
     */
    private BoundSql copyFromBoundSql(MappedStatement ms, BoundSql boundSql, String sql) {
        BoundSql newBoundSql = new BoundSql(ms.getConfiguration(),sql, boundSql.getParameterMappings(), boundSql.getParameterObject());
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }
        return newBoundSql;
    }
}
