springboot mybatis(不是mybatis-plus)多租户,支持复杂查询


public class TenantContext {
    private static final ThreadLocal<String> TENANT_ID = new InheritableThreadLocal<>();

    public static void setTenantId(String tenantId) {
        TENANT_ID.set(tenantId);
    }

    public static String getTenantId() {
        return TENANT_ID.get();
    }

    public static void clear() {
        TENANT_ID.remove();
    }
}

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface TenantIgnore {
}

import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;


import java.lang.reflect.Field;
import java.util.*;


@Intercepts({
    @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
    @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})
})
public class TenantInterceptor implements Interceptor {
    private static final String TENANT_COLUMN = "tenant_id";
    private final Set<String> ignoreTables = new HashSet<>();

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        String tenantId = TenantContext.getTenantId();
        if (tenantId == null) return invocation.proceed();

        MappedStatement ms = (MappedStatement) invocation.getArgs()[0];
        Object parameter = invocation.getArgs();
        BoundSql boundSql = ms.getBoundSql(parameter);
        String originalSql = boundSql.getSql();

        // 解析并修改 SQL
        Statement statement = CCJSqlParserUtil.parse(originalSql);
        String newSql = processStatement(statement, tenantId, ms.getId());


        BoundSql newBoundSql = new BoundSql(
                ms.getConfiguration(),
                newSql,
                boundSql.getParameterMappings(),
                boundSql.getParameterObject());
        MappedStatement newMs = buildMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));

        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }

        // 替换 MappedStatement
        invocation.getArgs()[0] = newMs;

        return invocation.proceed();
    }

    /**
     * 根据已有MappedStatement构造新的MappedStatement
     */
    private MappedStatement buildMappedStatement(MappedStatement ms, SqlSource sqlSource) {
        MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), sqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
            StringBuilder keyProperties = new StringBuilder();
            for (String keyProperty : ms.getKeyProperties()) {
                keyProperties.append(keyProperty).append(",");
            }
            keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
            builder.keyProperty(keyProperties.toString());
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }
    private String processStatement(Statement statement, String tenantId, String mapperId) {
        if (statement instanceof Select) {
            processSelect((Select) statement, tenantId, mapperId);
        } else if (statement instanceof Insert) {
            processInsert((Insert) statement, tenantId, mapperId);
        } else if (statement instanceof Update) {
            processUpdate((Update) statement, tenantId);
        } else if (statement instanceof Delete) {
            processDelete((Delete) statement, tenantId);
        }
       return statement.toString();
    }

    //=============== SELECT 处理 ===============//
    private void processSelect(Select select, String tenantId, String mapperId) {
        SelectBody selectBody = select.getSelectBody();
        if (selectBody instanceof PlainSelect) {
            processPlainSelect(select,(PlainSelect) selectBody, tenantId, mapperId);
        } else if (selectBody instanceof SetOperationList) {
            ((SetOperationList) selectBody).getSelects().forEach(s -> processSelect((Select) s, tenantId, mapperId));
        }
    }

    private void processPlainSelect(Select select,PlainSelect plainSelect, String tenantId, String mapperId) {
        // 处理主表
        processFromItem(select,plainSelect.getFromItem(), tenantId, mapperId, plainSelect);

        // 处理 JOIN 表
        if (plainSelect.getJoins() != null) {
            plainSelect.getJoins().forEach(join ->
                processFromItem(select,join.getRightItem(), tenantId, mapperId, plainSelect)
            );
        }

        // 处理子查询中的 WHERE
        if (plainSelect.getWhere() != null) {
            plainSelect.getWhere().accept(new ExpressionVisitorAdapter() {
                @Override
                public void visit(SubSelect subSelect) {
                    processSelect(select, tenantId, mapperId);
                }
            });
        }
    }

    //=============== INSERT/UPDATE/DELETE 处理 ===============//
    private void processInsert(Insert insert, String tenantId, String mapperId) {
        String tableName = insert.getTable().getName().toLowerCase();
        if (shouldIgnoreTable(tableName, mapperId)) return;

        // 插入 tenant_id 列和值
        insert.getColumns().add(new Column(TENANT_COLUMN));
        ((ExpressionList) insert.getItemsList()).getExpressions().add(new StringValue(tenantId));
    }

    private void processUpdate(Update update, String tenantId) {
        String tableName = update.getTable().getName().toLowerCase();
        if (shouldIgnoreTable(tableName, null)) return;

        update.setWhere(buildTenantCondition(update.getWhere(), tenantId, tableName));
    }

    private void processDelete(Delete delete, String tenantId) {
        String tableName = delete.getTable().getName().toLowerCase();
        if (shouldIgnoreTable(tableName, null)) return;

        delete.setWhere(buildTenantCondition(delete.getWhere(), tenantId, tableName));
    }

    //=============== 核心工具方法 ===============//
    private void processFromItem(Select select,FromItem fromItem, String tenantId, String mapperId, PlainSelect plainSelect) {
        if (fromItem instanceof Table) {
            processTable((Table) fromItem, tenantId, mapperId, plainSelect);
        } else if (fromItem instanceof SubSelect) {
            processSubSelect(select,(SubSelect) fromItem, tenantId, mapperId);
        }
    }

    private void processSubSelect(Select select,SubSelect subSelect, String tenantId, String mapperId) {
        SelectBody subSelectBody = subSelect.getSelectBody();
        if (subSelectBody instanceof PlainSelect) {
            processPlainSelect(select,(PlainSelect) subSelectBody, tenantId, mapperId);
        }
    }


    private void processTable(Table table, String tenantId, String mapperId, PlainSelect plainSelect) {
        String tableName = table.getName().toLowerCase();
        String alias = Optional.ofNullable(table.getAlias()).map(net.sf.jsqlparser.expression.Alias::getName).orElse(tableName);

        if (!shouldIgnoreTable(tableName, mapperId)) {
            // 构建租户条件
            EqualsTo tenantCondition = new EqualsTo(
                new Column(alias + "." + TENANT_COLUMN),
                new StringValue(tenantId)
            );
            addWhereCondition(plainSelect, tenantCondition);
        }
    }

    private void addWhereCondition(PlainSelect plainSelect, Expression condition) {
        Expression existingWhere = plainSelect.getWhere();
        if (existingWhere == null) {
            plainSelect.setWhere(condition);
        } else {
            plainSelect.setWhere(new AndExpression(existingWhere, condition));
        }
    }

    private Expression buildTenantCondition(Expression where, String tenantId, String tableName) {
        EqualsTo tenantCondition = new EqualsTo(
            new Column(TENANT_COLUMN),
            new StringValue(tenantId)
        );
        return where == null ? tenantCondition : new AndExpression(where, tenantCondition);
    }

    private boolean shouldIgnoreTable(String tableName, String mapperId) {
        // 1. 检查实体类注解
        try {
            if (mapperId != null) {
                Class<?> entityClass = Class.forName(mapperId.substring(0, mapperId.lastIndexOf('.')));
                if (entityClass.isAnnotationPresent(TenantIgnore.class)) {
                    return true;
                }
            }
        } catch (ClassNotFoundException ignored) {}

        // 2. 检查配置忽略表
        return ignoreTables.contains(tableName.toLowerCase());
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        String tables = properties.getProperty("ignoreTables");
        if (tables != null) {
            ignoreTables.addAll(Arrays.asList(tables.toLowerCase().split(",")));
        }
    }


    /**
     * 用于构造新MappedStatement
     */
    public static class BoundSqlSqlSource implements SqlSource {
        BoundSql boundSql;

        public BoundSqlSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }
}

import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.event.EventListener;

@Configuration
public class TenantInterceptorConfig {

    @Autowired
    private SqlSessionFactory sqlSessionFactory;

    @EventListener(ApplicationReadyEvent.class) // 应用启动完成后执行
    public void addInterceptorAfterPageHelper() {
        TenantInterceptor tenantInterceptor = new TenantInterceptor();
        sqlSessionFactory.getConfiguration().addInterceptor(tenantInterceptor);
    }

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值