首页我们需要了解一下mybatis的org.apache.ibatis.plugin.Interceptor与Handler和Executor。mybatis的handler处理器主要有3个与一个Executor。分别为StatementHandler、ResultSetHandler、ParameterHandler与Executor。
在mybatis中的我们可以使用@Intercepts来添加拦截器。具体用法可以参考官方文档
铺垫差不多了,直接上代码
主要分两部分来解决这个问题,这样更容易理解与更容易扩展。分为查数据与查总数。
先看查总数的:MybatisTotalInterceptor
/**
* 拦截器分页
*
* 使用自动拦截器分页,满足如下两个条件:
*
* 1、参数必须是{@link PaginableParameter}的子类或实现类
* 2、查询SQL ID必须以PagedList结尾 或是查询参数pagingable=ture
*/
@Intercepts({@Signature(type = StatementHandler.class, method = "query", args ={Statement.class, ResultHandler.class})})
@Repository
public class MybatisTotalInterceptor implements Interceptor {
private static final Logger logger = LoggerFactory.getLogger(MybatisTotalInterceptor.class);
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object resultObject = invocation.proceed();
int currSize = ((List) resultObject).size();
Object classObject = invocation.getTarget();
if (classObject instanceof RoutingStatementHandler) {
RoutingStatementHandler statementHandler = (RoutingStatementHandler) classObject;
StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");
//参数判断
BoundSql boundSql = delegate.getBoundSql();
Object obj = boundSql.getParameterObject();
if (obj instanceof PaginableParameter) {
PaginableParameter page = (PaginableParameter) obj;
//通过反射获取delegate父类BaseStatementHandler的mappedStatement属性
MappedStatement mappedStatement =
(MappedStatement) ReflectHelper.getFieldValue(delegate, "mappedStatement");
//SQLID正则判断 或是确定可以分页
if (page.getPagingable()) {
//logger.info("分页总条查了");
if ((currSize == 0 && page.getPageIndex() == 0) || (currSize != 0 && currSize < page.getPageSize())) {
page.setTotal(page.getPageIndex() * page.getPageSize() + currSize);
} else {
//拦截到的prepare方法参数是一个Connection对象
Statement statement = (Statement) invocation.getArgs()[0];
//给当前的page参数对象设置总记录数
this.setTotalRecord(page, mappedStatement, statement.getConnection());
}
}
}
}
return resultObject;
}
/**
* 给当前的参数对象page设置总记录数
*
* @param page Mapper映射语句对应的参数对象
* @param mappedStatement Mapper映射语句
* @param connection 当前的数据库连接
*/
private void setTotalRecord(PaginableParameter page, MappedStatement mappedStatement, Connection connection) {
BoundSql boundSql = mappedStatement.getBoundSql(page);
String sql = boundSql.getSql();
String countSql = this.getCountSql(sql, page.getCountField());
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);
//这儿会涉及到一个列表参数查询无法正常赋值问题,那么我自己重写了DefaultParameterHandler。生成MybatisParameterHandlerExtends(在下面会有代码)类。
ParameterHandler parameterHandler = new MybatisParameterHandlerExtends(mappedStatement, page, countBoundSql);
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
logger.debug("select paged list count sql = {}", countSql);
pstmt = connection.prepareStatement(countSql);
parameterHandler.setParameters(pstmt);
rs = pstmt.executeQuery();
if (rs.next()) {
int totalRecord = rs.getInt(1);
//给当前的参数page对象设置总记录数
page.setTotal(totalRecord);
}
} catch (SQLException e) {
logger.error("paging set parameter error,msg:{}", e);
} finally {
try {
if (rs != null)
rs.close();
if (pstmt != null)
pstmt.close();
} catch (SQLException e1) {
logger.error("paging set parameter error,msg:{}", e1);
}
}
}
/**
* 根据原Sql语句获取对应的查询总记录数的Sql语句
*
* @param sql
* @return
*/
private String getCountSql(String sql, String countField) {
if (StringUtils.isEmpty(countField)) {
countField = "count(1)";
}
return "select ".concat(countField).concat(" from (").concat(sql).concat(" ) as ")
.concat(MybatisPagedTableNameConstants.TABLE_NAME);
}
/**
* 拦截器对应的封装原始对象的方法
*/
public Object plugin(Object arg0) {
if (arg0 instanceof StatementHandler) {
return Plugin.wrap(arg0, this);
} else {
return arg0;
}
}
/**
* 设置注册拦截器时设定的属性
*/
@Override
public void setProperties(Properties p) {
logger.info("set properties parameter,p={}", p);
}
再来看看查数据的 MybatisPagedInterceptor
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
@Repository
public class MybatisPagedInterceptor implements Interceptor {
private static final Logger logger = LoggerFactory.getLogger(MybatisPagedInterceptor.class);
/**
* 默认分页index
*/
private static final Integer DEFAULT_PAGE_INDEX = 0;
/**
* 默认分页sizes
*/
private static final Integer DEFAULT_PAGE_SIZE = 10;
@Override
public Object intercept(Invocation invocation) throws Throwable {
if (invocation.getTarget() instanceof RoutingStatementHandler) {
RoutingStatementHandler statementHandler = (RoutingStatementHandler) invocation.getTarget();
StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");
//参数判断
BoundSql boundSql = delegate.getBoundSql();
Object obj = boundSql.getParameterObject();
if (obj instanceof PaginableParameter) {
PaginableParameter page = (PaginableParameter) obj;
//SQLID正则判断 或是确定可以分页
if (page.getPagingable()) {
String sql = boundSql.getSql();
//获取分页Sql语句
String pageSql = this.getPageSql(page, sql);
//利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
ReflectHelper.setFieldValue(boundSql, "sql", pageSql);
//logger.info("分页数据执行了");
}
}
}
return invocation.proceed();
}
/**
* 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle
* 其它的数据库都 没有进行分页
*
* @param page 分页对象
* @param sql 原sql语句
* @return
*/
private String getPageSql(PaginableParameter page, String sql) {
//设置超过的pageIndex
/* if (page.getTotal() <= 0) {
page.setPageIndex(0);
}*/
//默认分页值
if (page.getPageIndex() == null) {
page.setPageIndex(DEFAULT_PAGE_INDEX);
}
if (page.getPageSize() == null || page.getPageSize() == 0) {
page.setPageSize(DEFAULT_PAGE_SIZE);
}
StringBuilder sqlBuffer = new StringBuilder(sql);
return getMysqlPageSql(page, sqlBuffer);
}
/**
* 获取Mysql数据库的分页查询语句
*
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Mysql数据库分页语句
*/
private String getMysqlPageSql(PaginableParameter page, StringBuilder sqlBuffer) {
//计算第一条记录的位置,Mysql中记录的位置是从0开始的。
int offset = (page.getPageIndex()) * page.getPageSize();
sqlBuffer.append(" limit ").append(offset + page.getStartOffset()).append(",").append(page.getPageSize());
return sqlBuffer.toString();
}
/**
* 拦截器对应的封装原始对象的方法
*/
public Object plugin(Object arg0) {
if (arg0 instanceof StatementHandler) {
return Plugin.wrap(arg0, this);
} else {
return arg0;
}
}
/**
* 设置注册拦截器时设定的属性
*/
@Override
public void setProperties(Properties p) {
logger.info("set properties parameter,p={}", p);
}
下面是MybatisParameterHandlerExtends代码
/**
* 分页时重写参数放入
*
* @author: luole
* @date: 2018/6/26 10:54
* @description:
*/
@Slf4j
public class MybatisParameterHandlerExtends implements ParameterHandler {
private final TypeHandlerRegistry typeHandlerRegistry;
private final MappedStatement mappedStatement;
private final Object parameterObject;
private final BoundSql boundSql;
private final Configuration configuration;
public MybatisParameterHandlerExtends(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
this.mappedStatement = mappedStatement;
this.configuration = mappedStatement.getConfiguration();
this.typeHandlerRegistry = mappedStatement.getConfiguration().getTypeHandlerRegistry();
this.parameterObject = parameterObject;
this.boundSql = boundSql;
}
@Override
public Object getParameterObject() {
return parameterObject;
}
/**
* 主要重写属性放入
*
* @param ps 参数列表
*/
@Override
public void setParameters(PreparedStatement ps) {
ErrorContext.instance().activity("setting parameters").object(mappedStatement.getParameterMap().getId());
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
FrchPropertyParam frchPropertyParam = new FrchPropertyParam();
if (parameterMappings != null) {
for (int i = 0; i < parameterMappings.size(); i++) {
ParameterMapping parameterMapping = parameterMappings.get(i);
if (parameterMapping.getMode() != ParameterMode.OUT) {
Object value;
String propertyName = parameterMapping.getProperty();
value = getFrchPropertyObject(propertyName, frchPropertyParam);
TypeHandler typeHandler = parameterMapping.getTypeHandler();
JdbcType jdbcType = parameterMapping.getJdbcType();
if (value == null && jdbcType == null) {
log.debug("set parameters ,propertyName = {},value=null", propertyName);
jdbcType = configuration.getJdbcTypeForNull();
}
try {
log.debug("set parameters,propertyName = {},value={}", propertyName, value);
typeHandler.setParameter(ps, i + 1, value, jdbcType);
} catch (SQLException e) {
throw new TypeException(
"Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
}
}
}
}
}
/**
* 取参数过程
* @param propertyName
* @return
*/
private Object getFrchPropertyObject(String propertyName, FrchPropertyParam frchPropertyParam) {
Object value;
if (boundSql.hasAdditionalParameter(propertyName)) { // issue #448 ask first for additional params
value = boundSql.getAdditionalParameter(propertyName);
} else if (parameterObject == null) {
value = null;
} else if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
value = parameterObject;
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
value = getValue(propertyName, metaObject, frchPropertyParam);//
}
return value;
}
/**
* 重写取值方法
* @param propertyName
* 参数名
* @param metaObject
* @return
*/
private Object getValue(String propertyName, MetaObject metaObject, FrchPropertyParam frchPropertyParam) {
if (propertyName.startsWith("__frch_")) {
propertyName = propertyName.replaceAll("__frch_", "");
String[] propertys = propertyName.split("_");
//构造取参数的过程
if (StringUtils.isEmpty(frchPropertyParam.getLastParameter())
|| !frchPropertyParam.getLastParameter().equals(propertys[0])) {
frchPropertyParam.setLastParameter(propertys[0]);
frchPropertyParam.setCurrStartIndex(propertys[1]);
}
return getFrchPropertyObject(metaObject, propertys, frchPropertyParam);
} else {
return metaObject.getValue(propertyName);
}
}
/**
* 实际取参数方法
* @param metaObject
* @param propertys
* @return
*/
private Object getFrchPropertyObject(MetaObject metaObject, String[] propertys,
FrchPropertyParam frchPropertyParam) {
String proName = propertys[0];
for (String name : metaObject.getGetterNames()) {
if (name.startsWith(propertys[0]) && metaObject.getValue(name) instanceof List) {
proName = name;
}
}
log.debug(
"query paged list,select count parameter getfrch,proName = {},propertyIndex = {},frchPropertyParam = {}",
proName, propertys[1], frchPropertyParam);
List paramList = (List) metaObject.getValue(proName);
int propertyIndex = Integer.parseInt(propertys[1]);
if (StringUtils.isNotEmpty(frchPropertyParam.getCurrStartIndex())) {
propertyIndex = propertyIndex - Integer.parseInt(frchPropertyParam.getCurrStartIndex());
if (propertyIndex < 0) {
frchPropertyParam.setCurrStartIndex("" + propertyIndex);
propertyIndex = 0;
}
}
if (propertyIndex <= paramList.size()) {
return paramList.get(propertyIndex);
} else {
return "";
}
}
}
对于上面这三个主要的类中还涉及到的几个其他需要的类。
- PaginableParameter 查询参数接口
- FrchPropertyParam 对于list参数的封装
- Plugin 非常重要的一个类
- ReflectHelper 反射类
- PagedResult 结果返回类
代码如下:
public class Plugin implements InvocationHandler {
private final Object lastTarget;
private final Object target;
private final Interceptor interceptor;
private final Map<Class<?>, Set<Method>> signatureMap;
private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap, Object lastTarget) {
this.target = target;
this.interceptor = interceptor;
this.signatureMap = signatureMap;
this.lastTarget = lastTarget;
}
public static Object wrap(Object target, Interceptor interceptor) {
Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
Class<?> type = target.getClass();
Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
if (interfaces.length > 0) {
Object currTarget = target;
if (target instanceof Proxy) {
//target = ((Plugin) target).target;
try {
//获取指定对象中此字段的值
currTarget = getTarget(target); //获取Proxy对象中的此字段的值
} catch (Exception e) {
//e.printStackTrace();
}
}
return Proxy.newProxyInstance(
type.getClassLoader(),
interfaces,
new Plugin(currTarget, interceptor, signatureMap, target));
}
return target;
}
public static Object getTarget(Object proxy) throws Exception {
for (; proxy instanceof Proxy; ) {
Field field = proxy.getClass().getSuperclass().getDeclaredField("h");
field.setAccessible(true);
//获取指定对象中此字段的值
Plugin personProxy = (Plugin) field.get(proxy); //获取Proxy对象中的此字段的值
Field person = personProxy.getClass().getDeclaredField("target");
person.setAccessible(true);
proxy = person.get(personProxy);
}
return proxy;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
try {
Set<Method> methods = signatureMap.get(method.getDeclaringClass());
if (methods != null && methods.contains(method)) {
return interceptor.intercept(new Invocation(target, method, args));
}
//这个方法实现了职责连模式的向下执行
return method.invoke(lastTarget, args);
} catch (Exception e) {
throw ExceptionUtil.unwrapThrowable(e);
}
}
private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
// issue #251
if (interceptsAnnotation == null) {
throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
}
Signature[] sigs = interceptsAnnotation.value();
Map<Class<?>, Set<Method>> signatureMap = new HashMap<Class<?>, Set<Method>>();
for (Signature sig : sigs) {
Set<Method> methods = signatureMap.computeIfAbsent(sig.type(), k -> new HashSet<>());
try {
Method method = sig.type().getMethod(sig.method(), sig.args());
methods.add(method);
} catch (NoSuchMethodException e) {
throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
}
}
return signatureMap;
}
private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
Set<Class<?>> interfaces = new HashSet<Class<?>>();
while (type != null) {
for (Class<?> c : type.getInterfaces()) {
if (signatureMap.containsKey(c)) {
interfaces.add(c);
}
}
type = type.getSuperclass();
}
return interfaces.toArray(new Class<?>[interfaces.size()]);
}
}
public class ReflectHelper {
private static final Logger logger = LoggerFactory.getLogger(ReflectHelper.class);
private ReflectHelper() {
}
/**
* @param obj
* @param fieldName
* @return
*/
public static Object getFieldValue(Object obj, String fieldName) {
if (obj == null) {
return null;
}
Field targetField = getTargetField(obj.getClass(), fieldName);
try {
return FieldUtils.readField(targetField, obj, true);
} catch (IllegalAccessException e) {
logger.error("", e);
}
return null;
}
private static Field getTargetField(Class<?> targetClass, String fieldName) {
Field field = null;
try {
if (targetClass == null) {
return field;
}
if (Object.class.equals(targetClass)) {
return field;
}
field = FieldUtils.getDeclaredField(targetClass, fieldName, true);
if (field == null) {
field = getTargetField(targetClass.getSuperclass(), fieldName);
}
} catch (Exception e) {
logger.error("", e);
}
return field;
}
public static void setFieldValue(Object obj, String fieldName, Object value) {
if (null == obj) {
return;
}
Field targetField = getTargetField(obj.getClass(), fieldName);
try {
FieldUtils.writeField(targetField, obj, value);
} catch (IllegalAccessException e) {
logger.error("", e);
}
}
}
@Data
public class FrchPropertyParam {
private String lastParameter;
private String currStartIndex;
}
public interface PaginableParameter {
/**
* @return 页索引
*/
Integer getPageIndex();
/**
* @param pageIndex
* 页索引
*/
void setPageIndex(Integer pageIndex);
/**
* @return 页行数
*/
Integer getPageSize();
/**
* @param pageSize
* 页行数
*/
void setPageSize(Integer pageSize);
/**
* @return 总数
*/
Integer getTotal();
/**
* @param totle
* 总数
*/
void setTotal(Integer totle);
/**
*
* @return 是否分页,默认为true
*/
boolean getPagingable();
/**
*
* @param pagingable
* 是否分页,默认为true
*/
void setPagingable(boolean pagingable);
void setCountField(String countField);
String getCountField();
/**
* @return
* 起始值偏移量
*/
int getStartOffset();
void setStartOffset(int offset);
}
public class PagedResult<T> implements Serializable {
private static final long serialVersionUID = -1589842116452346939L;
/**
* 最大页行数
*/
public static final int MAX_PAGE_SIZE = 2000;
/**
* 页索引,0为起始
*/
@ApiModelProperty("页索引,0为起始")
private int pageIndex;
/**
* 页行数
*/
@ApiModelProperty("页行数")
private int pageSize;
/**
* 总行数
*/
@ApiModelProperty("总行数")
private long total;
/**
* 页数
*/
@ApiModelProperty("总页数")
private int pageCount;
/**
* 当前页数据
*/
@ApiModelProperty("当前页数据")
private List<T> list;//NOSONAR
public PagedResult() {
}
public PagedResult(PaginableParameter paginableParameter, List<T> list) {
this.list = list;
//只有做了分页的才计算分页数据
if (paginableParameter.getPagingable()) {
this.pageIndex = paginableParameter.getPageIndex();
this.pageSize = paginableParameter.getPageSize();
this.total = paginableParameter.getTotal();
this.pageSize = this.pageSize < 1 ? 1 : this.pageSize;
this.pageCount = (int) (total / pageSize);
if (total % pageSize != 0) {
this.pageCount++;
}
}
}
}
如查不想直接复制,点击进入复制 可以在 去clone.