SQL 解析器可以将 SQL 语句解析成一棵抽象语法树(AST)。遍历 AST,从所有的叶子节点中可以找到本条SQL 语句中需要的所有表。

在 Calcite 中,解析出的 AST 是以 SqlNode 的形式表现的,一个 SqlNode 即是 AST 中的一个节点。SqlNode 有众多的子类,但是因为我们的目标只是为了找出语句中涉及到的表,因而我们重点关注会出现对表的引用的节点。表名在 AST 中会是一个 SqlIdentifier 的叶子结点,但并非所有 SqlIdentifier 叶子结点都对应表,列名也对应 SqlIdentifier

在一条 SQL 中,最终出现表的引用的情况归结于以下两种情况:

  1. SELECT 语句的 FROM clause 中的直接引用
  2. JOIN 语句中 LEFT 和 RIGHT clause 中的直接引用

嵌套子查询的 SQL 语句中,最终进入到子查询的 AST 子树中,只要出现了对表的引用,一定会分解出以上两种结构。因此,对于一个 SqlIdentifier 类型的叶子节点,在:

  1. 父节点是 SqlSelect,且当前节点是父节点的 FROM 子句派生出的子节点
  2. 父节点是 SqlJoin

这两种情况下,该叶子结点就是一个表的引用。

另外,一种特殊的情况需要加以考虑。在 SQL 中 AS 常用作起别名,因而可能 SqlIdentifier 的父节点是 AS,而 AS 的父节点是 SELECTJOIN。这种情况下,我们可以将 AS 看作一种 “转发” 结点,即 AS 的父节点和子节点忽略掉 AS 结点,直接构成父子关系。

从根结点开始遍历 AST,解析所有的子查询,找到符合上述两种情况的子结构,就可以提取出所有对表的引用。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
   /**
    * get source tables and target tables from AST (SqlNode)
    *
    * @param sqlNode a top level SqlNode, aka Root of AST, can only be SELECT OR INSERT
    * @return a tuple contains source tables and target tables name
    */
  def extractTablesInSql(sqlNode: SqlNode): (Set[String], Set[String]) = {
    sqlNode.getKind match {
      case SqlKind.INSERT =>
        extractTablesInSqlInsert(sqlNode.asInstanceOf[SqlInsert])
      case SqlKind.SELECT =>
        extractTablesInSqlSelect(sqlNode.asInstanceOf[SqlSelect])
      case _ =>
        throw new IllegalArgumentException(s"Can not parse tables in $sqlNode")
    }
  }

  private def extractTablesInSqlSelect(sqlSelect: SqlSelect): (Set[String], Set[String]) = {
    val sourceTables = extractSourceTableInSql(sqlSelect,false)
    val targetTables = Set[String]() //target tables is empty
    (sourceTables,targetTables)
  }

  private def extractTablesInSqlInsert(sqlInsert: SqlInsert): (Set[String],Set[String]) = {
    val sourceTables = extractSourceTableInSql(sqlInsert.getSource,false)
    require(sqlInsert.getTargetTable.isInstanceOf[SqlIdentifier])
    val targetTables = Set(sqlInsert.getTargetTable.asInstanceOf[SqlIdentifier].toString)
    (sourceTables,targetTables)
  }

  /**
    * parse source tables recursively
    *
    * @param sqlNode      a sqlNode may contains source table
    * @param fromOrJoin   a boolean value indicate this node is
    *                     1. A FROM child of a SELECT node,
    *                     or
    *                     2. A child of a JOIN node
    *                     source table only occurred in SubQuery
    * @return a set of source table names
    */
  private def extractSourceTableInSql(sqlNode: SqlNode, fromOrJoin: Boolean): Set[String] = {
    if (sqlNode == null) {
      //no source table
      Set[String]()
    } else {
      sqlNode.getKind match {
        case SqlKind.SELECT =>
          //may have subQuery in FROM, SELECT list, WHERE, HAVING
          val selectNode = sqlNode.asInstanceOf[SqlSelect]
          val sourceInFromClause = extractSourceTableInSql(selectNode.getFrom,true)
          val sourceInSelectListClause = selectNode.getSelectList.getList.asScala
            .filter(_.isInstanceOf[SqlCall]) //if not SqlCall, just ignore it
            .foldLeft[Set[String]](Set())((s,node) => {
            s ++ extractSourceTableInSql(node,false)
          })
          val sourceInWhereClause = extractSourceTableInSql(selectNode.getWhere,false)
          val sourceInHavingClause = extractSourceTableInSql(selectNode.getHaving,false)
          sourceInFromClause ++ sourceInSelectListClause ++ sourceInWhereClause ++ sourceInHavingClause
        case SqlKind.JOIN =>
          val left = extractSourceTableInSql(sqlNode.asInstanceOf[SqlJoin].getLeft,true)
          val right = extractSourceTableInSql(sqlNode.asInstanceOf[SqlJoin].getRight,true)
          left ++ right
        case SqlKind.AS =>
          //AS node should at least 2 operand
          require(sqlNode.asInstanceOf[SqlCall].operandCount() >= 2)
          //AS only consider operand[0], forward fromOrJoin relation to next level
          extractSourceTableInSql(sqlNode.asInstanceOf[SqlCall].operand(0),fromOrJoin)
        case SqlKind.IDENTIFIER =>
          if (fromOrJoin) {
            //If this IDENTIFIER is one of
            // 1. A FROM child of a SELECT node,
            // 2. A child of a JOIN node
            // then it is a table name.
            Set(sqlNode.asInstanceOf[SqlIdentifier].toString)
          } else {
            //NOT a table name, may be column name or other identifier
            Set()
          }
        case _ if sqlNode.isInstanceOf[SqlCall] =>
          //If is a SqlCall, find tables in all child node.
          sqlNode.asInstanceOf[SqlCall].getOperandList.asScala
            .foldLeft[Set[String]](Set())((s,node) => {
            s ++ extractSourceTableInSql(node,false)
          })
        case _ =>
          //all other kind of SqlNode, no source table
          Set[String]()
      }
    }
  }