mybatis分页,结合mybatis的Interceptor做的自动分页

本文介绍了一种基于MyBatis的自定义分页插件实现方式,该插件能够自动拦截SQL查询并实现数据分页及总数统计。通过使用@Intercepts注解和自定义ParameterHandler,有效地解决了列表参数查询时的参数赋值问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

首页我们需要了解一下mybatis的org.apache.ibatis.plugin.Interceptor与Handler和Executor。mybatis的handler处理器主要有3个与一个Executor。分别为StatementHandler、ResultSetHandler、ParameterHandler与Executor。
在这里插入图片描述

在mybatis中的我们可以使用@Intercepts来添加拦截器。具体用法可以参考官方文档
铺垫差不多了,直接上代码

主要分两部分来解决这个问题,这样更容易理解与更容易扩展。分为查数据与查总数。
先看查总数的:MybatisTotalInterceptor

/**
 * 拦截器分页
 * 
 * 使用自动拦截器分页,满足如下两个条件:
 * 
 * 1、参数必须是{@link PaginableParameter}的子类或实现类
 * 2、查询SQL ID必须以PagedList结尾 或是查询参数pagingable=ture
 */
@Intercepts({@Signature(type = StatementHandler.class, method = "query", args ={Statement.class, ResultHandler.class})})
@Repository
public class MybatisTotalInterceptor implements Interceptor {

    private static final Logger logger = LoggerFactory.getLogger(MybatisTotalInterceptor.class);

    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        Object resultObject = invocation.proceed();
        int currSize = ((List) resultObject).size();
        Object classObject = invocation.getTarget();
        if (classObject instanceof RoutingStatementHandler) {
            RoutingStatementHandler statementHandler = (RoutingStatementHandler) classObject;
            StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");
            //参数判断
            BoundSql boundSql = delegate.getBoundSql();
            Object obj = boundSql.getParameterObject();
            if (obj instanceof PaginableParameter) {
                PaginableParameter page = (PaginableParameter) obj;
                //通过反射获取delegate父类BaseStatementHandler的mappedStatement属性  
                MappedStatement mappedStatement =
                        (MappedStatement) ReflectHelper.getFieldValue(delegate, "mappedStatement");
                //SQLID正则判断 或是确定可以分页
                if (page.getPagingable()) {
                    //logger.info("分页总条查了");

                    if ((currSize == 0 && page.getPageIndex() == 0) || (currSize != 0 && currSize < page.getPageSize())) {
                        page.setTotal(page.getPageIndex() * page.getPageSize() + currSize);
                    } else {
                        //拦截到的prepare方法参数是一个Connection对象
                        Statement statement = (Statement) invocation.getArgs()[0];
                        //给当前的page参数对象设置总记录数
                        this.setTotalRecord(page, mappedStatement, statement.getConnection());
                    }
                }
            }
        }
        return resultObject;
    }


    /**
     * 给当前的参数对象page设置总记录数
     *
     * @param page            Mapper映射语句对应的参数对象
     * @param mappedStatement Mapper映射语句
     * @param connection      当前的数据库连接
     */
    private void setTotalRecord(PaginableParameter page, MappedStatement mappedStatement, Connection connection) {
        BoundSql boundSql = mappedStatement.getBoundSql(page);
        String sql = boundSql.getSql();
        String countSql = this.getCountSql(sql, page.getCountField());
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);
        //这儿会涉及到一个列表参数查询无法正常赋值问题,那么我自己重写了DefaultParameterHandler。生成MybatisParameterHandlerExtends(在下面会有代码)类。
        ParameterHandler parameterHandler = new MybatisParameterHandlerExtends(mappedStatement, page, countBoundSql);

        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            logger.debug("select paged list count sql = {}", countSql);
            pstmt = connection.prepareStatement(countSql);
            parameterHandler.setParameters(pstmt);

            rs = pstmt.executeQuery();
            if (rs.next()) {
                int totalRecord = rs.getInt(1);
                //给当前的参数page对象设置总记录数
                page.setTotal(totalRecord);
            }
        } catch (SQLException e) {
            logger.error("paging set parameter error,msg:{}", e);
        } finally {
            try {
                if (rs != null)
                    rs.close();
                if (pstmt != null)
                    pstmt.close();
            } catch (SQLException e1) {
                logger.error("paging set parameter error,msg:{}", e1);
            }
        }
    }

    /**
     * 根据原Sql语句获取对应的查询总记录数的Sql语句
     *
     * @param sql
     * @return
     */
    private String getCountSql(String sql, String countField) {

        if (StringUtils.isEmpty(countField)) {
            countField = "count(1)";
        }
        return "select ".concat(countField).concat("  from (").concat(sql).concat(" ) as ")
                .concat(MybatisPagedTableNameConstants.TABLE_NAME);
    }

    /**
     * 拦截器对应的封装原始对象的方法
     */
    public Object plugin(Object arg0) {
        if (arg0 instanceof StatementHandler) {
            return Plugin.wrap(arg0, this);
        } else {
            return arg0;
        }
    }

    /**
     * 设置注册拦截器时设定的属性
     */
    @Override
    public void setProperties(Properties p) {
        logger.info("set properties parameter,p={}", p);
    }

再来看看查数据的 MybatisPagedInterceptor

	@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
@Repository
public class MybatisPagedInterceptor implements Interceptor {

    private static final Logger logger = LoggerFactory.getLogger(MybatisPagedInterceptor.class);

    /**
     * 默认分页index
     */
    private static final Integer DEFAULT_PAGE_INDEX = 0;

    /**
     * 默认分页sizes
     */
    private static final Integer DEFAULT_PAGE_SIZE = 10;


    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        if (invocation.getTarget() instanceof RoutingStatementHandler) {
            RoutingStatementHandler statementHandler = (RoutingStatementHandler) invocation.getTarget();
            StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");
            //参数判断
            BoundSql boundSql = delegate.getBoundSql();
            Object obj = boundSql.getParameterObject();
            if (obj instanceof PaginableParameter) {
                PaginableParameter page = (PaginableParameter) obj;
                //SQLID正则判断 或是确定可以分页
                if (page.getPagingable()) {
                    String sql = boundSql.getSql();
                    //获取分页Sql语句
                    String pageSql = this.getPageSql(page, sql);
                    //利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
                    ReflectHelper.setFieldValue(boundSql, "sql", pageSql);
                    //logger.info("分页数据执行了");
                }
            }
        }
        return invocation.proceed();
    }

    /**
     * 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle
     * 其它的数据库都 没有进行分页
     *
     * @param page 分页对象
     * @param sql  原sql语句
     * @return
     */
    private String getPageSql(PaginableParameter page, String sql) {
        //设置超过的pageIndex
       /* if (page.getTotal() <= 0) {
            page.setPageIndex(0);
        }*/

        //默认分页值
        if (page.getPageIndex() == null) {
            page.setPageIndex(DEFAULT_PAGE_INDEX);
        }
        if (page.getPageSize() == null || page.getPageSize() == 0) {
            page.setPageSize(DEFAULT_PAGE_SIZE);
        }

        StringBuilder sqlBuffer = new StringBuilder(sql);
        return getMysqlPageSql(page, sqlBuffer);
    }

    /**
     * 获取Mysql数据库的分页查询语句
     *
     * @param page      分页对象
     * @param sqlBuffer 包含原sql语句的StringBuffer对象
     * @return Mysql数据库分页语句
     */
    private String getMysqlPageSql(PaginableParameter page, StringBuilder sqlBuffer) {
        //计算第一条记录的位置,Mysql中记录的位置是从0开始的。
        int offset = (page.getPageIndex()) * page.getPageSize();
        sqlBuffer.append(" limit ").append(offset + page.getStartOffset()).append(",").append(page.getPageSize());
        return sqlBuffer.toString();
    }

    /**
     * 拦截器对应的封装原始对象的方法
     */
    public Object plugin(Object arg0) {
        if (arg0 instanceof StatementHandler) {
            return Plugin.wrap(arg0, this);
        } else {
            return arg0;
        }
    }

    /**
     * 设置注册拦截器时设定的属性
     */
    @Override
    public void setProperties(Properties p) {
        logger.info("set properties parameter,p={}", p);
    }

下面是MybatisParameterHandlerExtends代码

/**
 * 分页时重写参数放入
 *
 * @author: luole
 * @date: 2018/6/26 10:54
 * @description:
 */
@Slf4j
public class MybatisParameterHandlerExtends implements ParameterHandler {
    private final TypeHandlerRegistry typeHandlerRegistry;

    private final MappedStatement mappedStatement;
    private final Object parameterObject;
    private final BoundSql boundSql;
    private final Configuration configuration;

    public MybatisParameterHandlerExtends(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
        this.mappedStatement = mappedStatement;
        this.configuration = mappedStatement.getConfiguration();
        this.typeHandlerRegistry = mappedStatement.getConfiguration().getTypeHandlerRegistry();
        this.parameterObject = parameterObject;
        this.boundSql = boundSql;
    }

    @Override
    public Object getParameterObject() {
        return parameterObject;
    }


    /**
     * 主要重写属性放入
     *
     * @param ps 参数列表
     */
    @Override
    public void setParameters(PreparedStatement ps) {
        ErrorContext.instance().activity("setting parameters").object(mappedStatement.getParameterMap().getId());
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        FrchPropertyParam frchPropertyParam = new FrchPropertyParam();
        if (parameterMappings != null) {
            for (int i = 0; i < parameterMappings.size(); i++) {
                ParameterMapping parameterMapping = parameterMappings.get(i);
                if (parameterMapping.getMode() != ParameterMode.OUT) {
                    Object value;
                    String propertyName = parameterMapping.getProperty();
                    value = getFrchPropertyObject(propertyName, frchPropertyParam);
                    TypeHandler typeHandler = parameterMapping.getTypeHandler();
                    JdbcType jdbcType = parameterMapping.getJdbcType();
                    if (value == null && jdbcType == null) {
                        log.debug("set parameters ,propertyName = {},value=null", propertyName);
                        jdbcType = configuration.getJdbcTypeForNull();
                    }
                    try {
                        log.debug("set parameters,propertyName = {},value={}", propertyName, value);
                        typeHandler.setParameter(ps, i + 1, value, jdbcType);
                    } catch (SQLException e) {
                        throw new TypeException(
                                "Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
                    }
                }
            }
        }
    }

    /**
     * 取参数过程
     * @param propertyName
     * @return
     */
    private Object getFrchPropertyObject(String propertyName, FrchPropertyParam frchPropertyParam) {
        Object value;
        if (boundSql.hasAdditionalParameter(propertyName)) { // issue #448 ask first for additional params
            value = boundSql.getAdditionalParameter(propertyName);
        } else if (parameterObject == null) {
            value = null;
        } else if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
            value = parameterObject;
        } else {
            MetaObject metaObject = configuration.newMetaObject(parameterObject);
            value = getValue(propertyName, metaObject, frchPropertyParam);//
        }
        return value;
    }

    /**
     * 重写取值方法
     * @param propertyName
     *          参数名
     * @param metaObject
     * @return
     */
    private Object getValue(String propertyName, MetaObject metaObject, FrchPropertyParam frchPropertyParam) {
        if (propertyName.startsWith("__frch_")) {
            propertyName = propertyName.replaceAll("__frch_", "");
            String[] propertys = propertyName.split("_");

            //构造取参数的过程
            if (StringUtils.isEmpty(frchPropertyParam.getLastParameter())
                    || !frchPropertyParam.getLastParameter().equals(propertys[0])) {
                frchPropertyParam.setLastParameter(propertys[0]);
                frchPropertyParam.setCurrStartIndex(propertys[1]);
            }

            return getFrchPropertyObject(metaObject, propertys, frchPropertyParam);
        } else {
            return metaObject.getValue(propertyName);
        }
    }

    /**
     * 实际取参数方法
     * @param metaObject
     * @param propertys
     * @return
     */
    private Object getFrchPropertyObject(MetaObject metaObject, String[] propertys,
                                         FrchPropertyParam frchPropertyParam) {
        String proName = propertys[0];
        for (String name : metaObject.getGetterNames()) {
            if (name.startsWith(propertys[0]) && metaObject.getValue(name) instanceof List) {
                proName = name;
            }
        }

        log.debug(
                "query paged list,select count parameter getfrch,proName = {},propertyIndex = {},frchPropertyParam = {}",
                proName, propertys[1], frchPropertyParam);

        List paramList = (List) metaObject.getValue(proName);
        int propertyIndex = Integer.parseInt(propertys[1]);
        if (StringUtils.isNotEmpty(frchPropertyParam.getCurrStartIndex())) {
            propertyIndex = propertyIndex - Integer.parseInt(frchPropertyParam.getCurrStartIndex());
            if (propertyIndex < 0) {
                frchPropertyParam.setCurrStartIndex("" + propertyIndex);
                propertyIndex = 0;
            }
        }
        if (propertyIndex <= paramList.size()) {
            return paramList.get(propertyIndex);
        } else {
            return "";
        }
    }

}

对于上面这三个主要的类中还涉及到的几个其他需要的类。

  • PaginableParameter 查询参数接口
  • FrchPropertyParam 对于list参数的封装
  • Plugin 非常重要的一个类
  • ReflectHelper 反射类
  • PagedResult 结果返回类
    代码如下:
	public class Plugin implements InvocationHandler {

    private final Object lastTarget;
    private final Object target;
    private final Interceptor interceptor;
    private final Map<Class<?>, Set<Method>> signatureMap;

    private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap, Object lastTarget) {
        this.target = target;
        this.interceptor = interceptor;
        this.signatureMap = signatureMap;
        this.lastTarget = lastTarget;
    }

    public static Object wrap(Object target, Interceptor interceptor) {
        Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
        Class<?> type = target.getClass();
        Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
        if (interfaces.length > 0) {
            Object currTarget = target;
            if (target instanceof Proxy) {
                //target = ((Plugin) target).target;
                try {
                    //获取指定对象中此字段的值
                    currTarget = getTarget(target); //获取Proxy对象中的此字段的值
                } catch (Exception e) {
                    //e.printStackTrace();
                }
            }
            return Proxy.newProxyInstance(
                    type.getClassLoader(),
                    interfaces,
                    new Plugin(currTarget, interceptor, signatureMap, target));
        }
        return target;
    }

    public static Object getTarget(Object proxy) throws Exception {
        for (; proxy instanceof Proxy; ) {
            Field field = proxy.getClass().getSuperclass().getDeclaredField("h");
            field.setAccessible(true);
            //获取指定对象中此字段的值
            Plugin personProxy = (Plugin) field.get(proxy); //获取Proxy对象中的此字段的值
            Field person = personProxy.getClass().getDeclaredField("target");
            person.setAccessible(true);
            proxy = person.get(personProxy);
        }
        return proxy;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        try {
            Set<Method> methods = signatureMap.get(method.getDeclaringClass());
            if (methods != null && methods.contains(method)) {
                return interceptor.intercept(new Invocation(target, method, args));
            }
            //这个方法实现了职责连模式的向下执行
            return method.invoke(lastTarget, args);
        } catch (Exception e) {
            throw ExceptionUtil.unwrapThrowable(e);
        }
    }

    private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
        Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
        // issue #251
        if (interceptsAnnotation == null) {
            throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
        }
        Signature[] sigs = interceptsAnnotation.value();
        Map<Class<?>, Set<Method>> signatureMap = new HashMap<Class<?>, Set<Method>>();
        for (Signature sig : sigs) {
            Set<Method> methods = signatureMap.computeIfAbsent(sig.type(), k -> new HashSet<>());
            try {
                Method method = sig.type().getMethod(sig.method(), sig.args());
                methods.add(method);
            } catch (NoSuchMethodException e) {
                throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
            }
        }
        return signatureMap;
    }

    private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
        Set<Class<?>> interfaces = new HashSet<Class<?>>();
        while (type != null) {
            for (Class<?> c : type.getInterfaces()) {
                if (signatureMap.containsKey(c)) {
                    interfaces.add(c);
                }
            }
            type = type.getSuperclass();
        }
        return interfaces.toArray(new Class<?>[interfaces.size()]);
    }

}
public class ReflectHelper {

    private static final Logger logger = LoggerFactory.getLogger(ReflectHelper.class);

    private ReflectHelper() {
    }

    /**
     * @param obj
     * @param fieldName
     * @return
     */
    public static Object getFieldValue(Object obj, String fieldName) {
        if (obj == null) {
            return null;
        }

        Field targetField = getTargetField(obj.getClass(), fieldName);
        try {
            return FieldUtils.readField(targetField, obj, true);
        } catch (IllegalAccessException e) {
            logger.error("", e);
        }
        return null;
    }

    private static Field getTargetField(Class<?> targetClass, String fieldName) {
        Field field = null;

        try {
            if (targetClass == null) {
                return field;
            }

            if (Object.class.equals(targetClass)) {
                return field;
            }

            field = FieldUtils.getDeclaredField(targetClass, fieldName, true);
            if (field == null) {
                field = getTargetField(targetClass.getSuperclass(), fieldName);
            }
        } catch (Exception e) {
            logger.error("", e);
        }

        return field;
    }

    public static void setFieldValue(Object obj, String fieldName, Object value) {
        if (null == obj) {
            return;
        }
        Field targetField = getTargetField(obj.getClass(), fieldName);
        try {
            FieldUtils.writeField(targetField, obj, value);
        } catch (IllegalAccessException e) {
            logger.error("", e);
        }
    }
}

@Data
public class FrchPropertyParam {

    private String lastParameter;

    private String currStartIndex;
}

public interface PaginableParameter {

    /**
     * @return 页索引
     */
    Integer getPageIndex();

    /**
     * @param pageIndex
     *            页索引
     */
    void setPageIndex(Integer pageIndex);

    /**
     * @return 页行数
     */
    Integer getPageSize();

    /**
     * @param pageSize
     *            页行数
     */
    void setPageSize(Integer pageSize);

    /**
     * @return 总数
     */
    Integer getTotal();

    /**
     * @param totle
     *            总数
     */
    void setTotal(Integer totle);

    /**
     *
     * @return 是否分页,默认为true
     */
    boolean getPagingable();

    /**
     *
     * @param pagingable
     * 			是否分页,默认为true
     */
    void setPagingable(boolean pagingable);

    void setCountField(String countField);

    String getCountField();

    /**
     * @return
     *          起始值偏移量
     */
    int getStartOffset();

    void setStartOffset(int offset);
}
public class PagedResult<T> implements Serializable {

    private static final long serialVersionUID = -1589842116452346939L;

    /**
     * 最大页行数
     */
    public static final int MAX_PAGE_SIZE = 2000;

    /**
     * 页索引,0为起始
     */
    @ApiModelProperty("页索引,0为起始")
    private int pageIndex;

    /**
     * 页行数
     */
    @ApiModelProperty("页行数")
    private int pageSize;

    /**
     * 总行数
     */
    @ApiModelProperty("总行数")
    private long total;

    /**
     * 页数
     */
    @ApiModelProperty("总页数")
    private int pageCount;

    /**
     * 当前页数据
     */
    @ApiModelProperty("当前页数据")
    private List<T> list;//NOSONAR

    public PagedResult() {

    }

    public PagedResult(PaginableParameter paginableParameter, List<T> list) {
        this.list = list;

        //只有做了分页的才计算分页数据
        if (paginableParameter.getPagingable()) {
            this.pageIndex = paginableParameter.getPageIndex();
            this.pageSize = paginableParameter.getPageSize();
            this.total = paginableParameter.getTotal();

            this.pageSize = this.pageSize < 1 ? 1 : this.pageSize;
            this.pageCount = (int) (total / pageSize);
            if (total % pageSize != 0) {
                this.pageCount++;
            }
        }
    }
 }

如查不想直接复制,点击进入复制 可以在 去clone.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值