1、ParseUtils
package com.example.util;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import java.util.*;
import java.util.stream.Collectors;
import static com.alibaba.druid.util.JdbcConstants.*;
@Slf4j
public abstract class ParseUtil {
public static void format(String sql, String dbType) {
String sqlFormat = SQLUtils.format(sql, dbType);
if (sql.equals(sqlFormat)) {
throw new RuntimeException("SQL格式错误");
}
}
public static void checkSymbol(String sql) {
Arrays.stream(sql.toUpperCase().split("DROP|CREATE|INSERT|UPDATE"))
.filter(sqlStr -> !StringUtils.isEmpty(sqlStr))
.filter(sqlStr -> !sqlStr.contains(";"))
.forEach(sqlStr -> {
throw new RuntimeException("SQL缺失;");
});
}
public static List<SQLStatement> parseStatements(String sql, String dbType) {
try {
return SQLUtils.parseStatements(sql, dbType);
} catch (Exception e) {
log.error(e.getMessage(), e);
throw new RuntimeException("SQL格式错误");
}
}
public static List<Map<String, List<String>>> bloodRelationship(List<SQLStatement> sqlStatements
, List<Map<String, List<String>>> tableList, String dbType) {
SchemaStatVisitor visitor;
for (SQLStatement sqlStatement : sqlStatements) {
Map<String, List<String>> tableMap = new HashMap<>();
visitor = SQLUtils.createSchemaStatVisitor(dbType);
sqlStatement.accept(visitor);
visitor.getTables().forEach((key, value) -> {
if (!"Select".equals(value.toString())) {
throw new RuntimeException("禁止进行其他DML操作!");
}
});
List<String> tables = visitor.getTables()
.keySet()
.stream()
.map(TableStat.Name::getName)
.collect(Collectors.toList());
if (!CollectionUtils.isEmpty(tables) && tables.size() > 1) {
tableMap.put(tables.remove(0), tables);
tableList.add(tableMap);
}
}
return tableList;
}
static class OdpsParseUtil extends ParseUtil {
public static List<Map<String, List<String>>> parse(String sql) {
return bloodRelationship(checkSqlFormatByOdps(sql), new ArrayList<>(), ODPS);
}
public static List<SQLStatement> checkSqlFormatByOdps(String sql) {
checkSymbol(sql);
format(sql, ODPS);
return parseStatements(sql, ODPS);
}
}
static class HiveParseUtil extends ParseUtil {
public static List<Map<String, List<String>>> parse(String sql) {
return bloodRelationship(checkSqlFormatByHive(sql), new ArrayList<>(), HIVE);
}
public static List<SQLStatement> checkSqlFormatByHive(String sql) {
checkSymbol(sql);
format(sql, HIVE);
return parseStatements(sql, HIVE);
}
}
static class MysqlParseUtil extends ParseUtil {
public static List<Map<String, List<String>>> parse(String sql) {
return bloodRelationship(checkSqlFormatByMysql(sql), new ArrayList<>(), MYSQL);
}
public static List<SQLStatement> checkSqlFormatByMysql(String sql) {
checkSymbol(sql);
format(sql, MYSQL);
return parseStatements(sql, MYSQL);
}
}
}
2、源码分析
public class SchemaStatVisitor extends SQLASTVisitorAdapter {
protected SchemaRepository repository;
protected final HashMap<TableStat.Name, TableStat> tableStats = new LinkedHashMap<TableStat.Name, TableStat>();
protected final Map<Long, Column> columns = new LinkedHashMap<Long, Column>();
protected final List<Condition> conditions = new ArrayList<Condition>();
protected final Set<Relationship> relationships = new LinkedHashSet<Relationship>();
protected final List<Column> orderByColumns = new ArrayList<Column>();
protected final Set<Column> groupByColumns = new LinkedHashSet<Column>();
protected final List<SQLAggregateExpr> aggregateFunctions = new ArrayList<SQLAggregateExpr>();
protected final List<SQLMethodInvokeExpr> functions = new ArrayList<SQLMethodInvokeExpr>(2);
......
}
- 可以根据自己的功能动态的修改工具类,也可以实现自己的SchemaStatVisitor ,githup的api上有,可以去查看一下