此前分享了通过引入外部的 GO 程序来实现 SQL 到 DSL 的转换的文章,然而在后续测试过程中发现部分条件的转换仍有问题,例如不支持 IS 和 IS NOT。所以在这段时间,决定自行编写一个更为完善的转换工具以支撑自身项目。
目前我所知晓的用于解析 SQL 的工具,像 Apache Calcite 和 Druid 皆已进行了尝试,并且查看了这两个代码包下的解析 API ,认为 Druid 的更加灵活丰富,而且网上能找到的资料也比较多,后期出现问题可能也更好解决,所以也决定基于 Druid 来完成本次对 SQL 的解析工作。
1. Druid
大多数兄弟或许一听到 Druid 就联想到数据库连接池功能,然而在 com.alibaba.druid.sql 包下其实存在很多 API 能够对 SQL 进行解析,而且支持非常丰富的数据源,例如 mysql 、Oracle 、postgresql 等常见数据库,还有像 Hive 、Starrocks 、Presto 等在大数据场景中常用的数据库。
参考资料:
https://github.com/alibaba/druid/wiki/Druid_SQL_AST#1-%E4%BB%80%E4%B9%88%E6%98%AFast
2. 解析目的
解析 SQL 的目的在于获取到 SQL 的各个部分:select 、where 、group by 、order by 、limit 、函数计算等,进而转换为 Elasticsearch 的 DSL 。所以,本篇主要介绍怎样将标准的 mysql 查询语句转换成 DSL 。
3. 抽象语法树(AST)
抽象语法树(Abstract Syntax Tree ,简称 AST )属于一种树状的数据结构,用于呈现源代码的语法构造。在 SQL 转 AST 的进程中,SQL 语句被拆解为一系列的节点,每个节点代表着 SQL 语句里的一个组成部分,例如:表名、列名、运算符、函数调用等等。这些节点以树形的结构进行组织,根节点代表着整个 SQL 语句,子节点则代表着语句中的子句、表达式等。
例如这个简单的 SQL 查询:
SELECT
name,
age
FROM
users
WHERE
age > 21;
那对应的 AST 可以认为是下面的结构.
SelectStatement
├── SelectList
│ ├── Column (name: "name")
│ └── Column (name: "age")
├── FromClause
│ └── Table (name: "users")
└── WhereClause
└── ComparisonExpression
├── Column (name: "age")
├── Operator (value: ">")
└── Literal (value: 21)
因此可以从 AST 中获取到想要的节点,并获取节点的描述信息,以上面的 AST 为例,如果想要获取到查询的字段,可以从 SelectList 节点中获取两个字段,如果想获取到表名可以从 FromClause 节点获取,诸如此类......
4. SQL和DSL对应关系
4.1. index
索引名就是 SQL 中的表名,直接获取到from的表名即可,比较简单
4.2. fields
需要返回的字段,也就是在 select
之后的字段,本人这里仅需 user_id
即可,甚至在使用 Spark 查询时,只需要返回的文档 ID
就行(因为 user_id
和文档 ID
是一致的)。
如果要从 select
之后提取字段,可以通过 getSelectList
获取到 SQLSelectItem
对象,然后再进行遍历,通过 sqlSelectItem.getExpr.getName
来获取字段名,关于这部分内容在这里就不再赘述了。
4.3. query
本人这里主要就是进行解析where条件,将不同的逻辑条件转为DSL。Druid解析后的ATS,where节点类型主要分为SQLBinaryOpExpr、SQLInListExpr和SQLBetweenExpr
SQLInListExpr其实就是sql的in,和not in 对应 DSL 为 term 的值为数组列表;SQLBetweenExpr 就是 字段在一个范围段内,对应着就是 range的from和to两个边界;SQLBinaryOpExpr就是二元值,通俗将就是左右两部分,左边为字段,右边为值,例如 a = 1之类的条件判断。诸如此类还有 = != > < 等等。
4.4. aggregations
聚合条件,目前需求中还没有需要聚合的条件,如果后面需要会更新。
4.5. sort
排序条件,同上。
5. 源码
废话不多说,直接贴出来源码,由于项目是 scala 为主,所以直接 scala 写了,java项目调用不影响。
5.1. SqlToDslTranslator
package com.dengdz.translator
import com.alibaba.druid.sql.ast.SQLStatement
import com.alibaba.druid.sql.ast.statement.{SQLDeleteStatement, SQLInsertStatement, SQLSelectStatement, SQLUpdateStatement}
/**
*
* 翻译器接口,用于将 SQL 转换为 DSL
*
* <p>接口定义了 SQL 到 DSL 转换器的操作。具体的实现类负责将来自不同数据库的 SQL 查询转换为对应的 DSL 格式。
* 例如:
* <ul>
* <li>{@link MysqlToDslTranslator} 负责将 MySQL 查询转换为 DSL。</li>
* <li>{@link HiveToDslTranslator} 负责将 Hive 查询转换为 DSL。</li>
* </ul>
* 根据需求,还可以提供其他数据库的实现。</p>
*
* @author dengdz
*
*/
trait SqlToDslTranslator {
def handleSelectStatement(stmt: SQLSelectStatement): (String, String) = ???
def handleInsertStatement(stmt: SQLInsertStatement): (String, String) = ???
def handleUpdateStatement(stmt: SQLUpdateStatement): (String, String) = ???
def handleDeleteStatement(stmt: SQLDeleteStatement): (String, String) = ???
def handleOtherStatement(stat: SQLStatement): (String, String) = ???
}
5.2. SqlToDslHandler
package com.dengdz.translator
/**
* <p>语法转换接口,用于定义条件运算的逻辑。具体的条件运算逻辑由子类实现。
* 例如:
* <ul>
* <li>{@link MysqlToDslTranslator} 实现将 MySQL 转换为 DSL</li>
* <li>{@link HiveToDslTranslator} 实现将 Hive 转换为 DSL</li>
* 根据需求,自行实现相应的语法转换器。<p>
*
* @author dengdz
*/
trait SqlToDslHandler {
def greaterThanOrEqual(field: String, value: String): Object
def equality(field: String, value: String): Object
def lessThanOrEqual(field: String, value: String): Object
def greaterThan(field: String, value: String): Object
def lessThan(field: String, value: String): Object
def notEqual(field: String, value: String): Object
def like(field: String, value: String): Object
def notLike(field: String, value: String): Object
def is(field: String, value: String): Object
def isNot(field: String, value: String): Object
def and(left: String, right: String): Object
def or(left: String, right: String): Object
def betweenExpr(field: String, beginValue: String, endValue: String): Object
def inExpr(field: String, values: List[String]): Object
def notInExpr(field: String, values: List[String]): Object
def queryAll(): String
def root(dsl: String): Object
def orRoot(dsl: String): Object
def andRoot(dsl: String): Object
def betweenRoot(dsl: String): Object
def inRoot(dsl: String): Object
def comparisonRoot(dsl: String): Object
}
5.3. MysqlToDslTranslator
package com.dengdz.translator.impl
import cn.hutool.core.util.StrUtil
import cn.hutool.json.JSONUtil
import com.alibaba.druid.sql.ast.SQLExpr
import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator._
import com.alibaba.druid.sql.ast.expr._
import com.alibaba.druid.sql.ast.statement.{SQLExprTableSource, SQLSelectItem, SQLSelectStatement}
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock
import com.dengdz.translator.{SqlToDslHandler, SqlToDslTranslator}
import com.dengdz.utils.Escaping
import java.util
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
/**
* 实现 MySQL SQL 语句到 DSL 的转换。
* <p>
* <li>1.实现 {@link `SqlToDslTranslator`} 接口中的处理增删改查(CRUD)操作的转换。</li>
* <li>2.实现 {@link `SqlToDslHandler`} 接口中逻辑运算的语法转换。</li>
* <p>
* 我这里是将mysql的标准sql转为DSL语句,后续通过spark查询,也可以实现es API的转换
*
* @author dengdz
*/
class MysqlToDslTranslator extends SqlToDslTranslator with SqlToDslHandler {
/** 索引 */
private var index = ""
/** 返回结果 */
private var resultQuery = ""
/** ******************************** SqlToDslTranslator start *********************************** */
override def handleSelectStatement(stmt: SQLSelectStatement): (String, String) = {
try {
// 获取查询语句
val sqlQuery: MySqlSelectQueryBlock =
stmt
.getSelect
.getQuery
.asInstanceOf[MySqlSelectQueryBlock]
println("sqlQuery -> " + sqlQuery)
// 查询
val select =
sqlQuery.getSelectList
// 条件
val where =
sqlQuery.getWhere
println("selectItems -> " + select)
val fields = getQueryFieldsAndMethods(select)
println("fields -> " + fields)
val query =
handleWhere(where, isRoot = true)
println("query -> " + query)
val fromSqlIdentifierExpr =
sqlQuery
.getFrom
.asInstanceOf[SQLExprTableSource]
.getExpr
.asInstanceOf[SQLIdentifierExpr]
println("fromSqlIdentifierExpr -> " + fromSqlIdentifierExpr)
index = fromSqlIdentifierExpr.getName
println("index -> " + index)
resultQuery = root(query)
} catch {
case e: Exception => throw new RuntimeException("SQL convert DSL failed. Please check the SQL", e)
}
(resultQuery, index)
}
/**
*
* @param sqlExpr
* @param isRoot 是否为根节点
* @return 将 where 条件 转为 DSL 中的 query
*/
private def handleWhere(sqlExpr: SQLExpr, isRoot: Boolean): String = {
sqlExpr match {
case binaryOpExpr: SQLBinaryOpExpr =>
println("SQLBinaryOpExpr")
binaryOpExpr.getOperator match {
case SQLBinaryOperator.BooleanOr =>
println("OR")
handleOr(binaryOpExpr)
case SQLBinaryOperator.BooleanAnd =>
println("AND")
handleAnd(binaryOpExpr)
case _ =>
println("_")
handleValueCompare(binaryOpExpr, isRoot)
}
case inaryOpExpr: SQLInListExpr =>
println("SQLInListExpr")
handleIn(inaryOpExpr, isRoot)
case betweenExpr: SQLBetweenExpr =>
println("SQLBetweenExpr")
handleBetween(betweenExpr, isRoot)
case _ =>
println("_")
queryAll() // for null
}
}
private def handleBetween(sqlBetween: SQLBetweenExpr, isRoot: Boolean) = {
val sqlIdentifierExpr = sqlBetween.getTestExpr.asInstanceOf[SQLIdentifierExpr]
val field = sqlIdentifierExpr.getName
val beginValue = getValue(sqlBetween.getBeginExpr)
val endValue = getValue(sqlBetween.getEndExpr)
var resultStr = betweenExpr(field, beginValue, endValue)
if (isRoot) {
resultStr = betweenRoot(resultStr)
}
resultStr
}
private def getValue(sqlExpr: SQLExpr) = {
sqlExpr match {
case integerExpr: SQLIntegerExpr =>
integerExpr.getNumber.toString
case charExpr: SQLCharExpr =>
s"""${charExpr.getText}"""
case numberExpr: SQLNumberExpr =>
numberExpr.getNumber.toString
case _: SQLNullExpr =>
null
case _ =>
throw new RuntimeException("Unsupported SQLExpr type")
}
}
private def handleIn(inaryOpExpr: SQLInListExpr, isRoot: Boolean) = {
val sqlIdentifierExpr = inaryOpExpr.getExpr.asInstanceOf[SQLIdentifierExpr]
val field = sqlIdentifierExpr.getName
val list = inaryOpExpr.getTargetList.asScala // Convert Java List to Scala Seq
val values = list.map(v => getValue(v))
val resultStr = if (inaryOpExpr.isNot) {
notInExpr(field, values.toList)
} else {
inExpr(field, values.toList)
}
if (isRoot) {
inRoot(resultStr)
} else {
resultStr
}
}
/**
*
* @param binaryOpExpr 二元表达式:左侧为字段 右侧为值
* @param isRoot
* @return 字段值运算,例如 > = < != 等
*/
private def handleValueCompare(binaryOpExpr: SQLBinaryOpExpr, isRoot: Boolean) = {
// 字段名
val field =
binaryOpExpr
.getLeft
.asInstanceOf[SQLIdentifierExpr]
.getName
// 字段值
var value = getValue(binaryOpExpr.getRight)
val resultStr = binaryOpExpr.getOperator match {
/** 等于 */
case Equality =>
equality(field, value)
/** 不等于 */
case NotEqual =>
notEqual(field, value)
/** 大于 */
case GreaterThan =>
greaterThan(field, value)
/** 大于等于 */
case GreaterThanOrEqual =>
greaterThanOrEqual(field, value)
/** 小于 */
case LessThan =>
lessThan(field, value)
/** 小于等于 */
case LessThanOrEqual =>
lessThanOrEqual(field, value)
/** 正则匹配 */
case Like =>
// 字符转义
value = Escaping.escapeQueryChars(value)
like(field, value)
/** 正则不匹配 */
case NotLike =>
notLike(field, value)
/** is */
case Is =>
is(field, value)
/** is not */
case IsNot =>
isNot(field, value)
}
if (isRoot) {
comparisonRoot(resultStr)
} else {
resultStr
}
}
private def handleAnd(binaryOpExpr: SQLBinaryOpExpr) = {
val (leftStr, rightStr) = handleLeftAndRight(binaryOpExpr)
val resultStr = and(leftStr, rightStr)
binaryOpExpr.getParent match {
case parent: SQLBinaryOpExpr if parent.getOperator == SQLBinaryOperator.BooleanAnd =>
resultStr
case _ =>
andRoot(resultStr)
}
}
private def handleOr(binaryOpExpr: SQLBinaryOpExpr) = {
val (left, right) = handleLeftAndRight(binaryOpExpr)
val resultStr = or(left, right)
binaryOpExpr.getParent match {
case parentExpr: SQLBinaryOpExpr if parentExpr.getOperator == SQLBinaryOperator.BooleanOr => resultStr
case _ => orRoot(resultStr)
}
}
private def handleLeftAndRight(binaryOpExpr: SQLBinaryOpExpr) = {
val leftExpr = binaryOpExpr.getLeft
val rightExpr = binaryOpExpr.getRight
val leftStr = handleWhere(leftExpr, isRoot = false)
val rightStr = handleWhere(rightExpr, isRoot = false)
(leftStr, rightStr)
}
/**
* @param selectItems
* @return 返回 SQL 中查询的字段以及查询方式
*/
private def getQueryFieldsAndMethods(selectItems: util.List[SQLSelectItem]): List[(String, String, String)] = {
val fields: ListBuffer[(String, String, String)] =
new ListBuffer[(String, String, String)]
selectItems.asScala.foreach { sqlSelectItem =>
// 别名
val alias: String = sqlSelectItem.getAlias
println("alias -> " + alias)
// 操作类型
val sqlExpr: SQLExpr = sqlSelectItem.getExpr
println("sqlExpr -> " + sqlExpr)
sqlExpr match {
/** 查询全部字段 */
case _: SQLAllColumnExpr =>
println("SQLAllColumnExpr")
fields += (("*", null, ""))
/** 查询特定字段 */
case sqlIdentifierExpr: SQLIdentifierExpr =>
println("SQLIdentifierExpr")
fields += ((sqlIdentifierExpr.getName, alias, ""))
case _ =>
throw new RuntimeException("未知操作类型!")
}
}
fields.toList
}
/** ******************************** SqlToDslTranslator end *********************************** */
/** ------------------------------- SqlToDslHandler start ------------------------------------ */
override def and(left: String, right: String): String = tuple(left, right)
override def or(left: String, right: String): String = tuple(left, right)
private def tuple(left: String, right: String): String = {
if (left.isEmpty || right.isEmpty) {
left + right
} else {
s"$left,$right"
}
}
override def betweenExpr(field: String, begin: String, end: String): String = {
s"""{
| "range": {
| "$field": {
| "from": "$begin",
| "to": "$end"
| }
| }
|}""".stripMargin
}
override def inExpr(field: String, values: List[String]): String = {
s"""{
| "terms": {
| "$field": [${formatListValues(values)}]
| }
|}""".stripMargin
}
override def notInExpr(field: String, values: List[String]): String = {
s"""{
| "bool": {
| "must_not": {
| "terms": {
| "$field": [${formatListValues(values)}]
| }
| }
| }
|}""".stripMargin
}
private def formatListValues(values: List[String]): String = {
values.map(v => s""""$v"""").mkString(",")
}
override def greaterThanOrEqual(field: String, value: String): String = {
s"""{
| "range": {
| "$field": {
| "from": "$value"
| }
| }
|}""".stripMargin
}
override def equality(field: String, value: String): String = {
s"""{
| "match_phrase": {
| "$field": "$value"
| }
|}""".stripMargin
}
override def lessThanOrEqual(field: String, value: String): String = {
s"""{
| "range": {
| "$field": {
| "to": "$value"
| }
| }
|}""".stripMargin
}
override def greaterThan(field: String, value: String): String = {
s"""{
| "range": {
| "$field": {
| "gt": "$value"
| }
| }
|}""".stripMargin
}
override def lessThan(field: String, value: String): String = {
s"""{
| "range": {
| "$field": {
| "lt": "$value"
| }
| }
|}""".stripMargin
}
override def notEqual(field: String, value: String): String = {
s"""{
| "bool": {
| "must_not": [
| {
| "match_phrase": {
| "$field": {
| "query": "$value"
| }
| }
| }
| ]
| }
|}""".stripMargin
}
override def like(field: String, value: String): String = {
s"""{
| "query_string": {
| "default_field": "$field",
| "query": "*$value*"
| }
|}""".stripMargin
}
override def notLike(field: String, value: String): String = {
s"""{
| "bool": {
| "must_not": {
| "match_phrase": {
| "$field": {
| "query": "$value"
| }
| }
| }
| }
|}""".stripMargin
}
override def is(field: String, value: String): String = {
s"""{
| "bool": {
| "must_not": {
| "exists": {
| "field": "$field"
| }
| }
| }
|}""".stripMargin
}
def isNot(field: String, value: String): String = {
s"""{
| "bool": {
| "must": {
| "exists": {
| "field": "$field"
| }
| }
| }
|}""".stripMargin
}
override def root(dsl: String): String = {
println(" ROOT ------------------->")
val builder = new StringBuilder(s"""{"query" : $dsl""")
builder.append("}")
JSONUtil.parseObj(builder.toString()).toString
}
override def orRoot(dsl: String): String = {
s"""{
| "bool": {
| "should": [$dsl]
| }
|}""".stripMargin
}
override def andRoot(dsl: String): String = must(dsl)
override def betweenRoot(dsl: String): String = must(dsl)
override def inRoot(dsl: String): String = must(dsl)
override def comparisonRoot(dsl: String): String = must(dsl)
private def must(dsl: String): String = {
s"""{
| "bool": {
| "must": [$dsl]
| }
|}""".stripMargin
}
override def queryAll(): String = {
"""{
| "match_all": {}
|}""".stripMargin
}
/** ------------------------------- SqlToDslHandler end ------------------------------------ */
}
5.4. SqlToDslConverter
package com.dengdz.utils
import com.alibaba.druid.sql.ast.SQLStatement
import com.alibaba.druid.sql.ast.statement.{SQLDeleteStatement, SQLInsertStatement, SQLSelectStatement, SQLUpdateStatement}
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser
import com.alibaba.druid.sql.parser.ParserException
import com.dengdz.translator.SqlToDslTranslator
import com.dengdz.translator.impl.MysqlToDslTranslator
import com.dengdz.utils.Kind.{Kind, MYSQL}
/**
*
* @param sql 待转换sql
* @param kind sql类型
*
* @author dengdz
*/
class SqlToDslConverter(sql: String, kind: Kind) {
private var sqlStatement: SQLStatement = _
private var sqlToDslTranslator: SqlToDslTranslator = _
parseSQL(sql, kind)
/**
* 转换 SQL
*
* @return
*/
def convertSqlToDsl(): (String, String) = {
println("sqlStatement ->" + sqlStatement)
sqlStatement match {
/** 查 */
case stmt: SQLSelectStatement =>
sqlToDslTranslator.handleSelectStatement(stmt)
/** 增 */
case stmt: SQLInsertStatement =>
sqlToDslTranslator.handleInsertStatement(stmt)
/** 改 */
case stmt: SQLUpdateStatement =>
sqlToDslTranslator.handleUpdateStatement(stmt)
/** 删 */
case stmt: SQLDeleteStatement =>
sqlToDslTranslator.handleDeleteStatement(stmt)
/** 其余操作 */
case _ =>
sqlToDslTranslator.handleOtherStatement(sqlStatement)
}
}
/**
* 解析 SQL 语句
*
* @param sql
* @return
* @throws ParserException
*/
@throws[ParserException]
def parseSQL(sql: String, kind: Kind) = try {
// 将SQL字符串解析为AST
kind match {
case MYSQL =>
sqlStatement = new MySqlStatementParser(sql).parseStatement
sqlToDslTranslator = new MysqlToDslTranslator()
}
} catch {
case e: ParserException =>
println(s"解析SQL失败: $sql")
e.printStackTrace()
throw e
}
}
5.5. Escaping
package com.dengdz.utils
/**
* 字符转义
* @author dengdz
*/
object Escaping {
def escapeQueryChars(s: String): String = {
if (s == null || s.trim.isEmpty) {
return s
}
val specialChars = Set('\\', '+', '-', '!', '(', ')', ':', '^', '[', ']', '\"', '{', '}', '~', '*', '?', '|', '&', ';', '/', '.', '$')
val sb = new StringBuilder
s.foreach { c =>
if (specialChars.contains(c) || Character.isWhitespace(c)) {
sb.append('\\').append('\\')
}
sb.append(c)
}
sb.toString()
}
}
5.6. Kind
package com.dengdz.utils
/**
* @author dengdz
*/
object Kind extends Enumeration {
type Kind = Value
val MYSQL, HIVE = Value
}
5.7. Demo
package com.dengdz
import com.dengdz.utils.{Kind, SqlToDslConverter}
object Demo {
def main(args: Array[String]): Unit = {
val sql = "select user_id from demo where user_id is not null"
val converter =
new SqlToDslConverter(sql = sql, kind = Kind.MYSQL)
val tuple =
converter.convertSqlToDsl()
println(tuple)
}
}
评论区