MyBatis Plus 使用拦截器实现数据自定义权限查询数据

文章介绍了如何在MyBatis中使用自定义注解和拦截器实现数据权限控制,包括创建DataScope注解、DataEnum枚举和自定义DataPermission拦截器,以根据用户角色限制数据访问范围。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

开发中遇到根据当前用户的角色,只能查看数据权限范围的数据需求。列表实现方案有两种,一是在开发初期就做好判断赛选,但如果这个需求是中途加的,或不希望每个接口都加一遍,就可以方案二加拦截器的方式。在mybatis执行sql前修改语句,限定where范围。
当然拦截器生效后是全局性的,如何保证只对需要的接口进行拦截和转化,就可以应用注解进行识别,话不多说,直接show code

1. 创建自定义注解,数据条件枚举(非必须)

package com.example.ftserver.datapermission;

import java.lang.annotation.*;

/**
 * @author aaa
 * @date 2023-10-31 16:24
 * @description 自定义数据范围注解
 */
@Target({ElementType.TYPE,ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DataScope {
    /**
     * 表别名
     *
     * @return str
     */
    String alias() default "";


    /**
     * 数据范围方式:如 >,=
     * @return {@link DataEnum[]}
     */
    DataEnum[] dataScope() default {

    };
    /**
     * 数据范围,这只是为了方便测试,实际项目中可以在代码中实现
     * @return {@link int[]}
     */
    int[] scopeArrays() default {

    };
}

package com.example.ftserver.datapermission;

import com.test.ft.common.exception.CommonException;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.commons.lang3.StringUtils;

import java.util.Arrays;

/**
 * @author aaa
 * @description 数据权限条件,可以领先配置,非必须,此处只做一个测试样例
 */
@Getter
@AllArgsConstructor
public enum DataEnum {
    /**
     * 等于
     */
    EQ("="),
    LIKE("like"),

    IN("in"),

    ALL("all"),

    GT(">"),
    GE(">="),
    LT("<"),
    LE("<=");
    private final String value;

    public static String getDataValue(String value){
        return Arrays.stream(DataEnum.values()).filter(e -> StringUtils.equals(e.value, value)).findFirst().orElseThrow(() -> new CommonException("请输入正确的参数")).getValue();
    }

}

2.数据权限拦截器

package com.example.ftserver.datapermission;

import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SetOperationList;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.schema.Column;

/**
 * @author aaa
 * @description 自定义数据拦截器
 */

// @NoArgsConstructor 有些是把where条件的设置放置在一个独立的类里面, 然后再通过属性设置,此时是需要该注解的
@AllArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@Slf4j
public class MyDataPermission extends JsqlParserSupport implements InnerInterceptor {

    @SneakyThrows(Exception.class)
    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql){
        String id = ms.getId();
        if (InterceptorIgnoreHelper.willIgnoreBlockAttack(id)) {
            return;
        }
        String className = id.substring(0, id.lastIndexOf("."));
        String methodName = id.substring(id.lastIndexOf(".") + 1);
        Method[] methods = Class.forName(className).getMethods();
        Optional<Method> first = Arrays.stream(methods).filter(a -> a.getName().equals(methodName)).findFirst();
        if (first.isPresent()) {
            DataScope annotation = first.get().getAnnotation(DataScope.class);
            if (annotation == null) {
                return;
            }
        }
        PluginUtils.MPBoundSql sql = PluginUtils.mpBoundSql(boundSql);
        sql.sql(this.parserSingle(sql.sql(), id));
    }

    @Override
    protected void processSelect(Select select, int index, String sql, Object obj) {
        SelectBody selectBody = select.getSelectBody();
        if (selectBody instanceof PlainSelect) {
            this.setWhere((PlainSelect) selectBody, (String) obj);
        } else if (selectBody instanceof SetOperationList) {
            SetOperationList setOperationList = (SetOperationList) selectBody;
            List<SelectBody> selectBodyList = setOperationList.getSelects();
            selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, (String) obj));
        }
    }

    /**
     * 设置 where 条件,过滤已删除数据
     *
     * @param plainSelect  查询对象
     * @param whereSegment 查询条件片段
     */
    @SneakyThrows(Exception.class)
    private void setWhere(PlainSelect plainSelect, String whereSegment) {
        Expression where = plainSelect.getWhere();
        if (where == null) {
            where = new HexValue(" 1 = 1 ");
        }
        //获取mapper名称
        String className = whereSegment.substring(0, whereSegment.lastIndexOf("."));
        //获取方法名
        String methodName = whereSegment.substring(whereSegment.lastIndexOf(".") + 1);
        Table fromItem = (Table) plainSelect.getFromItem();
        // 有别名用别名,无别名用表名,防止字段冲突报错
        Alias fromItemAlias = fromItem.getAlias();
        String mainTableName = fromItemAlias == null ? null : fromItemAlias.getName();
        log.info("interceptor begin,where :{},className:{},methodName:{}", where, className, methodName);

        //获取当前mapper 的方法
        Method[] methods = Class.forName(className).getMethods();
        //遍历判断mapper 的所以方法,判断方法上是否有 DataScope注解
        for (Method m : methods) {
            if (Objects.equals(m.getName(), methodName)) {
                DataScope annotation = m.getAnnotation(DataScope.class);
                if (annotation == null) {
                    break;
                }
                // 部门条件拼接
                Expression andExpression = new HexValue(" 1 = 1 ");
                String alias = annotation.alias();
                mainTableName = StringUtils.isNotBlank(alias) ? alias : mainTableName;
                // 查看不同部门数据
                DataEnum[] enums = annotation.dataScope();
                if (enums.length > 0 && StringUtils.equals(fromItem.getName(), "person_t")) { // 只是示例:测试一下一个表
                    if (enums.length == 1) {
                        DataEnum dataEnum = enums[0];
                        // = 条件配置
                        if (DataEnum.EQ.equals(dataEnum)) {
                            EqualsTo to = new EqualsTo();
                            to.setLeftExpression(new Column(mainTableName == null ? "dept_code" : mainTableName + ".dept_code"));
                            to.setRightExpression(new LongValue(annotation.scopeArrays()[0]));
                            andExpression = to;
                        }
                        // IN条件
                        if (DataEnum.IN.equals(dataEnum)) {
                            InExpression inExpression = new InExpression();
                            inExpression.setLeftExpression(new Column(mainTableName == null ? "dept_code" : mainTableName + ".dept_code"));
                            ItemsList expressionList = new ExpressionList(Arrays.stream(annotation.scopeArrays()).mapToObj(Long::valueOf).map(LongValue::new).collect(Collectors.toList()));
                            inExpression.withRightItemsList(expressionList);
                            andExpression = inExpression;
                        }
                        // > 条件
                        if (DataEnum.GT.equals(dataEnum)) {
                            GreaterThan to = new GreaterThan();
                            to.setLeftExpression(new Column(mainTableName == null ? "dept_code" : mainTableName + ".dept_code"));
                            to.setRightExpression(new LongValue(annotation.scopeArrays()[0]));
                            andExpression = to;
                        }
                        // >= 条件
                        if (DataEnum.GE.equals(dataEnum)) {
                            GreaterThanEquals to = new GreaterThanEquals();
                            to.setLeftExpression(new Column(mainTableName == null ? "dept_code" : mainTableName + ".dept_code"));
                            to.setRightExpression(new LongValue(annotation.scopeArrays()[0]));
                            andExpression = to;
                        }

                    }
                    // between and 条件
                    if (enums.length == 2) {
                        DataEnum dataEnum = enums[0];
                        DataEnum enum1 = enums[1];
                        if (DataEnum.GE.equals(dataEnum) && DataEnum.LE.equals(enum1)) {
                            Between to = new Between();
                            to.setLeftExpression(new Column(mainTableName == null ? "dept_code" : mainTableName + ".dept_code"));
                            to.setBetweenExpressionStart(new LongValue(annotation.scopeArrays()[0]));
                            to.setBetweenExpressionEnd(new LongValue(annotation.scopeArrays()[1]));
                            andExpression = to;
                        }
                    }
                }
                // 查看未删除的数据
                EqualsTo usesEqualsTo = new EqualsTo();
                usesEqualsTo.setLeftExpression(new Column(mainTableName == null ? "is_delete" : mainTableName + ".is_delete"));
                usesEqualsTo.setRightExpression(new LongValue(0));
                IsNullExpression nullExpression = new IsNullExpression();
                nullExpression.withLeftExpression(new Column(mainTableName == null ? "is_delete" : mainTableName + ".is_delete"));
                // ()括号拼接
                Parenthesis parenthesis = new Parenthesis(new OrExpression(usesEqualsTo, nullExpression));
                AndExpression expression = new AndExpression(where, parenthesis);
                plainSelect.setWhere(new AndExpression(expression, andExpression));
                log.info("interceptor end,where :{}", plainSelect);
                break;
            }
        }
    }
}

Mybatis Plus原始的BaseMapper里的方法查询,也可以使用自定义的数据权限注解来实现数据权限的控制,只需要新建一个继承BaseMapper的类,并重写相关方法即可

package com.example.ftserver.plugin;

import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.example.ftserver.datapermission.DataEnum;
import com.example.ftserver.datapermission.DataScope;
import org.apache.ibatis.annotations.Param;

import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import java.util.Map;

/**
 * @author aaa
 * @description
 */
public interface RootMapper<T> extends BaseMapper<T> {

    /**
     * 自定义批量新增,代替mybatis plus 自带的批量新增
     *
     * @param batchList 批量新增参数
     * @return int
     */
    int insertBatch(@Param("list") Collection<T> batchList);

    /**
     * 根据id批量新增
     *
     * @param batchList 批量更新参数
     * @return int
     */
    int updateBatch(@Param("list") Collection<T> batchList);

    /**
     * 根据 ID 查询
     *
     * @param id 主键ID
     */
    @Override
    @DataScope
    T selectById(Serializable id);

    /**
     * 查询(根据ID 批量查询)
     *
     * @param idList 主键ID列表(不能为 null 以及 empty)
     */
    @Override
    @DataScope
    List<T> selectBatchIds(@Param(Constants.COLL) Collection<? extends Serializable> idList);

    /**
     * 查询(根据 columnMap 条件)
     *
     * @param columnMap 表字段 map 对象
     */
    @Override
    @DataScope
    List<T> selectByMap(@Param(Constants.COLUMN_MAP) Map<String, Object> columnMap);

    /**
     * 根据 Wrapper 条件,查询总记录数
     *
     * @param queryWrapper 实体对象封装操作类(可以为 null)
     */
    @Override
    @DataScope
    Long selectCount(@Param(Constants.WRAPPER) Wrapper<T> queryWrapper);

    /**
     * 根据 entity 条件,查询全部记录
     *
     * @param queryWrapper 实体对象封装操作类(可以为 null)
     */
    @Override
    // 测试只查询部门编码为1 的数据
    @DataScope(dataScope = {DataEnum.EQ}, scopeArrays = {1}) 
    List<T> selectList(@Param(Constants.WRAPPER) Wrapper<T> queryWrapper);

    /**
     * 根据 Wrapper 条件,查询全部记录
     *
     * @param queryWrapper 实体对象封装操作类(可以为 null)
     */
    @Override
    @DataScope
    List<Map<String, Object>> selectMaps(@Param(Constants.WRAPPER) Wrapper<T> queryWrapper);

    /**
     * 根据 Wrapper 条件,查询全部记录
     * <p>注意: 只返回第一个字段的值</p>
     *
     * @param queryWrapper 实体对象封装操作类(可以为 null)
     */
    @Override
    @DataScope
    List<Object> selectObjs(@Param(Constants.WRAPPER) Wrapper<T> queryWrapper);

    /**
     * 根据 entity 条件,查询全部记录(并翻页)
     *
     * @param page         分页查询条件(可以为 RowBounds.DEFAULT)
     * @param queryWrapper 实体对象封装操作类(可以为 null)
     */
    @Override
    // 测试只查询部门编码为1 到3 的数据
    @DataScope(dataScope = {DataEnum.GE, DataEnum.LE}, scopeArrays = {1, 3})
    <P extends IPage<T>> P selectPage(P page, @Param(Constants.WRAPPER) Wrapper<T> queryWrapper);

    /**
     * 根据 Wrapper 条件,查询全部记录(并翻页)
     *
     * @param page         分页查询条件
     * @param queryWrapper 实体对象封装操作类
     */
    @Override
    @DataScope
    <P extends IPage<Map<String, Object>>> P selectMapsPage(P page, @Param(Constants.WRAPPER) Wrapper<T> queryWrapper);

    /**
     * 根据 entity 条件,查询一条记录
     * <p>查询一条记录,例如 qw.last("limit 1") 限制取一条记录, 注意:多条数据会报异常</p>
     *
     * @param queryWrapper 实体对象封装操作类(可以为 null)
     */
    @Override
    @DataScope
    default T selectOne(@Param(Constants.WRAPPER) Wrapper<T> queryWrapper) {
        return BaseMapper.super.selectOne(queryWrapper);
    }

    /**
     * 根据 Wrapper 条件,判断是否存在记录
     *
     * @param queryWrapper 实体对象封装操作类
     * @return 是否存在记录
     */
    @Override
    @DataScope
    default boolean exists(@Param(Constants.WRAPPER) Wrapper<T> queryWrapper) {
        return BaseMapper.super.exists(queryWrapper);
    }
}

需要注意的一点:重写相关方法时,要把@Param内容加上,并且内容字段与BaseMapper里需要保持一致

3.将拦截器注入配置文件

在这里插入图片描述
大概的步骤已经完成,下面做一个简单的测试,新建一个表,里面有简单的部门编码和一些基础信息,如图所示:
在这里插入图片描述
测试场景1,使用mybatis plus 自带的selectList()查询,在该方法上已经加上自定义数据权限@DataScope注解:
在这里插入图片描述
测试结果,可以看得到SQL已经添加上注解所带条件:
在这里插入图片描述
测试场景2,在分页查询上加上自定义数据权限@DataScope注解:
在这里插入图片描述
打印SQL,也是生效的:
在这里插入图片描述
测试场景3,自定义的SQL使用自定义数据权限@DataScope注解:
在这里插入图片描述
测试结果,可以看到同样是带字符串S的列,dept_code不为2的就没有被查询出来:
在这里插入图片描述
在这里插入图片描述
以上就是全篇知识点, 需要注意的点可能有:

1.记得把拦截器加到MyBatis-Plus的插件中,确保生效

2.where条件的拼接需要多多调试,多注意别名和条件的构造

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值