Alibaba Druid之SQL 解析器

15

1.组成部分

  • Parser
    • 词法分析
    • 语法分析
  • AST(Abstract Syntax Tree,抽象语法树)
  • Visitor

parser 由两部分组成

  • 词法分析 Lexer
  • 语法分析 Parser

词法分析:当拿到一条形如 select id, name from user 的 SQL 语句后,首先需要解析出每个独立的单词,select,id,name,from,user,也叫作 Lexer。

通过词法分析后,便要进行语法分析。

语法分析的职责就是依据语法规则,明确一个语句的语义,表达的是什意思。

AST 是 Parser 的产物

语句经过词法分析,语法分析后,它的结构需要以一种计算机能读懂的方式表达出来,最常用的就是抽象语法树。

树的概念很接近于一个语句结构的表示,一个语句:它由哪些部分组成?其中一个组成部分又有哪些部分组成?例如一条 select 语句,它由 select 列表、where 子句、排序字段、分组字段等组成,而 select 列表则由一个或多个 select 项组成,where 子句又由一个或者多个 where条件组成。

这种组成结构就是一个总分的逻辑结构,用树来表达,最合适不过。

AST详细内容:Druid_SQL_AST · alibaba/druid Wiki (github.com)

visitor模式访问

AST 仅仅是语义的表示,但如何对这个语义进行表达,便需要去访问这棵 AST,看它到底表达什么含义。通常遍历语法树,使用 visitor模式去遍历,从根节点开始遍历,一直到最后一个叶子节点,在遍历的过程中,便不断地收集信息到一个上下文中,整个遍历过程完成后,对这棵树所表达的语法含义,已经被保存到上下文了。有时候一次遍历还不够,需要二次遍历。遍历的方式,广度优先的遍历方式是最常见的。

2.快速使用

  1. 新建一个 Parser
  2. 使用 Parser 解析 SQL,生成 AST
  3. 使用 Visitor 访问 AST - 修改SQL时直接对AST进行操作

使用案例一(推荐使用案例二)

public class ParserMain {
    public static void main(String[] args) {
        String sql = "select id,name from user";
        // 新建 MySQL Parser
        SQLStatementParser parser = new MySqlStatementParser(sql);
        // 使用Parser解析生成AST,这里SQLStatement就是AST
        SQLStatement statement = parser.parseStatement();
        // 使用visitor来访问AST
        MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
        statement.accept(visitor);
        // 从visitor中拿出你所关注的信息
        System.out.println(visitor.getColumns());
    }
}

使用案例二

public class SqlParser {
    public static void main(String[] args) {
        String sql = "select age a,name n from student s inner join (select id,name from score where sex='女') temp on sex='男' and temp.id in(select id from score where sex='男') where student.name='zhangsan' group by student.age order by student.id ASC;";
        System.out.println("SQL语句为:" + sql);
        //格式化输出
        String result = SQLUtils.format(sql, JdbcConstants.MYSQL);
        System.out.println("格式化后输出:\n" + result);
        System.out.println("*********************");
        // 使用工具类直接获取到AST
        List<SQLStatement> sqlStatementList = SQLUtils.parseStatements(sql,JdbcConstants.MYSQL);
        SQLStatement stmt = sqlStatementList.get(0);
        MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
        stmt.accept(visitor);
        
        System.out.println("数据库类型\t\t" + visitor.getDbType());
        //获取字段名称
        System.out.println("查询的字段\t\t" + visitor.getColumns());
        //获取表名称
        System.out.println("表名\t\t\t" + visitor.getTables().keySet());
        System.out.println("条件\t\t\t" + visitor.getConditions());
        System.out.println("group by\t\t" + visitor.getGroupByColumns());
        System.out.println("order by\t\t" + visitor.getOrderByColumns());
    }
}

应用案例一

在where后拼接新的查询条件。

public class SqlParserUtilTest {
​
  @Test
  public void getParsedSqlList() {
    String testSql = "select c1,c2,c3 from   t1,t2  order by o1,o2";
//    String testSql = "select c1,c2,c3 from   t1,t2 where 3=3  order by o1,o2";
​
//    String testSql = "delete from t1 , t2 where id = 1 ";
//    String testSql = "delete from t1,t2  ";
​
//    String testSql = "update t1 set c1='ll', c2='11'  where c2=2";
//    String testSql = "update t1 set c1='ll' , c2='11' ";
​
// error
//    String testSql = "select c2 from t1 group by c2 having c3 limit 1,10";
//    String testSql = "select c1,c2 from t1 left join t2 on t1.a = t2.a where t1.c = 1 and t1.b = 1 group by c1,c2 having c1=2 order by c1 desc limit 1,10";
​
//    String testSql = "insert into t1(c1,c2) select c1,c2 from t2 left join t1 on t1.a = t2.a";
//    String testSql = "insert into t1(c1,c2) values('lz',20),('wl',20)";
    
    HashMap<String, Object> map = new HashMap<>();
    map.put("role", 123);
    map.put("systemId", "aaa");
    System.err.println(contactConditions(testSql, map));
  }
​
  private static String contactConditions(String sql, Map<String, Object> columnMap) {
    SQLStatement stmt = parser(sql);
    // 新增查询条件的拼接
    StringBuilder constraintsBuffer = new StringBuilder();
    Set<String> keys = columnMap.keySet();
    Iterator<String> keyIter = keys.iterator();
    if (keyIter.hasNext()) {
      String key = keyIter.next();
      constraintsBuffer.append(key).append(" = ").append(getSqlByClass(columnMap.get(key)));
    }
    while (keyIter.hasNext()) {
      String key = keyIter.next();
      constraintsBuffer.append(" AND ").append(key).append(" = ")
          .append(getSqlByClass(columnMap.get(key)));
    }
    SQLExprParser constraintsParser = SQLParserUtils.createExprParser(
        constraintsBuffer.toString(), JdbcUtils.MYSQL);
    SQLExpr constraintsExpr = constraintsParser.expr();
    // -----------------条件拼接完成----------------
    // -----------------条件拼接完成----------------
    // select 语句
    if (stmt instanceof SQLSelectStatement) {
      return select((SQLSelectStatement) stmt, constraintsExpr);
    }
    // delete  语句
    if (stmt instanceof SQLDeleteStatement) {
      return delete((SQLDeleteStatement) stmt, constraintsExpr);
    }
    // update  语句
    if (stmt instanceof SQLUpdateStatement) {
      return update((SQLUpdateStatement) stmt, constraintsExpr);
    }
    // insert 语句
    if (stmt instanceof SQLInsertStatement) {
      return insert(constraintsExpr, sql);
    }
    return sql;
  }
  
  /**
   * 获取到SQL的AST
   * @param sql SQL
   * @return AST
   */
  private static SQLStatement parser(String sql) {
    List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, JdbcUtils.MYSQL);;
    return stmtList.get(0);
  }
​
  private static String select(SQLSelectStatement selectStmt, SQLExpr constraintsExpr) {
    // 拿到SQLSelect
    SQLSelect sqlselect = selectStmt.getSelect();
    SQLSelectQueryBlock query = (SQLSelectQueryBlock) sqlselect.getQuery();
    SQLExpr whereExpr = query.getWhere();
    // 修改where表达式
    if (whereExpr == null) {
      query.setWhere(constraintsExpr);
    } else {
      SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd,
          constraintsExpr);
      query.setWhere(newWhereExpr);
    }
    sqlselect.setQuery(query);
    return sqlselect.toString();
  }
​
  private static String delete(SQLDeleteStatement deleteStat, SQLExpr constraintsExpr) {
    SQLExpr whereExpr = deleteStat.getWhere();
    if (whereExpr == null) {
      deleteStat.setWhere(constraintsExpr);
    } else {
      SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd,
          constraintsExpr);
      deleteStat.setWhere(newWhereExpr);
    }
    return deleteStat.toString();
  }
​
  private static String update(SQLUpdateStatement updateStat, SQLExpr constraintsExpr) {
    SQLExpr whereExpr = updateStat.getWhere();
    if (whereExpr == null) {
      updateStat.setWhere(constraintsExpr);
    } else {
      SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd,
          constraintsExpr);
      updateStat.setWhere(newWhereExpr);
    }
    return updateStat.toString();
  }
  
  private static String insert(SQLExpr constraintsExpr, String sql) {
    if (sql.contains("select")) {
      int select = sql.indexOf("select");
      String insertSql = sql.substring(0, select);
      String selectSql = sql.substring(select);
      selectSql = select((SQLSelectStatement) parser(selectSql), constraintsExpr);
      return insertSql + selectSql;
    } else {
      return sql;
    }
  }
​
  private static String getSqlByClass(Object value) {
    if (value instanceof Number) {
      return value + "";
    } else if (value instanceof String) {
      return "'" + value + "'";
    }
    return "'" + value.toString() + "'";
  }
}

看我完这些后你应该大体上了解了Alibaba Druid如何对SQL进行的解析,这些还无法满足你的需求的话,请再看官方文档:SQL Parser · alibaba/druid Wiki (github.com),你会有新的收获的。