java:Druid工具类解析sql获取表名
1 前言
alibaba的druid连接池除了sql执行的功能外,还有sql语法解析的工具提供,参考依赖如下:
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
<version>1.2.15</version>
</dependency>
2 使用
参考druid的工具类:com.alibaba.druid.sql.parser.SQLParserUtils#getTables(String sql, DbType dbType)方法,可以用于获取sql的表名:
比如针对mysql的select语句:
package com.xiaoxu.parser;
import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import java.util.List;
/**
* @author xiaoxu
* @date 2024-03-11
* java_demo2:com.xiaoxu.parser.SQLParserTest
*/
public class SQLParserTest {
public static void main(String[] args) {
List<String> tables = SQLParserUtils.getTables("select * from my_fuitrs_99 where id = ?", DbType.mysql);
System.out.println(tables);
String nihao_99 = "select * from my_fuitrs_99 where id = ?".replace(tables.get(0), "my_fuitrs");
System.out.println(nihao_99);
}
}
执行结果如下:
该方法的本质是根据自定义的hash值,对于每个标识符,能唯一生成一个hashCode,将标识符与存入map中的标识符来比对。比如当lexer.token获取的是com.alibaba.druid.sql.parser.Token类中的FROM(“FROM”)时,先执行lexer.nextToken()方法调用,获取到用户自定义的from后续的标识符(from关键字后续跟着的空格等字符会跳过),判断在没有获取到,则将其置为:Token.IDENTIFIER,表示用户自定义的identifier标识符。如下 keywords.getKeyword(hashLCase) 如果获取的标识符不存在,则将判断为用户自定义的标识符。
参考com.alibaba.druid.sql.parser.Lexer的scanIdentifier()方法:
public void scanIdentifier() {
this.hashLCase = 0;
this.hash = 0;
final char first = ch;
if (ch == '`') {
mark = pos;
bufPos = 1;
char ch;
int startPos = pos + 1;
int quoteIndex = text.indexOf('`', startPos);
if (quoteIndex == -1) {
throw new ParserException("illegal identifier. " + info());
}
hashLCase = 0xcbf29ce484222325L;
hash = 0xcbf29ce484222325L;
for (int i = startPos; i < quoteIndex; ++i) {
ch = text.charAt(i);
hashLCase ^= ((ch >= 'A' && ch <= 'Z') ? (ch + 32) : ch);
hashLCase *= 0x100000001b3L;
hash ^= ch;
hash *= 0x100000001b3L;
}
stringVal = MySqlLexer.quoteTable.addSymbol(text, pos, quoteIndex + 1 - pos, hash);
//stringVal = text.substring(mark, pos);
pos = quoteIndex + 1;
this.ch = charAt(pos);
token = Token.IDENTIFIER;
return;
}
final boolean firstFlag = isFirstIdentifierChar(first);
if (!firstFlag) {
throw new ParserException("illegal identifier. " + info());
}
hashLCase = 0xcbf29ce484222325L;
hash = 0xcbf29ce484222325L;
hashLCase ^= ((ch >= 'A' && ch <= 'Z') ? (ch + 32) : ch);
hashLCase *= 0x100000001b3L;
hash ^= ch;
hash *= 0x100000001b3L;
mark = pos;
bufPos = 1;
char ch = 0;
for (; ; ) {
char c0 = ch;
ch = charAt(++pos);
if (!isIdentifierChar(ch)) {
if ((ch == '(' || ch == ')') && c0 > 256) {
bufPos++;
continue;
}
break;
}
hashLCase ^= ((ch >= 'A' && ch <= 'Z') ? (ch + 32) : ch);
hashLCase *= 0x100000001b3L;
hash ^= ch;
hash *= 0x100000001b3L;
bufPos++;
continue;
}
this.ch = charAt(pos);
if (bufPos == 1) {
switch (first) {
case '(':
token = Token.LPAREN;
return;
case ')':
token = Token.RPAREN;
return;
default:
break;
}
token = Token.IDENTIFIER;
stringVal = CharTypes.valueOf(first);
if (stringVal == null) {
stringVal = Character.toString(first);
}
return;
}
Token tok = keywords.getKeyword(hashLCase);
if (tok != null) {
token = tok;
if (token == Token.IDENTIFIER) {
stringVal = SymbolTable.global.addSymbol(text, mark, bufPos, hash);
} else {
stringVal = null;
}
} else {
token = Token.IDENTIFIER;
stringVal = SymbolTable.global.addSymbol(text, mark, bufPos, hash);
}
}
获取表名的工具类方法:
public static List<String> getTables(String sql, DbType dbType) {
Set<String> tables = new LinkedHashSet<>();
boolean set = false;
Lexer lexer = createLexer(sql, dbType);
lexer.nextToken();
SQLExprParser exprParser;
switch (dbType) {
case odps:
exprParser = new OdpsExprParser(lexer);
break;
case mysql:
exprParser = new MySqlExprParser(lexer);
break;
default:
exprParser = new SQLExprParser(lexer);
break;
}
for_:
for (; lexer.token != Token.EOF; ) {
switch (lexer.token) {
case CREATE:
case DROP:
case ALTER:
set = false;
lexer.nextToken();
if (lexer.token == Token.TABLE) {
lexer.nextToken();
if (lexer.token == Token.IF) {
lexer.nextToken();
if (lexer.token == Token.NOT) {
lexer.nextToken();
}
if (lexer.token == Token.EXISTS) {
lexer.nextToken();
}
}
SQLName name = exprParser.name();
tables.add(name.toString());
if (lexer.token == Token.AS) {
lexer.nextToken();
}
}
continue for_;
case FROM:
case JOIN:
lexer.nextToken();
if (lexer.token != Token.LPAREN
&& lexer.token != Token.VALUES
) {
SQLName name = exprParser.name();
tables.add(name.toString());
}
continue for_;
case SEMI:
set = false;
break;
case SET:
set = true;
break;
case EQ:
if (set && dbType == DbType.odps) {
lexer.nextTokenForSet();
continue for_;
}
break;
default:
break;
}
lexer.nextToken();
}
return new ArrayList<>(tables);
}
比如上述的from关键字,当执行完lexer.nextToken()方法后,lexer.stringVal()方法即可以通过字符串的头尾下标切割字符串并返回该标识符,比如表名,也就是我们自定义的标识符。同时关注源码逻辑可知,数据库的表名,druid工具类会将大写字符、小写字符(大写字母A-Z的ASCII码值范围是65-90,而小写字母的ASCII码值范围是97-122,在大写字母的ASCII码值上+32即可转换成小写字母)判定为标识符,同时druid在处理时,除了大小写字母外,下划线(_)、美元符号($)、数字(0-9,ASCII码值范围是48-57)等等,均可判定为标识符。
修改表名含有大写字母,如下获取sql的表名:
List<String> tables = SQLParserUtils.getTables("select * from my_fuitrs_99 where id = ?", DbType.mysql);
System.out.println(tables);
String nihao_99 = "select * from my_fuitrs_99 where id = ?".replace(tables.get(0), "my_fuitrs");
System.out.println(nihao_99);
List<String> tables2 = SQLParserUtils.getTables("select * from MY_fuitrs_99 where id = ?", DbType.mysql);
System.out.println(tables2);
重新执行执行结果如下:
[my_fuitrs_99]
select * from my_fuitrs where id = ?
[MY_fuitrs_99]
3 举一反三
那么我们可以根据上面的工具,简单自定义实现一个替换sql表名的工具,工具类如下:
SQLParserUtil:
package com.xiaoxu.parser;
import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.ast.SQLName;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlExprParser;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlLexer;
import com.alibaba.druid.sql.parser.Lexer;
import com.alibaba.druid.sql.parser.SQLExprParser;
import com.alibaba.druid.sql.parser.SQLParserFeature;
import com.alibaba.druid.sql.parser.Token;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
/**
* @author xiaoxu
* @date 2024-03-12
* java_demo2:com.xiaoxu.parser.SQLParserUtil
*/
@SuppressWarnings("all")
public class SQLParserUtil {
public static String getSqlFromReplaceNameIfNeccessary(String sql, DbType dbType, @Nullable String replaceTableName) {
String table = null;
String newSql = sql;
Lexer lexer = createLexer(sql, dbType);
lexer.nextToken();
SQLExprParser exprParser;
switch (dbType) {
case mysql:
exprParser = new MySqlExprParser(lexer);
break;
default:
exprParser = new SQLExprParser(lexer);
break;
}
for_:
for (; lexer.token() != Token.EOF; ) {
switch (lexer.token()) {
case FROM:
case INTO:
lexer.nextToken();
if (lexer.token() != Token.LPAREN && lexer.token() != Token.VALUES) {
if (StringUtils.hasText(replaceTableName)) {
int mark = ((MLexer) lexer).getMark();
int bufPoss = ((MLexer) lexer).getBufPos();
StringBuilder sbd = new StringBuilder();
sbd.append(sql.substring(0, mark));
sbd.append(replaceTableName);
sbd.append(sql.substring(mark + bufPoss));
newSql = sbd.toString();
}
SQLName name = exprParser.name();
table = name.toString();
}
break for_;
default:
break;
}
lexer.nextToken();
}
System.out.println("原本的表名:" + table);
System.out.println("替换表名为:" + replaceTableName + "后的sql:" + newSql);
return newSql;
}
public static Lexer createLexer(String sql, DbType dbType) {
return createLexer(sql, dbType, new SQLParserFeature[0]);
}
public static Lexer createLexer(String sql, DbType dbType, SQLParserFeature... features) {
if (dbType == null) {
dbType = DbType.mysql;
}
switch (dbType) {
case mysql:
return new MLexer(sql);
default:
return new Lexer(sql, null, dbType);
}
}
private static class MLexer extends MySqlLexer {
public MLexer(char[] input, int inputLength, boolean skipComment) {
super(input, inputLength, skipComment);
}
public MLexer(String input) {
super(input);
}
public MLexer(String input, SQLParserFeature... features) {
super(input, features);
}
public MLexer(String input, boolean skipComment, boolean keepComments) {
super(input, skipComment, keepComments);
}
public int getBufPos() {
return this.bufPos;
}
public int getMark() {
return this.mark;
}
}
}
测试下我们自定义的SQLParserUtil工具类:
package com.xiaoxu.parser;
import com.alibaba.druid.DbType;
/**
* @author xiaoxu
* @date 2024-03-12
* java_demo2:com.xiaoxu.parser.SQLParserTest2
*/
public class SQLParserTest2 {
public static void main(String[] args) {
String sql = SQLParserUtil.getSqlFromReplaceNameIfNeccessary("select * from my_fruits_99 where id = ?",
DbType.mysql, "xiaoxu_88");
System.out.println(sql);
System.out.println("\n");
String sql2 = SQLParserUtil.getSqlFromReplaceNameIfNeccessary("insert into apple_$66 values()",
DbType.mysql, "Pear$_88");
System.out.println(sql2);
}
}
执行结果如下:
可以看到,上面工具针对扫描到标识符为FROM(比如select * from语句)或者INTO(比如insert into语句)时,可以实现替换sql的表名功能,其余类似功能参考druid的工具类自行实现即可。
再来举个栗子,新增方法getSqlInHoldCountIfNeccessary如下:
package com.xiaoxu.parser;
import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.ast.SQLName;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlExprParser;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlLexer;
import com.alibaba.druid.sql.parser.Lexer;
import com.alibaba.druid.sql.parser.SQLExprParser;
import com.alibaba.druid.sql.parser.SQLParserFeature;
import com.alibaba.druid.sql.parser.Token;
import com.google.common.collect.Lists;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
import java.util.List;
/**
* @author xiaoxu
* @date 2024-03-12
* java_demo2:com.xiaoxu.parser.SQLParserUtil
*/
@SuppressWarnings("all")
public class SQLParserUtil {
public static String getSqlInHoldCountIfNeccessary(String sql, DbType dbType, @Nullable Integer[] count) {
StringBuilder tempSql = new StringBuilder();
Lexer lexer = createLexer(sql, dbType);
lexer.nextToken();
int posInc = 0;
int startPos = 0;
int endPos = 0;
int subStartPos = 0;
for_:
for (; lexer.token() != Token.EOF; ) {
switch (lexer.token()) {
case IN:
lexer.nextToken();
startPos = endPos;
endPos = ((MLexer) lexer).getPos();
if (lexer.token() == Token.LPAREN) {
tempSql.append(sql, startPos, endPos);
subStartPos = endPos;
do {
endPos = ((MLexer) lexer).getPos();
lexer.nextToken();
} while (lexer.token() != Token.RPAREN);
String replaceMent = sql.substring(subStartPos, endPos);
if (count != null && count.length > posInc && count[posInc] != null && count[posInc] > 0) {
List<String> incStrs = Lists.newArrayList();
for (int i = 0; i < count[posInc]; i++) {
incStrs.add("?");
}
tempSql.append(String.join(",", incStrs));
} else {
tempSql.append(replaceMent);
}
posInc++;
}
continue for_;
default:
break;
}
lexer.nextToken();
}
tempSql.append(sql.substring(endPos));
return tempSql.toString();
}
public static String getSqlFromReplaceNameIfNeccessary(String sql, DbType dbType, @Nullable String replaceTableName) {
String table = null;
String newSql = sql;
Lexer lexer = createLexer(sql, dbType);
lexer.nextToken();
SQLExprParser exprParser;
switch (dbType) {
case mysql:
exprParser = new MySqlExprParser(lexer);
break;
default:
exprParser = new SQLExprParser(lexer);
break;
}
for_:
for (; lexer.token() != Token.EOF; ) {
switch (lexer.token()) {
case FROM:
case INTO:
lexer.nextToken();
if (lexer.token() != Token.LPAREN && lexer.token() != Token.VALUES) {
if (StringUtils.hasText(replaceTableName)) {
int mark = ((MLexer) lexer).getMark();
int bufPoss = ((MLexer) lexer).getBufPos();
StringBuilder sbd = new StringBuilder();
sbd.append(sql.substring(0, mark));
sbd.append(replaceTableName);
sbd.append(sql.substring(mark + bufPoss));
newSql = sbd.toString();
}
SQLName name = exprParser.name();
table = name.toString();
}
break for_;
default:
break;
}
lexer.nextToken();
}
System.out.println("原本的表名:" + table);
System.out.println("替换表名为:" + replaceTableName + "后的sql:" + newSql);
return newSql;
}
public static Lexer createLexer(String sql, DbType dbType) {
return createLexer(sql, dbType, new SQLParserFeature[0]);
}
public static Lexer createLexer(String sql, DbType dbType, SQLParserFeature... features) {
if (dbType == null) {
dbType = DbType.mysql;
}
switch (dbType) {
case mysql:
return new MLexer(sql);
default:
return new Lexer(sql, null, dbType);
}
}
private static class MLexer extends MySqlLexer {
public MLexer(char[] input, int inputLength, boolean skipComment) {
super(input, inputLength, skipComment);
}
public MLexer(String input) {
super(input);
}
public MLexer(String input, SQLParserFeature... features) {
super(input, features);
}
public MLexer(String input, boolean skipComment, boolean keepComments) {
super(input, skipComment, keepComments);
}
public int getBufPos() {
return this.bufPos;
}
public int getMark() {
return this.mark;
}
public int getPos() {
return this.pos;
}
}
}
getSqlInHoldCountIfNeccessary方法的效果是,我们知道在mysql的子查询IN中,假设一个sql有多处具有子查询IN,假定为:in (?)。但是我们需要自定义IN子查询后续的参数个数,意即类似更新sql的子查询参数数目为in (?,?,?),同时不改变原有sql的语句,那么通过我们自定义的该方法可以达到效果,测试如下:
package com.xiaoxu.parser;
import com.alibaba.druid.DbType;
/**
* @author xiaoxu
* @date 2024-03-12
* java_demo2:com.xiaoxu.parser.SQLParserTest2
*/
public class SQLParserTest2 {
public static void main(String[] args) {
String sql = SQLParserUtil.getSqlInHoldCountIfNeccessary("select * from my where id in( ?,?) and status is not null and name in (?) and ot ='N' and pr in ()",
DbType.mysql, new Integer[]{null, 4});
System.out.println("最终结果是:");
System.out.println(sql);
String sql2 = SQLParserUtil.getSqlInHoldCountIfNeccessary("select * from my where id in (?)",
DbType.mysql, new Integer[]{3});
System.out.println(sql2);
String sql3 = SQLParserUtil.getSqlInHoldCountIfNeccessary("select * from my where id = ?",
DbType.mysql, new Integer[]{3});
System.out.println(sql3);
}
}
执行结果如下:
其中参数new Integer[]{null, 4}的效果是,第一个IN子查询不变,第二个子查询更新为in (?,?,?,?)。该逻辑是按照顺序更新IN的后续参数数目,同时不改变原有的sql。