从 SQL 语句中解析出源表和结果表

SQL 解析器可以将 SQL 语句解析成一棵抽象语法树(AST)。遍历 AST,从所有的叶子节点中可以找到本条SQL 语句中需要的所有表。

在 Calcite 中,解析出的 AST 是以 SqlNode 的形式表现的,一个 SqlNode 即是 AST 中的一个节点。SqlNode 有众多的子类,但是因为我们的目标只是为了找出语句中涉及到的表,因而我们重点关注会出现对表的引用的节点。表名在 AST 中会是一个 SqlIdentifier 的叶子结点,但并非所有 SqlIdentifier 叶子结点都对应表,列名也对应 SqlIdentifier

在一条 SQL 中,最终出现表的引用的情况归结于以下两种情况:

  1. SELECT 语句的 FROM clause 中的直接引用
  2. JOIN 语句中 LEFT 和 RIGHT clause 中的直接引用

嵌套子查询的 SQL 语句中,最终进入到子查询的 AST 子树中,只要出现了对表的引用,一定会分解出以上两种结构。因此,对于一个 SqlIdentifier 类型的叶子节点,在:

  1. 父节点是 SqlSelect,且当前节点是父节点的 FROM 子句派生出的子节点
  2. 父节点是 SqlJoin

这两种情况下,该叶子结点就是一个表的引用。

另外,一种特殊的情况需要加以考虑。在 SQL 中 AS 常用作起别名,因而可能 SqlIdentifier 的父节点是 AS,而 AS 的父节点是 SELECTJOIN。这种情况下,我们可以将 AS 看作一种 “转发” 结点,即 AS 的父节点和子节点忽略掉 AS 结点,直接构成父子关系。

从根结点开始遍历 AST,解析所有的子查询,找到符合上述两种情况的子结构,就可以提取出所有对表的引用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
 /**
* get source tables and target tables from AST (SqlNode)
*
* @param sqlNode a top level SqlNode, aka Root of AST, can only be SELECT OR INSERT
* @return a tuple contains source tables and target tables name
*/
def extractTablesInSql(sqlNode: SqlNode): (Set[String], Set[String]) = {
sqlNode.getKind match {
case SqlKind.INSERT =>
extractTablesInSqlInsert(sqlNode.asInstanceOf[SqlInsert])
case SqlKind.SELECT =>
extractTablesInSqlSelect(sqlNode.asInstanceOf[SqlSelect])
case _ =>
throw new IllegalArgumentException(s"Can not parse tables in $sqlNode")
}
}

private def extractTablesInSqlSelect(sqlSelect: SqlSelect): (Set[String], Set[String]) = {
val sourceTables = extractSourceTableInSql(sqlSelect,false)
val targetTables = Set[String]() //target tables is empty
(sourceTables,targetTables)
}

private def extractTablesInSqlInsert(sqlInsert: SqlInsert): (Set[String],Set[String]) = {
val sourceTables = extractSourceTableInSql(sqlInsert.getSource,false)
require(sqlInsert.getTargetTable.isInstanceOf[SqlIdentifier])
val targetTables = Set(sqlInsert.getTargetTable.asInstanceOf[SqlIdentifier].toString)
(sourceTables,targetTables)
}

/**
* parse source tables recursively
*
* @param sqlNode a sqlNode may contains source table
* @param fromOrJoin a boolean value indicate this node is
* 1. A FROM child of a SELECT node,
* or
* 2. A child of a JOIN node
* source table only occurred in SubQuery
* @return a set of source table names
*/
private def extractSourceTableInSql(sqlNode: SqlNode, fromOrJoin: Boolean): Set[String] = {
if (sqlNode == null) {
//no source table
Set[String]()
} else {
sqlNode.getKind match {
case SqlKind.SELECT =>
//may have subQuery in FROM, SELECT list, WHERE, HAVING
val selectNode = sqlNode.asInstanceOf[SqlSelect]
val sourceInFromClause = extractSourceTableInSql(selectNode.getFrom,true)
val sourceInSelectListClause = selectNode.getSelectList.getList.asScala
.filter(_.isInstanceOf[SqlCall]) //if not SqlCall, just ignore it
.foldLeft[Set[String]](Set())((s,node) => {
s ++ extractSourceTableInSql(node,false)
})
val sourceInWhereClause = extractSourceTableInSql(selectNode.getWhere,false)
val sourceInHavingClause = extractSourceTableInSql(selectNode.getHaving,false)
sourceInFromClause ++ sourceInSelectListClause ++ sourceInWhereClause ++ sourceInHavingClause
case SqlKind.JOIN =>
val left = extractSourceTableInSql(sqlNode.asInstanceOf[SqlJoin].getLeft,true)
val right = extractSourceTableInSql(sqlNode.asInstanceOf[SqlJoin].getRight,true)
left ++ right
case SqlKind.AS =>
//AS node should at least 2 operand
require(sqlNode.asInstanceOf[SqlCall].operandCount() >= 2)
//AS only consider operand[0], forward fromOrJoin relation to next level
extractSourceTableInSql(sqlNode.asInstanceOf[SqlCall].operand(0),fromOrJoin)
case SqlKind.IDENTIFIER =>
if (fromOrJoin) {
//If this IDENTIFIER is one of
// 1. A FROM child of a SELECT node,
// 2. A child of a JOIN node
// then it is a table name.
Set(sqlNode.asInstanceOf[SqlIdentifier].toString)
} else {
//NOT a table name, may be column name or other identifier
Set()
}
case _ if sqlNode.isInstanceOf[SqlCall] =>
//If is a SqlCall, find tables in all child node.
sqlNode.asInstanceOf[SqlCall].getOperandList.asScala
.foldLeft[Set[String]](Set())((s,node) => {
s ++ extractSourceTableInSql(node,false)
})
case _ =>
//all other kind of SqlNode, no source table
Set[String]()
}
}
}