初始方案
- 不校验
script语法

mapper层sql校验,在项目启动前进行sql语法校验,通常要到执行这个mapper才会报错。
package ix.account.util;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.ibatis.annotations.*;
import org.springframework.beans.factory.BeanDefinitionStoreException;
import org.springframework.context.ResourceLoaderAware;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternUtils;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.core.type.filter.TypeFilter;
import org.springframework.util.StringUtils;
import org.springframework.util.SystemPropertyUtils;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.*;
import java.util.stream.Collectors;
/**
* spring scaner
*/
@Slf4j
public class ClassScaner implements ResourceLoaderAware {
private final List<TypeFilter> includeFilters = new LinkedList<>();
private final List<TypeFilter> excludeFilters = new LinkedList<>();
private ResourcePatternResolver resourcePatternResolver = new PathMatchingResourcePatternResolver();
private MetadataReaderFactory metadataReaderFactory = new CachingMetadataReaderFactory(this.resourcePatternResolver);
public static Set<Class> scan(String[] basePackages,
Class<? extends Annotation>... annotations) {
ClassScaner classScaner = new ClassScaner();
if (ArrayUtils.isNotEmpty(annotations)) {
for (Class annotation : annotations) {
classScaner.addIncludeFilter(new AnnotationTypeFilter(annotation));
}
}
Set<Class> classes = new HashSet<>();
for (String s : basePackages) {
classes.addAll(classScaner.doScan(s));
}
return classes;
}
/**
* spring 指定包扫描
*
* @param basePackages 扫描包基本路径
* @param annotations 具体扫描什么注解 例如{@link Mapper}
* @return
*/
public static Set<Class> scan(String basePackages, Class<? extends Annotation>... annotations) {
return ClassScaner.scan(StringUtils.tokenizeToStringArray(basePackages, ",; \t\n"), annotations);
}
public final ResourceLoader getResourceLoader() {
return this.resourcePatternResolver;
}
@Override
public void setResourceLoader(ResourceLoader resourceLoader) {
this.resourcePatternResolver = ResourcePatternUtils
.getResourcePatternResolver(resourceLoader);
this.metadataReaderFactory = new CachingMetadataReaderFactory(
resourceLoader);
}
public void addIncludeFilter(TypeFilter includeFilter) {
this.includeFilters.add(includeFilter);
}
public void addExcludeFilter(TypeFilter excludeFilter) {
this.excludeFilters.add(0, excludeFilter);
}
public void resetFilters(boolean defaultFilters) {
this.includeFilters.clear();
this.excludeFilters.clear();
}
public Set<Class> doScan(String basePackage) {
Set<Class> classes = new HashSet<>();
try {
String packageSearchPath = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX
+ org.springframework.util.ClassUtils
.convertClassNameToResourcePath(SystemPropertyUtils
.resolvePlaceholders(basePackage))
+ "/**/*.class";
Resource[] resources = this.resourcePatternResolver
.getResources(packageSearchPath);
for (int i = 0; i < resources.length; i++) {
Resource resource = resources[i];
if (resource.isReadable()) {
MetadataReader metadataReader = this.metadataReaderFactory.getMetadataReader(resource);
boolean b = (includeFilters.size() == 0 && excludeFilters.size() == 0)
|| matches(metadataReader);
if (b) {
try {
classes.add(Class.forName(metadataReader
.getClassMetadata().getClassName()));
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}
}
} catch (IOException ex) {
throw new BeanDefinitionStoreException(
"I/O failure during classpath scanning", ex);
}
return classes;
}
protected boolean matches(MetadataReader metadataReader) throws IOException {
for (TypeFilter tf : this.excludeFilters) {
if (tf.match(metadataReader, this.metadataReaderFactory)) {
return false;
}
}
for (TypeFilter tf : this.includeFilters) {
if (tf.match(metadataReader, this.metadataReaderFactory)) {
return true;
}
}
return false;
}
public static boolean getMethodAnnotation(String basePackages,
Class<? extends Annotation>... annotations) {
Set<Class> scan = scan(basePackages, annotations);
List<SqlErrorInfo> sqlErrorInfos = new ArrayList<>();
for (Class mapperClass : scan) {
Method[] methods = mapperClass.getMethods();
for (Method method : methods) {
Annotation[] annotations1 = method.getAnnotations();
for (Annotation annotation : annotations1) {
if (annotation instanceof Insert) {
List<String> collect = Arrays.stream(((Insert) annotation).value()).collect(Collectors.toList());
String sql = sqlAnnotValue(collect);
boolean b = crudCheck(sql);
if (b == false) {
sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
}
} else if (annotation instanceof Select) {
List<String> collect = Arrays.stream(((Select) annotation).value()).collect(Collectors.toList());
String sql = sqlAnnotValue(collect);
boolean b = crudCheck(sql);
if (b == false) {
sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
}
} else if (annotation instanceof Update) {
List<String> collect = Arrays.stream(((Update) annotation).value()).collect(Collectors.toList());
String sql = sqlAnnotValue(collect);
boolean b = crudCheck(sql);
if (b == false) {
sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
}
} else if (annotation instanceof Delete) {
List<String> collect = Arrays.stream(((Delete) annotation).value()).collect(Collectors.toList());
String sql = sqlAnnotValue(collect);
boolean b = crudCheck(sql);
if (b == false) {
sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
}
}
}
}
}
// System.out.println(sqlErrorInfos.size());
sqlErrorInfos.forEach(
info -> {
log.error("不正确的sql,不校验<script>包装,错误sql : " + info);
}
);
if (sqlErrorInfos.size() == 0) {
return true;
} else {
return false;
}
}
/**
* 将Mapper层注解中的sql获取
*
* @param collect sqlCollect
* @return sql
*/
private static String sqlAnnotValue(List<String> collect) {
String sql;
if (collect.size() != 1) {
StringBuilder sbd = new StringBuilder();
collect.forEach(s -> {
sbd.append(s);
sbd.append(" ");
});
sql = sbd.toString();
} else {
sql = collect.get(0);
}
return sql;
}
/**
* crud sql校验
*
* @param sql
*/
private static boolean crudCheck(String sql) {
// System.out.println("准备校验的sql = " + sql);
if (sql.startsWith("<script>")) {
return true;
} else {
try {
MySqlStatementParser parser = new MySqlStatementParser(sql);
List<SQLStatement> stmtList = parser.parseStatementList();
int size = stmtList.size();
if (size != 0) {
return true;
} else {
return false;
}
} catch (Exception e) {
return false;
}
}
}
public static void main(String[] args) {
String basePackages = "ix.account.mapper";
Set<Class> scan = ClassScaner.scan(basePackages, Mapper.class);
getMethodAnnotation(basePackages, Mapper.class);
}
@Data
@Builder
private static class SqlErrorInfo {
private String method;
private String sql;
private String clazz;
}
}
继续改进
查看源码部分
org.apache.ibatis.scripting.xmltags.XMLLanguageDriver#createSqlSource(org.apache.ibatis.session.Configuration, java.lang.String, java.lang.Class<?>)

这部分代码中script变量为注解中的脚本。
@Override
public SqlSource createSqlSource(Configuration configuration, String script, Class<?> parameterType) {
// issue #3
if (script.startsWith("<script>")) {
XPathParser parser = new XPathParser(script, false, configuration.getVariables(), new XMLMapperEntityResolver());
return createSqlSource(configuration, parser.evalNode("/script"), parameterType);
} else {
// issue #127
script = PropertyParser.parse(script, configuration.getVariables());
TextSqlNode textSqlNode = new TextSqlNode(script);
if (textSqlNode.isDynamic()) {
return new DynamicSqlSource(configuration, textSqlNode);
} else {
return new RawSqlSource(configuration, script, parameterType);
}
}
}
org.apache.ibatis.scripting.xmltags.XMLLanguageDriver#createSqlSource(org.apache.ibatis.session.Configuration, org.apache.ibatis.parsing.XNode, java.lang.Class<?>)
@Override
public SqlSource createSqlSource(Configuration configuration, XNode script, Class<?> parameterType) {
XMLScriptBuilder builder = new XMLScriptBuilder(configuration, script, parameterType);
return builder.parseScriptNode();
}

此时XNode script已经成为可用的sql
- 基本改造如下
String s = "<script>" +
"SELECT * FROM tbl_order " +
"WHERE 1=1" +
"<when test='title!=null'>" +
"AND mydate = #{mydate}" +
"</when>" +
"</script>";
XPathParser parser = new XPathParser(s, false);
XNode xNode = parser.evalNode("/script");
String stringBody = xNode.getStringBody();
System.out.println(stringBody);
859

被折叠的 条评论
为什么被折叠?



