package com.abcxin.smart.core.persistence.interceptor;

import com.abcxin.smart.core.persistence.IDialect;
import com.abcxin.smart.core.persistence.dialect.MySqlDialect;
import com.abcxin.smart.core.persistence.dialect.OracleDialect;
import com.abcxin.smart.core.persistence.utils.Constants;
import com.abcxin.smart.validator.annotation.ModelAnnotation;
import com.abcxin.smart.validator.annotation.ModelTableAnnotation;
import com.abcxin.smart.validator.util.JsqlparserUtil;
import com.abcxin.smart.validator.util.SingletonMapCopyUtils;
import com.alibaba.fastjson.JSON;
import com.bcxin.ars.dto.PropertiesDTO;
import com.bcxin.ars.model.User;
import com.com.bcxin.ars.com.abcxin.smart.core.web.validate.AjaxPageResponse;
import com.mysql.jdbc.StringUtils;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.session.Session;
import org.apache.shiro.subject.Subject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.text.DateFormat;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.regex.Pattern;

/**
 * Mybatis分页拦截器
 *
 * @author subinghui
 * @date 2017/1/13 16:32
 */
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})})
public class PaginationInterceptor implements Interceptor {

    private static final Logger log = LoggerFactory.getLogger(PaginationInterceptor.class);

    private static Pattern sqlPattern = Pattern.compile("(\\()(.*?)(\\))");

    private String dialectType;

    private IDialect dialect;


    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object target = invocation.getTarget();
        String methodName = invocation.getMethod().getName();
        if(methodName.equals("prepare")){
            StatementHandler statementHandler = (StatementHandler) target;
            MetaObject metaStatementHandler = SystemMetaObject.forObject(statementHandler);

            // 修复父类转子类时属性丢失问题
            Object bounds = metaStatementHandler.getValue("delegate.rowBounds");
            RowBounds rowBounds = (RowBounds) bounds;
            MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");
            Connection connection = (Connection) invocation.getArgs()[0];
            // 原始语句
            BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
            String originalSql = boundSql.getSql();
            String sqlId = mappedStatement.getId();
            Configuration configuration = mappedStatement.getConfiguration();
            /*如果是修改脚本 则记录原数据与新数据到表中*/
            /*if(originalSql.toLowerCase().startsWith("update") || originalSql.toLowerCase().startsWith("delete")) {
                //记录sql脚本执行 开始
                try {
                    String sql = getSql(sqlId, configuration, boundSql);
                    if(originalSql.toLowerCase().startsWith("update")) {
                        long time1=System.currentTimeMillis();
                        saveOperationSqlLog(connection, mappedStatement, sql);
                        long time2=System.currentTimeMillis();
                        log.info("记录sql脚本执行修改耗时："+(time2-time1)+"ms");
                    }
                }catch (Exception e){
                    log.error(e.getMessage(),e);
                }
            }*/
            /*记录sql脚本执行 结束*/

            // 不需要分页的场合
            if (rowBounds == null || rowBounds == RowBounds.DEFAULT) {
                return invocation.proceed();
            }

            //对dialect进行处理
            if (dialect == null) {
                synchronized (this) {
                    if (dialect == null) {
                        //获取数据库
                        String productName = connection.getMetaData().getDatabaseProductName();
                        if (log.isTraceEnabled()) {
                            log.trace("数据库产品名称: " + productName);
                        }
                        productName = productName.toLowerCase();
                        if (productName.indexOf(IDialect.MYSQL) != -1) {
                            dialectType = IDialect.MYSQL;
                            dialect = new MySqlDialect();
                        } else if (productName.indexOf(IDialect.ORACLE) != -1) {
                            dialectType = IDialect.ORACLE;
                            dialect = new OracleDialect();
                        }
                        // 未配置方言则抛出异常
                        if (dialect == null) {
                            throw new IllegalArgumentException("没有适合数据库" + productName + "的方言类（用于自动分页等）");
                        }
                        if (log.isInfoEnabled()) {
                            log.info("自动检测到的数据库类型为: " + dialectType);
                        }
                    }
                }
            }

            // 组装分页语句
            String paginationSql = "";
            if (bounds instanceof AjaxPageResponse) {
                AjaxPageResponse pagination = (AjaxPageResponse) bounds;
                //是否分页查询
                if(pagination.isPagination()) {
                    int offset = (pagination.getPageNumber() - 1) * pagination.getPageSize();

                    if(!StringUtils.isNullOrEmpty(pagination.getSort()) && !StringUtils.isNullOrEmpty(pagination.getOrder())){
                        originalSql = originalSql + " ORDER BY " +pagination.getSort() +" "+ pagination.getOrder();
                    }

                    paginationSql = dialect.buildPaginationSql(originalSql, offset, pagination.getPageSize());
                }else{
                    //不分页
                    return invocation.proceed();
                }
            } else {
                paginationSql = dialect.buildPaginationSql(originalSql, rowBounds.getOffset(), rowBounds.getLimit());
            }
            metaStatementHandler.setValue("delegate.boundSql.sql", paginationSql);

            // 禁用内存分页
            metaStatementHandler.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);
            metaStatementHandler.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);

            // 判断是否需要查询总记录条数
            if (bounds instanceof AjaxPageResponse) {
                AjaxPageResponse pagination = (AjaxPageResponse) bounds;
                if (pagination.getTotal() == 0) {
                    count(originalSql, connection, mappedStatement, boundSql, pagination);
                }
            }
        } else if(methodName.equals("query")) {
            RowBounds rowBounds = (RowBounds) invocation.getArgs()[2];
            if (rowBounds instanceof AjaxPageResponse) {
                AjaxPageResponse pagination = (AjaxPageResponse) rowBounds;
                Object data = invocation.proceed();
                if (data instanceof List) {
                    pagination.setData((List) data);
                }
                return data;
            }
        }

        return invocation.proceed();
    }

    /**
     * 查询总记录条数
     *
     * @param sql
     * @param connection
     * @param mappedStatement
     * @param boundSql
     * @param page
     */

    public void count(String sql, Connection connection, MappedStatement mappedStatement, BoundSql boundSql, AjaxPageResponse page) {
        // 记录总记录数
        //TODO 记得对sql进行处理（删除掉排序、不分组情况下删除掉select 和form之间的语句，使用count(0)

        String countSql = null;
        List<ParameterMapping> newList = new ArrayList<>();
        List<ParameterMapping> oldParams = boundSql.getParameterMappings();

        //list拷贝
        if(oldParams != null) {
            for (ParameterMapping oldParam : oldParams) {
                newList.add(oldParam);
            }
        }
        try {
            countSql = JsqlparserUtil.getCountSql(sql,newList);
            System.out.println(countSql);
        } catch (Exception e) {
            //e.printStackTrace();
            log.error("现有分页不支持复杂查询sql:"+sql+"使用默认分页");
            countSql = "SELECT COUNT(0) FROM (" + sql + ") as total";
        }
        /*if(page.isSpecialCount()){
            countSql = "SELECT COUNT(0) FROM (" + sql + ") as total";
        }else {
            String upperSql = sql.toUpperCase()
                    .replaceAll("\r\n"," ")
                    .replaceAll("\n\t\t\t"," ")
                    .replaceAll("\n"," ");
            String subSql = null;
            if (upperSql.split("FROM ").length > 2) {
                //查找()中的sql
                Matcher mat = sqlPattern.matcher(sql);
                int index = -1;
                while (mat.find()) {
                    subSql = mat.group();
                    System.out.println(subSql);
                    //查找()之后 的from位置
                    if (upperSql.indexOf("FROM ", sql.indexOf(subSql) + subSql.length()) > -1) {
                        index = upperSql.indexOf("FROM ", sql.indexOf(subSql) + subSql.length());
                    }
                }
                countSql = "SELECT COUNT(0) " + sql.substring(index);
            } else {
                countSql = "SELECT COUNT(0) " + sql.substring(upperSql.indexOf("FROM "));
            }
        }*/
        log.debug("分页时, 生成countSql: " + countSql);
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = connection.prepareStatement(countSql);
            BoundSql countBS = new BoundSql(mappedStatement.getConfiguration(), countSql,newList, boundSql.getParameterObject());
            ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), countBS);
            parameterHandler.setParameters(pstmt);

            rs = pstmt.executeQuery();
            long total = 0;
            if (rs.next()) {
                total = rs.getLong(1);
            }
            page.setTotal(total);
        } catch (SQLException e) {
            log.error("查询总数出错", e);
        } finally {
            if (rs != null) {
                try {
                    rs.close();
                } catch (Exception e) {
                    log.error("关闭ResultSet时异常.", e);
                }
            }
            if (pstmt != null) {
                try {
                    pstmt.close();
                } catch (Exception e) {
                    log.error("关闭PreparedStatement时异常.", e);
                }
            }
        }
    }


    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

        String dialectType = properties.getProperty("dialectType");

        if (dialectType != null && !"".equals(dialectType.trim())) {
            this.dialectType = dialectType;
        }

        // 定义数据库方言
        if (dialectType != null && !"".equals(dialectType)) {
            switch (dialectType.toLowerCase()) {
                case "mysql":
                    dialect = new MySqlDialect();
                    break;
                case "oracle":
                    dialect = new OracleDialect();
                    break;
                default:
                    break;
            }
        }
    }

    private static String getParameterValue(Object obj) {
        String value = null;
        if (obj instanceof String) {
            value = "'" + obj.toString() + "'";
        } else if (obj instanceof Date) {
            DateFormat formatter = DateFormat.getDateTimeInstance(DateFormat.DEFAULT, DateFormat.DEFAULT, Locale.CHINA);
            value = "'" + formatter.format(obj) + "'";
        } else {
            if (obj != null) {
                value = obj.toString();
            } else {
                value = "null";
            }

        }
        //附件上传带?,解析SQL有问题，统一附件把getResource.do\?path= 替换为空，add 2019-02-04 subh
        value = value.replaceAll("getResource.do\\?path=","").replaceAll("\\?","");
        //附件上传带?,解析SQL有问题，统一附件把getResource.do\?path= 替换为空，add 2019-02-04 subh

        return value;
    }


    /**
     * 记录操作日志 只记录单表操作
     * @param connection
     * @param mappedStatement
     * @param sql
     * @throws Throwable
     */
    private void saveOperationSqlLog(Connection connection,MappedStatement mappedStatement, String sql) throws Throwable{
        log.info("保存Sql: " + sql);
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            Update updateStatement = JsqlparserUtil.getUpdateStatement(sql);
            //获得表名
            List<String> tableNames = JsqlparserUtil.getTableNames(updateStatement,sql);
            //如果表多于1个，则是复杂更新，忽略，只记录单表操作
            if(tableNames!= null && tableNames.size() == 1){
                //获取修改的字段及值
                Map<String,String> setColumns = JsqlparserUtil.getSetColumns(updateStatement,sql);
                //获取修改条件的字段及值
                Map<String,String> whereColumns = JsqlparserUtil.getWhereColumns(updateStatement,sql);
                //如果获取不到ID，则不记录到日志表中，add subh 2019-02-05 start
                if(whereColumns.get(Constants.ID) !=null && !Constants.ID_EMPTY.equals(whereColumns.get(Constants.ID) )) {
                    /*获取记录triLog的主键 需要有数据库函数 nextval currval*/
                    String triLogId = "";
                    pstmt = connection.prepareStatement(Constants.SEQ_TRILOGID_SQL);
                    rs = pstmt.executeQuery();
                    while (rs.next()) {
                        triLogId = rs.getString(1);
                    }

                    DateTimeFormatter format = DateTimeFormatter.ofPattern(Constants.FORMAT_TIME);
                    pstmt = connection.prepareStatement(Constants.INSERT_TRILOG_SQL);
                    //日志主键
                    pstmt.setString(1, triLogId);
                    //生成日期
                    pstmt.setString(2, LocalDateTime.now().format(format));
                    //操作标识
                    pstmt.setString(3, Constants.BEFORE);
                    pstmt.setString(4,  Constants.UPDATE);
                    //表名
                    // pstmt.setString(5, tableNames.get(0));

                    //获取属性文件trilog.properties所对应的对象
                    PropertiesDTO propertiesDTO = JSON.parseObject(SingletonMapCopyUtils.mapProperties.get(tableNames.get(0)), PropertiesDTO.class);
                    //存储字段中文信息
                    Map<String,String> mapColumns= new HashMap<>();
                    //存储需要编码转换的字段信息
                    Map<String,String> mapCodeColumns= new HashMap<>();
                    if(null!=propertiesDTO){
                        //根据类路径反射得到class对象
                        Class classParameter = Class.forName(propertiesDTO.getModelPath());
                        //根据注解得到表中文名称
                        String tableName="";
                        if(classParameter.isAnnotationPresent(ModelTableAnnotation.class)) {
                            ModelTableAnnotation tableAnnotation = (ModelTableAnnotation) classParameter.getAnnotation(ModelTableAnnotation.class);
                            tableName = tableAnnotation.getName();
                        }
                        pstmt.setString(5, tableNames.get(0)+(tableName==null?"":tableName));

                        //字段中文名
                        Field[] fields = classParameter.getDeclaredFields();
                        for (Field field : fields) {
                            if (field.isAnnotationPresent(ModelAnnotation.class)) {
                                ModelAnnotation resource = field.getAnnotation(ModelAnnotation.class);
                                //用mapColumns存储字段中文名
                                mapColumns.put(resource.column(),resource.getName());
                                //需要编码转换
                                if (resource.needTranslate()) {
                                    mapCodeColumns.put(resource.column(),resource.getName());
                                }
                                if(null!=resource.dictName()){
                                    mapCodeColumns.put(resource.column(),resource.dictName());
                                }
                            }
                        }
                    }else{
                        pstmt.setString(5, tableNames.get(0));
                    }


                    //表主键
                    pstmt.setString(6, whereColumns.get(Constants.ID) == null ? Constants.EMPTY : whereColumns.get(Constants.ID));
                    //获取当前用户信息,定时器获取不到当前用户信息
                    String dbUser = Constants.OPERATION_SYSTEM;
                    String dbUserRealName = Constants.OPERATION_SYSTEM_REALNAME;
                    if(SecurityUtils.getSubject()!=null) {
                        Subject currentUser = SecurityUtils.getSubject();
                        Session session = currentUser.getSession();
                        User user = null;
                        if (session.getAttribute(Constants.CURRENTUSER) != null) {
                            user = (User) session.getAttribute(Constants.CURRENTUSER);
                            //主键
                            dbUser = user.getId().toString();
                            //名称
                            dbUserRealName = user.getRealname();
                        }
                    }
                    pstmt.setString(7, dbUser);
                    pstmt.setString(8, dbUserRealName);
                    pstmt.execute();

                /*查询修改前的数据并记录到 tri_log_detail 表中*/
                    String querySql = JsqlparserUtil.getQuerySql(updateStatement, sql);
                    pstmt = connection.prepareStatement(querySql);
                    rs = pstmt.executeQuery();

                    while (rs.next()) {
                        for (String key : setColumns.keySet()) {
                            String dbValue = rs.getString(key);
                            String newValue = setColumns.get(key);

                            if(!StringUtils.isNullOrEmpty(mapCodeColumns.get(key))){
                                //从map中读取编码值
                                dbValue = SingletonMapCopyUtils.mapCode.get(key.toUpperCase()+dbValue);
                                newValue = SingletonMapCopyUtils.mapCode.get(key.toUpperCase()+newValue.replaceAll(Constants.SINGLE_QUOTES,""));
                                //字段名与编码类型名不一致，则用dicName属性作为key去读
                                if(StringUtils.isNullOrEmpty(dbValue)){
                                    //读取旧值
                                    dbValue=SingletonMapCopyUtils.mapCode.get(mapCodeColumns.get(key).toUpperCase()+ rs.getString(key));
                                }
                                if(StringUtils.isNullOrEmpty(newValue)){
                                    //读取新值
                                    newValue=SingletonMapCopyUtils.mapCode.get(mapCodeColumns.get(key).toUpperCase()+setColumns.get(key).replaceAll("'",""));
                                }
                            }


                            if (StringUtils.isNullOrEmpty(dbValue)) {
                                dbValue = "";
                            }
                            if (StringUtils.isNullOrEmpty(newValue)) {
                                newValue = "";
                            } else {
                                if (newValue.startsWith("'")) {
                                    newValue = newValue.substring(1);
                                }
                                if (newValue.endsWith("'")) {
                                    newValue = newValue.substring(0, newValue.length() - 1);
                                }
                                if (newValue.toLowerCase().equals("true")) {
                                    newValue = "1";
                                }
                                if (newValue.toLowerCase().equals("false")) {
                                    newValue = "0";
                                }
                                if (newValue.toLowerCase().equals("null")) {
                                    newValue = "";
                                }
                            }

                            if (!dbValue.equals(newValue)) {
                                //日志详细
                                pstmt = connection.prepareStatement(Constants.INSERT_TRILOGDETAIL_SQL);
                                //生成时间
                                pstmt.setString(1, LocalDateTime.now().format(format));
                                //字段
                                // pstmt.setString(2, key);
                                pstmt.setString(2, key+(mapColumns.get(key)==null?"":mapColumns.get(key)));
                                //旧值
                                pstmt.setString(3, dbValue);
                                //新值
                                pstmt.setString(4, newValue);
                                //日志主键
                                pstmt.setString(5, triLogId);
                                pstmt.execute();
                            }
                        }
                    }


                    /*保存修改的sql到表operationSqlLog中*/
                    String sqlId = mappedStatement.getId();
                    String[] args = {LocalDateTime.now().format(format), sqlId, sql, triLogId};
                    pstmt = connection.prepareStatement(Constants.INSERT_OPERATIONSQLLOG_SQL);
                    for (int i = 0; i < args.length; i++) {
                        pstmt.setString(i + 1, args[i]);
                    }
                    pstmt.execute();
                }
            }
        } catch (SQLException e) {
            log.error("记录操作日志出错，请检查数据库中表operationSqlLog、tri_log_detail、tri_log是否存在，及函数nextval、currval是否存在", e);
        } finally {
            if (rs != null) {
                try {
                    rs.close();
                } catch (Exception e) {
                    log.error("关闭ResultSet时异常.", e);
                }
            }
            if (pstmt != null) {
                try {
                    pstmt.close();
                } catch (Exception e) {
                    log.error("关闭PreparedStatement时异常.", e);
                }
            }
        }

        //Update updateStatement = JsqlparserUtil.getUpdateStatement(sql);
        //JsqlparserUtil.getQuerySql(updateStatement,sql);
        //
        //
        //String sqlId = mappedStatement.getId();
        ////Object bounds = metaStatementHandler.getValue("delegate.rowBounds");
        //// 判断是否需要查询总记录条数
        ////if (bounds instanceof AjaxPageResponse) {
        ////    AjaxPageResponse pagination = (AjaxPageResponse) bounds;
        ////}
        //
        ///*Object parameter = boundSql.getParameterObject();
        //Class classParameter = (Class) parameter.getClass();
        //if(classParameter.isAnnotationPresent(ModelTableAnnotation.class)){
        //    ModelTableAnnotation tableAnnotation = (ModelTableAnnotation) classParameter.getAnnotation(ModelTableAnnotation.class);
        //    String tableName = tableAnnotation.tableName();
        //    Field idField = classParameter.getSuperclass().getDeclaredField("id");
        //    idField.setAccessible(true);
        //    String id = idField.get(parameter).toString();
        //
        //    String querySql = "select * from "+ tableName + " where id=" + id;
        //    Field[] fields = classParameter.getDeclaredFields();
        //    for (Field field : fields) {
        //        if (field.isAnnotationPresent(ModelAnnotation.class)) {
        //            ModelAnnotation resource = field.getAnnotation(ModelAnnotation.class);
        //            resource.column();
        //        }
        //    }
        //}*/
        //DateTimeFormatter format = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
        //StringBuffer saveSql = new StringBuffer("insert into operationSqlLog(`create_time`,`sqlId`,`sql`) VALUES(?,?,?)");
        //String[] args = {LocalDateTime.now().format(format),sqlId,sql};
        //saveOperationSqlLog(saveSql.toString(), connection, args);
    }

    /**
     * 获取SQL语句
     * @param sqlId dao的id
     * @param configuration  配置
     * @param boundSql 语句
     * @return
     */
    public static String getSql(String sqlId,Configuration configuration, BoundSql boundSql) {
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        String sql = boundSql.getSql().replaceAll("[\\s]+", " ");
        if (parameterMappings.size() > 0 && parameterObject != null) {
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                sql = sql.replaceFirst("\\?", getParameterValue(parameterObject));

            } else {
                MetaObject metaObject = configuration.newMetaObject(parameterObject);
                for (ParameterMapping parameterMapping : parameterMappings) {
                    String propertyName = parameterMapping.getProperty();
                    if (metaObject.hasGetter(propertyName)) {
                        Object obj = metaObject.getValue(propertyName);

                        sql = sql.replaceFirst("\\?", getParameterValue(obj));
                    } else if (boundSql.hasAdditionalParameter(propertyName)) {
                        Object obj = boundSql.getAdditionalParameter(propertyName);
                        sql = sql.replaceFirst("\\?", getParameterValue(obj));
                    }  else{
                        Map map = (Map)metaObject ;
                        sql = sql.replaceFirst("\\?", getParameterValue(map.get(propertyName)));
                    }
                }
            }
        }
        log.info(sqlId + ":" + sql);
        return sql;
    }

}
