// 这是一个插件的实现原理案例
public class InterceptorDemo {
public static void main(String[] args) {
// 创建SuperMan的代理对象,通过FlyInterceptor对它进行增强
Flyable flyable = (Flyable) Plugin.wrap(new SuperMan(), new FlyInterceptor());
flyable.fly();
}
@Intercepts({
@Signature(type = Flyable.class, method = "fly", args = {})
})
static class FlyInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
System.out.println("begin fly");
Object target = invocation.getTarget();
if (target instanceof Flyable) {
Method method = invocation.getMethod();
method.invoke(target);
}
System.out.println("stop fly");
return null;
}
}
interface Flyable {
void fly();
}
static class SuperMan implements Flyable {
@Override
public void fly() {
System.out.println("flying");
}
}
}
// 自定义一个慢SQL查询插件,拦截StatementHandler的查询方法
@Intercepts({
@Signature(type = StatementHandler.class, method = "update", args = {Statement.class}),
@Signature(type = StatementHandler.class, method = "query", args = {Statement.class, ResultHandler.class})
})
@Slf4j
public class SlowSqlInterceptor implements Interceptor {
private final long slowSqlTime = 10;
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
BoundSql boundSql = statementHandler.getBoundSql();
String sql = boundSql.getSql();
long startTime = System.currentTimeMillis();
Object result = invocation.proceed();
long endTime = System.currentTimeMillis();
long executeTime = endTime - startTime;
if (executeTime > slowSqlTime) {
log.debug("{}", sql);
log.warn("slow sql,execute time:{}ms", executeTime);
}
return result;
}
}
// 拦截mybatis四大组件的类
public interface Interceptor {
// 拦截方法
Object intercept(Invocation invocation) throws Throwable;
/**
* 将目标对象包装,根据目当前拦截器中的注解信息来给目标对象创建代理对象
* 目标对象的类型为: ParameterHandler ResultSetHandler StatementHandler Executor
*/
default Object plugin(Object target) {
return Plugin.wrap(target, this);
}
// 给拦截器提供的属性
default void setProperties(Properties properties) {
}
// 方法调用器
public class Invocation {
// 目标对象
private final Object target;
// 目标方法
private final Method method;
// 方法参数
private final Object[] args;
public Object proceed() throws InvocationTargetException, IllegalAccessException {
return method.invoke(target, args);
}
}
}
// 插件的工具类
public class Plugin implements InvocationHandler {
// 包装的目标对象
private final Object target;
// 拦截器对象
private final Interceptor interceptor;
// 所有方法的签名信息
private final Map<Class<?>, Set<Method>> signatureMap;
// 使用指定的拦截器增强目标对象
public static Object wrap(Object target, Interceptor interceptor) {
// 获取拦截器中标注的方法签名信息
// 获取@Intercepts注解中标注的方法签名信息(具体拦截的方法信息)
Map<Class<?>, Set<Method>> signatureMap = this.getSignatureMap(interceptor);
// 获取目标类型
Class<?> type = target.getClass();
// 获取增强的目标类实现的接口,如果方法签名中包含这些接口,表示要给这些接口生成代理对象
Class<?>[] interfaces = this.getAllInterfaces(type, signatureMap);
// 生成代理对象
if (interfaces.length > 0) {
return Proxy.newProxyInstance(type.getClassLoader(), interfaces, new Plugin(target, interceptor, signatureMap));
}
// 如果不存在接口,则不需要创建代理对象
return target;
}
// 代理对象需要执行的方法,当前类是一个InvocationHandler
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
try {
// 获取该目标类要增强的所有方法
Set<Method> methods = signatureMap.get(method.getDeclaringClass());
// 如果需要增强的方法中包含当前正在执行的方法
if (methods != null && methods.contains(method)) {
// 就要对当前方法进行增强拦截
return interceptor.intercept(new Invocation(target, method, args));
}
// 如果当前执行的方法不需要增强,直接直接方法
return method.invoke(target, args);
} catch (Exception e) {
throw ExceptionUtil.unwrapThrowable(e);
}
}
// 获取@Intercepts注解中标注的方法签名信息(具体拦截的方法信息)
private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
// 获取拦截器中的注解信息
Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
// 拦截器中必须包含@Intercepts注解
if (interceptsAnnotation == null) {
throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
}
// 获取注解中的签名信息
Signature[] sigs = interceptsAnnotation.value();
Map<Class<?>, Set<Method>> signatureMap = new HashMap<>();
// 遍历所有方法签名
for (Signature sig : sigs) {
// 为什么要用Set集合,因为@Signature注解可以重复,需要拦截相同类的多个方法的情况就需要Set
Set<Method> methods = MapUtil.computeIfAbsent(signatureMap, sig.type(), k -> new HashSet<>());
try {
// 获取到具体的方法对象
Method method = sig.type().getMethod(sig.method(), sig.args());
// 保存拦截的方法对象
methods.add(method);
}
// 如果该对象中不存在指定的方法,抛出异常
catch (NoSuchMethodException e) {
throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
}
}
return signatureMap;
}
// 获取增强的目标类实现的接口,如果方法签名中包含这些接口,表示要给这些接口生成代理对象
private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
Set<Class<?>> interfaces = new HashSet<>();
while (type != null) {
// 获取需要拦截的目标类实现的所有接口
for (Class<?> c : type.getInterfaces()) {
// 在方法签名Map中存在这个接口(表示要对这个接口增强)
if (signatureMap.containsKey(c)) {
// 将该接口保存下来,用给它生成代理对象
interfaces.add(c);
}
}
// 继续找父接口
type = type.getSuperclass();
}
return interfaces.toArray(new Class<?>[0]);
}
// 方法签名注解,就是用于描述一个具体的方法
public @interface Signature {
// 拦截器的目标类
Class<?> type();
// 拦截的目标类中的方法
String method();
// 拦截的方法的参数类型
Class<?>[] args();
}
}
// 拦截器,执行链
public class InterceptorChain {
// 所有的拦截器
public List<Interceptor> interceptors = new ArrayList<>();
// 执行所有拦截器,最终返回方法执行结果
public Object pluginAll(Object target) {
for (Interceptor interceptor : interceptors) {
// 将目标方法进行增强,创建目标对象的代理对象
// 通过代理对象进行拦截目标方法,符合条件则执行拦截器
// 执行目标方法,返回最终结果
target = interceptor.plugin(target);
}
// 执行的结果
return target;
}
// 添加拦截器
public void addInterceptor(Interceptor interceptor) {
interceptors.add(interceptor);
}
// 获取所有拦截器
public List<Interceptor> getInterceptors() {
return Collections.unmodifiableList(interceptors);
}
}
// 核心配置类组件
public class Configuration {
// 拦截器的执行链
public InterceptorChain interceptorChain = new InterceptorChain();
// 由下面的方法可知,插件可以拦截的对象由四个
// 就是Mybatis的四大组件,ParameterHandler,ResultSetHandler,StatementHandler,Executor
// 创建Parameter参数处理器
public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
// 执行所有拦截器
return (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
}
// 创建ResultSet结果集处理器
public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler, ResultHandler resultHandler, BoundSql boundSql) {
ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
// 执行所有拦截器
return (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
}
// 创建Statement处理器
public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
// 执行所有拦截器
return (StatementHandler) interceptorChain.pluginAll(statementHandler);
}
// 创建默认的执行器
public Executor newExecutor(Transaction transaction) {
return newExecutor(transaction, defaultExecutorType);
}
// 创建指定类型的执行器
public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
executorType = executorType == null ? defaultExecutorType : executorType;
Executor executor;
if (ExecutorType.BATCH == executorType) {
executor = new Executor.BatchExecutor(this, transaction);
} else if (ExecutorType.REUSE == executorType) {
executor = new Executor.ReuseExecutor(this, transaction);
} else {
executor = new Executor.SimpleExecutor(this, transaction);
}
// 如果开启了二级缓存,使用了装饰器模式
if (cacheEnabled) {
executor = new Executor.CachingExecutor(executor);
}
// 执行所有拦截器
return (Executor) interceptorChain.pluginAll(executor);
}
}