1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
|
/**
* get source tables and target tables from AST (SqlNode)
*
* @param sqlNode a top level SqlNode, aka Root of AST, can only be SELECT OR INSERT
* @return a tuple contains source tables and target tables name
*/
def extractTablesInSql(sqlNode: SqlNode): (Set[String], Set[String]) = {
sqlNode.getKind match {
case SqlKind.INSERT =>
extractTablesInSqlInsert(sqlNode.asInstanceOf[SqlInsert])
case SqlKind.SELECT =>
extractTablesInSqlSelect(sqlNode.asInstanceOf[SqlSelect])
case _ =>
throw new IllegalArgumentException(s"Can not parse tables in $sqlNode")
}
}
private def extractTablesInSqlSelect(sqlSelect: SqlSelect): (Set[String], Set[String]) = {
val sourceTables = extractSourceTableInSql(sqlSelect,false)
val targetTables = Set[String]() //target tables is empty
(sourceTables,targetTables)
}
private def extractTablesInSqlInsert(sqlInsert: SqlInsert): (Set[String],Set[String]) = {
val sourceTables = extractSourceTableInSql(sqlInsert.getSource,false)
require(sqlInsert.getTargetTable.isInstanceOf[SqlIdentifier])
val targetTables = Set(sqlInsert.getTargetTable.asInstanceOf[SqlIdentifier].toString)
(sourceTables,targetTables)
}
/**
* parse source tables recursively
*
* @param sqlNode a sqlNode may contains source table
* @param fromOrJoin a boolean value indicate this node is
* 1. A FROM child of a SELECT node,
* or
* 2. A child of a JOIN node
* source table only occurred in SubQuery
* @return a set of source table names
*/
private def extractSourceTableInSql(sqlNode: SqlNode, fromOrJoin: Boolean): Set[String] = {
if (sqlNode == null) {
//no source table
Set[String]()
} else {
sqlNode.getKind match {
case SqlKind.SELECT =>
//may have subQuery in FROM, SELECT list, WHERE, HAVING
val selectNode = sqlNode.asInstanceOf[SqlSelect]
val sourceInFromClause = extractSourceTableInSql(selectNode.getFrom,true)
val sourceInSelectListClause = selectNode.getSelectList.getList.asScala
.filter(_.isInstanceOf[SqlCall]) //if not SqlCall, just ignore it
.foldLeft[Set[String]](Set())((s,node) => {
s ++ extractSourceTableInSql(node,false)
})
val sourceInWhereClause = extractSourceTableInSql(selectNode.getWhere,false)
val sourceInHavingClause = extractSourceTableInSql(selectNode.getHaving,false)
sourceInFromClause ++ sourceInSelectListClause ++ sourceInWhereClause ++ sourceInHavingClause
case SqlKind.JOIN =>
val left = extractSourceTableInSql(sqlNode.asInstanceOf[SqlJoin].getLeft,true)
val right = extractSourceTableInSql(sqlNode.asInstanceOf[SqlJoin].getRight,true)
left ++ right
case SqlKind.AS =>
//AS node should at least 2 operand
require(sqlNode.asInstanceOf[SqlCall].operandCount() >= 2)
//AS only consider operand[0], forward fromOrJoin relation to next level
extractSourceTableInSql(sqlNode.asInstanceOf[SqlCall].operand(0),fromOrJoin)
case SqlKind.IDENTIFIER =>
if (fromOrJoin) {
//If this IDENTIFIER is one of
// 1. A FROM child of a SELECT node,
// 2. A child of a JOIN node
// then it is a table name.
Set(sqlNode.asInstanceOf[SqlIdentifier].toString)
} else {
//NOT a table name, may be column name or other identifier
Set()
}
case _ if sqlNode.isInstanceOf[SqlCall] =>
//If is a SqlCall, find tables in all child node.
sqlNode.asInstanceOf[SqlCall].getOperandList.asScala
.foldLeft[Set[String]](Set())((s,node) => {
s ++ extractSourceTableInSql(node,false)
})
case _ =>
//all other kind of SqlNode, no source table
Set[String]()
}
}
}
|