package com.abcxin.smart.validator.util;

import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.util.TablesNamesFinder;
import org.apache.ibatis.mapping.ParameterMapping;

import java.io.StringReader;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * sql脚本解析工具
 * @author linqinglin
 * @date 2018/12/08 0008 17:46
 */
public class JsqlparserUtil {

    private static CCJSqlParserManager parser = new CCJSqlParserManager();

    /* 私有构造方法，防止被实例化 */
    private JsqlparserUtil() {
    }

    public static void main(String[] args) throws Exception {
        String sql = "select aa,(select bb from cc where dd >1) as ee from ff,dd left join aad on ada=adsfasd where gg=1";
        System.out.println(getCountSql(sql,null));
        //parser("UPDATE user SET update_time = '2018-12-8 14:20:04', loginstate = 1, registrationID = '41186', loginnum = 1617, logindate = '2018-12-8 14:20:04', updateflag =true, uuid = null WHERE id=41186 and abc=2 and ddd=3");
        //parser("UPDATE approve_link_pro_detail SET active = FALSE, update_by = '111' WHERE approveAreaProId IN (SELECT b.id FROM approve_link_level a, approve_area_process b WHERE a.tempLinkRelaId = 1 AND b.approveLinkLevelId IN (a.id)) AND  tempLinkRelaId = 1 AND  active = TRUE");
    }

    public static void parser(String sql) throws JSQLParserException{
        long time1=System.currentTimeMillis();
        System.out.println(sql);
        //System.out.println(getTableNames(sql).toString());
        //System.out.println(getSetColumnNames(sql).toString());
        //System.out.println(getWhereColumnNames(sql).toString());
        Update updateStatement = getUpdateStatement(sql);
        //System.out.println(getQuerySql(updateStatement,sql));
        long time2=System.currentTimeMillis();
        System.out.println("当前程序耗时："+(time2-time1)+"ms");
    }

    public static Update getUpdateStatement(String sql) throws JSQLParserException{
        Statement statement = parser.parse(new StringReader(sql));
        if (statement instanceof Update) {
            return (Update) statement;
        }
        return null;
    }

    public static List<String> getTableNames(String sql) throws JSQLParserException{
        Statement statement = parser.parse(new StringReader(sql));
        return getTableNames((Update)statement,sql);
    }

    public static List<String> getTableNames(Update updateStatement, String sql) throws JSQLParserException{
        TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
        return tablesNamesFinder.getTableList(updateStatement);
    }

    public static Map<String,String> getSetColumnNames(String sql) throws JSQLParserException{
        Statement statement = parser.parse(new StringReader(sql));
        if (statement instanceof Update) {
            //获得Update对象
            Update updateStatement = (Update) statement;
            return getSetColumns(updateStatement,sql);
        }
        return null;
    }

    public static Map<String,String> getSetColumns(Update updateStatement, String sql) throws JSQLParserException{
        if (updateStatement != null) {
            //初始化接收获得到的字段信息
            Map<String,String> allColumnNames = new HashMap<String,String>();
            List<Column> columns = updateStatement.getColumns();
            List<Expression> values = updateStatement.getExpressions();

            for (int i = 0; i < columns.size(); i++) {
               // allColumnNames.put(columns.get(i).getColumnName().toLowerCase(),values.get(i).toString());
                allColumnNames.put(columns.get(i).getColumnName(),values.get(i).toString());
            }
            return allColumnNames;
        }

        return null;
    }


    public static String getQuerySql(String sql) throws JSQLParserException{
        Statement statement = parser.parse(new StringReader(sql));
        if (statement instanceof Update) {
            //获得Update对象
            Update updateStatement = (Update) statement;
            return getQuerySql(updateStatement,sql);
        }

        return null;
    }

    public static String getQuerySql(Update updateStatement, String sql) throws JSQLParserException{
        if (updateStatement != null) {
            List<String> tableNames = getTableNames(updateStatement,sql);
            StringBuffer querySql = new StringBuffer();
            querySql.append("select ");

            List<Column> columns = updateStatement.getColumns();
            querySql.append(" " + columns.toString().replaceAll("\\[","").replaceAll("]",""));


            querySql.append(" from " + tableNames.toString().toLowerCase().replaceAll("\\[","").replaceAll("]",""));

            Expression where = updateStatement.getWhere();
            querySql.append(" where " + where.toString().toLowerCase());

            return querySql.toString();
        }

        return null;
    }

    public static Map<String,String> getWhereColumns(String sql) throws JSQLParserException{
        Statement statement = parser.parse(new StringReader(sql));
        if (statement instanceof Update) {
            //获得Update对象
            Update updateStatement = (Update) statement;
            getWhereColumns(updateStatement,sql);
        }
        return null;
    }

    public static Map<String,String> getWhereColumns(Update updateStatement, String sql) throws JSQLParserException{
        if (updateStatement != null) {
            Expression where = updateStatement.getWhere();
            //初始化接收获得到的字段信息
            Map<String,String> allColumnNames = new HashMap<>();
            //BinaryExpression包括了整个where条件，
            //例如：AndExpression/LikeExpression/OldOracleJoinBinaryExpression
            if(where instanceof BinaryExpression){
                return getColumnName((BinaryExpression)(where),allColumnNames);
            }
        }
        return null;
    }

    /**
     * 获得where条件字段中列名，以及对应的操作符
     * @Title: getColumnName
     * @Description: TODO(这里用一句话描述这个方法的作用)
     * @param @param expression
     * @param @param allColumnNames
     * @param @return 设定文件
     * @return StringBuffer 返回类型
     * @throws
     */
    private static Map<String,String> getColumnName(Expression expression,Map<String,String> allColumnNames) {
        String columnName = null;
        if (expression instanceof BinaryExpression) {
            //获得左边表达式
            Expression leftExpression = ((BinaryExpression) expression).getLeftExpression();
            //获得右边表达式，并分解
            Expression rightExpression = ((BinaryExpression) expression).getRightExpression();

            //如果左边表达式为Column对象，则直接获得列名
            if (leftExpression instanceof Column) {
                //获得列名
                columnName = ((Column) leftExpression).getColumnName();
                allColumnNames.put(columnName.toLowerCase(),rightExpression.toString());
            }
            //否则，进行迭代
            else if (leftExpression instanceof BinaryExpression) {
                getColumnName((BinaryExpression) leftExpression, allColumnNames);
            }
            if (rightExpression instanceof BinaryExpression) {
                Expression leftExpression2 = ((BinaryExpression) rightExpression).getLeftExpression();
                Expression rightExpression2 = ((BinaryExpression) rightExpression).getRightExpression();
                if (leftExpression2 instanceof Column) {
                    //获得列名
                    columnName = ((Column) leftExpression2).getColumnName();
                    allColumnNames.put(columnName.toLowerCase(),rightExpression2.toString());
                }
            }
        }
        return allColumnNames;
    }


    /**
     * 获取分页统计总数的sql
     * @param sql 原查询sql
     * @param params 查询参数
     * @return
     * @throws Exception
     */
    public static String getCountSql(String sql, List<ParameterMapping> params) throws Exception {
        Statement statement = parser.parse(new StringReader(sql));

        StringBuilder countSql = new StringBuilder("SELECT COUNT(1) FROM ");
        if(statement instanceof Select){
            Select select = (Select) statement;

            //处理select item带传参的
            PlainSelect selectBody = (PlainSelect)select.getSelectBody();
            for (SelectItem selectItem : selectBody.getSelectItems()) {
                String[] split = selectItem.toString().split("\\?");
                if(split.length>1){
                    for (int i = 1; i < split.length ; i++) {
                        params.remove(0);
                    }
                }
            }
            countSql.append(selectBody.getFromItem()+" ");
            if(selectBody.getJoins() != null) {
                for (Join join : selectBody.getJoins()) {
                    if(join.isSimple()){
                        countSql.append(", ");
                    }
                    countSql.append(join.toString()+" ");
                }
            }
            if(selectBody.getWhere() != null) {
                countSql.append("WHERE " + selectBody.getWhere());
            }
        }
        return countSql.toString();
    }
}
