package com.gtis.ibatis;

import com.gtis.common.util.ReflectUtil;
import com.ibatis.common.jdbc.exception.NestedSQLException;
import com.ibatis.sqlmap.client.event.RowHandler;
import com.ibatis.sqlmap.engine.impl.ExtendedSqlMapClient;
import com.ibatis.sqlmap.engine.mapping.parameter.ParameterMap;
import com.ibatis.sqlmap.engine.mapping.result.AutoResultMap;
import com.ibatis.sqlmap.engine.mapping.result.ResultMap;
import com.ibatis.sqlmap.engine.mapping.sql.Sql;
import com.ibatis.sqlmap.engine.mapping.statement.ExecuteListener;
import com.ibatis.sqlmap.engine.mapping.statement.MappedStatement;
import com.ibatis.sqlmap.engine.mapping.statement.RowHandlerCallback;
import com.ibatis.sqlmap.engine.mapping.statement.SelectStatement;
import com.ibatis.sqlmap.engine.scope.ErrorContext;
import com.ibatis.sqlmap.engine.scope.StatementScope;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;

/**
 * 用于计数的SQL数量生成器
 *
 * @author C4ISR
 */
public class CountStatementUtil {

    public static MappedStatement createCountStatement(
            MappedStatement selectStatement) {
        return new CountStatement((SelectStatement) selectStatement);
    }

    public static String getCountStatementId(String selectStatementId) {
        return "__" + selectStatementId + "Count__";
    }
}

class CountStatement extends SelectStatement {

    public CountStatement(SelectStatement selectStatement) {
        super();
        setId(CountStatementUtil.getCountStatementId(selectStatement.getId()));
        setResultSetType(selectStatement.getResultSetType());
        setFetchSize(1);
        setParameterMap(selectStatement.getParameterMap());
        setParameterClass(selectStatement.getParameterClass());
        setSql(selectStatement.getSql());
        setResource(selectStatement.getResource());
        setSqlMapClient(selectStatement.getSqlMapClient());
        setTimeout(selectStatement.getTimeout());
        List executeListeners = (List) ReflectUtil.getFieldValue(
                selectStatement, "executeListeners", List.class);
        if (executeListeners != null) {
            for (Object listener : executeListeners) {
                addExecuteListener((ExecuteListener) listener);
            }
        }
        ResultMap resultMap = new AutoResultMap(
                ((ExtendedSqlMapClient) getSqlMapClient()).getDelegate(), false);
        resultMap.setId(getId() + "-AutoResultMap");
        resultMap.setResultClass(Long.class);
        resultMap.setResource(getResource());
        setResultMap(resultMap);

    }

    @Override
    protected void executeQueryWithCallback(StatementScope stateScope,
                                            Connection conn, Object parameterObject, Object resultObject, RowHandler rowHandler,
                                            int skipResults, int maxResults) throws SQLException {
        ErrorContext errorContext = stateScope.getErrorContext();
        errorContext
                .setActivity("preparing the mapped statement for execution");
        errorContext.setObjectId(this.getId());
        errorContext.setResource(this.getResource());

        try {
            parameterObject = validateParameter(parameterObject);

            Sql sql = getSql();

            errorContext.setMoreInfo("Check the parameter map.");
            ParameterMap parameterMap = sql.getParameterMap(stateScope,
                    parameterObject);

            errorContext.setMoreInfo("Check the result map.");
            ResultMap resultMap = getResultMap(stateScope, parameterObject, sql);

            stateScope.setResultMap(resultMap);
            stateScope.setParameterMap(parameterMap);

            errorContext.setMoreInfo("Check the parameter map.");
            Object[] parameters = parameterMap.getParameterObjectValues(
                    stateScope, parameterObject);

            errorContext.setMoreInfo("Check the SQL statement.");
            String sqlString = getSqlString(stateScope, parameterObject, sql);

            errorContext.setActivity("executing mapped statement");
            errorContext
                    .setMoreInfo("Check the SQL statement or the result map.");
            RowHandlerCallback callback = new RowHandlerCallback(resultMap,
                    resultObject, rowHandler);
            sqlExecuteQuery(stateScope, conn, sqlString, parameters, skipResults,
                    maxResults, callback);

            errorContext.setMoreInfo("Check the output parameters.");
            if (parameterObject != null) {
                postProcessParameterObject(stateScope, parameterObject, parameters);
            }

            errorContext.reset();
            sql.cleanup(stateScope);
            notifyListeners();
        } catch (SQLException e) {
            errorContext.setCause(e);
            throw new NestedSQLException(errorContext.toString(), e
                    .getSQLState(), e.getErrorCode(), e);
        } catch (Exception e) {
            errorContext.setCause(e);
            throw new NestedSQLException(errorContext.toString(), e);
        }
    }


    /**
     * 获取修改后的用于进行统计查询的SQL语句
     *
     * @param stateScope
     * @param parameterObject
     * @param sql
     * @return
     */
    private String getSqlString(StatementScope stateScope, Object parameterObject,
                                Sql sql) {
        String sqlString = sql.getSql(stateScope, parameterObject);
        sqlString = sqlString.trim();
//        if (sqlString.substring(0, 6).equalsIgnoreCase("select") && sqlString.toLowerCase().indexOf(" group ") < 0) {
//            int fromIndex = sqlString.indexOf(" from ");
//            int fromIndex1 = sqlString.indexOf("from(");
//            if (fromIndex1 > 0 && fromIndex > fromIndex1) {
//                fromIndex = fromIndex1;
//            }
//            if (fromIndex > 0) {
//                sqlString = "select count(*) as c  " + sqlString.substring(fromIndex, sqlString.length());
//            } else {
//                sqlString = "select count(*) as c from (" + sqlString + ")";
//            }
//        } else {
//            sqlString = "select count(*) as c from (" + sqlString + ")";
//        }

        char c1 = '(';
        char c2 = ')';
        while (sqlString.indexOf("order by") > 0) {
            if (sqlString.indexOf("partition by") > 0)//partition by need order by
                break;
            char[] charArray = sqlString.toCharArray();
            int start = sqlString.indexOf("order by");
            int dotCount = 0;
            for (int i = start + 7; i < charArray.length; i++) {
                if (charArray[i] == c1) {
                    dotCount++;
                }
                if (charArray[i] == c2 && dotCount == 0) {
                    sqlString = sqlString.substring(0, start) + sqlString.substring(i, charArray.length);
                    break;
                } else if (charArray[i] == c2) {
                    dotCount--;
                }
                if (i == charArray.length - 1) {
                    sqlString = sqlString.substring(0, start);
                }
            }

        }
        sqlString = "select count(*) as c from (" + sqlString + ")";
        return sqlString;
    }

    private ResultMap getResultMap(StatementScope stateScope,
                                   Object parameterObject, Sql sql) {
        return getResultMap();
    }

    public static void main(String[] param) {
        String sqlString = "select count(*) as c   from tbl_syqzs a    left join tbl_project b on substr(a.projectid,0,32) = b.projectid    left join tbl_djk c on c.projectid = a.projectid   where                   b.endtime is not null                  and           a.dwdm like ?                                                                    order by to_date(nvl(a.fzrq,'19980101'),'yyyy-mm-dd') desc";
        char c1 = '(';
        char c2 = ')';
        while (sqlString.indexOf("order by") > 0) {
            char[] charArray = sqlString.toCharArray();
            int start = sqlString.indexOf("order by");
            int dotCount = 0;
            for (int i = start + 7; i < charArray.length; i++) {
                if (charArray[i] == c1) {
                    dotCount++;
                }
                if (charArray[i] == c2 && dotCount == 0) {
                    sqlString = sqlString.substring(0, start) + sqlString.substring(i, charArray.length);
                    break;
                } else if (charArray[i] == c2) {
                    dotCount--;
                }
                if (i == charArray.length - 1) {
                    sqlString = sqlString.substring(0, start);
                }
            }

        }
        System.out.println(sqlString);
    }
}
