spring boot 启动校验controller 接口方法一致性校验器

package com.runyan.tg.common.web.check;

import cn.hutool.core.date.StopWatch;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.reflections.Reflections;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;

import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.*;
import java.util.stream.Collectors;


/**
 * Controller 与接口方法一致性校验器
 * 支持配置化包扫描和白名单排除机制
 */
@Component
@ConfigurationProperties(prefix = "yy.web.check")
@Data
@Slf4j
public class ControllerParameterAnnotationChecker implements InitializingBean {

    /**
     * 是否启用参数校验
     */
    private boolean enableControllerParameterCheck = true;

    /**
     * 需要扫描的包路径
     */
    private List<String> scanPackages = Arrays.asList("扫描包路径");

    /**
     * 白名单:跳过检查的类或接口全限定名
     */
    private Set<String> excludeClasses = new HashSet<>();

    @Override
    public void afterPropertiesSet() {
        if (!enableControllerParameterParameterCheck) {
            log.info("接口参数校验已禁用");
            return;
        }

        StopWatch stopWatch = new StopWatch();
        stopWatch.start();

        Set<Class<?>> allClasses = new HashSet<>();

        for (String packageName : scanPackages) {
            Reflections reflections = new Reflections(packageName);
            Set<Class<?>> controllerClasses = getAnnotatedClasses(reflections);
            allClasses.addAll(controllerClasses);
        }

        List<Class<?>> sortedClasses = new ArrayList<>(allClasses);
        sortedClasses.sort(Comparator.comparing(Class::getName));

        // 过滤白名单中的类
        List<Class<?>> filteredClasses = sortedClasses.stream()
                .filter(clazz -> !excludeClasses.contains(clazz.getName()))
                .collect(Collectors.toList());

        checkClasses(filteredClasses);

        stopWatch.stop();
        log.info("微服务接口参数检查完成,耗时 {}ms", stopWatch.getTotalTimeMillis());
    }

    /**
     * 获取带有 @Controller 或 @RestController 注解的类
     */
    private Set<Class<?>> getAnnotatedClasses(Reflections reflections) {
        Set<Class<?>> result = new HashSet<>();
        result.addAll(reflections.getTypesAnnotatedWith(org.springframework.stereotype.Controller.class));
        result.addAll(reflections.getTypesAnnotatedWith(org.springframework.web.bind.annotation.RestController.class));
        return result;
    }

    /**
     * 校验每个 Controller 类实现且仅实现一个接口,并检查方法签名一致性
     */
    private void checkClasses(List<Class<?>> classes) {
        for (Class<?> clazz : classes) {
            if (clazz.getInterfaces().length != 1) {
                throw new IllegalStateException(
                        String.format("类 [%s] 必须且只能实现一个接口,当前实现了 %d 个",
                                clazz.getName(), clazz.getInterfaces().length));
            }

            Class<?> remoteInterface = clazz.getInterfaces()[0];
            Map<String, String> interfaceMethods = extractMethodMap(remoteInterface);
            Map<String, String> controllerMethods = extractMethodMap(clazz);

            boolean isEqual = compareMaps(interfaceMethods, controllerMethods);
            if (!isEqual) {
                throw new IllegalStateException(
                        String.format("类 [%s] 和接口 [%s] 方法签名不一致",
                                clazz.getName(), remoteInterface.getName()));
            }
        }
    }

    /**
     * 构建方法名到 URL 映射的 Map
     */
    private Map<String, String> extractMethodMap(Class<?> clazz) {
        Map<String, String> methodMap = new HashMap<>();
        Method[] methods = clazz.getDeclaredMethods();

        for (Method method : methods) {
            if (isLambdaOrSynthetic(method)) continue;

            if (!method.isAnnotationPresent(PostMapping.class)) {
                throw new IllegalStateException(
                        String.format("方法 [%s] 必须使用 @PostMapping 注解", method.getName()));
            }

            Parameter[] parameters = method.getParameters();
            for (Parameter parameter : parameters) {
                if (!parameter.isAnnotationPresent(RequestBody.class) &&
                    !parameter.isAnnotationPresent(PathVariable.class)) {
                    throw new IllegalStateException(
                            String.format("方法 [%s] 参数 [%s] 缺少 @RequestBody 或 @PathVariable 注解",
                                    method.getName(), parameter.getName()));
                }
            }

            PostMapping annotation = method.getAnnotation(PostMapping.class);
            String[] value = annotation.value();
            if (value.length != 1) {
                throw new IllegalStateException(
                        String.format("方法 [%s] 的 @PostMapping 必须指定一个唯一路径", method.getName()));
            }

            methodMap.put(method.getName(), value[0]);
        }

        return methodMap;
    }

    /**
     * 判断是否为 Lambda 表达式或合成方法
     */
    private boolean isLambdaOrSynthetic(Method method) {
        return method.getName().contains("$deserializeLambda$") || method.getName().contains("lambda$");
    }

    /**
     * 比较两个 Map 是否完全一致
     */
    public static boolean compareMaps(Map<String, String> map1, Map<String, String> map2) {
        if (map1 == null || map2 == null) {
            return false;
        }

        if (map1.size() != map2.size()) {
            log.warn("Map size 不一致:{} vs {}", map1.size(), map2.size());
            return false;
        }

        for (Map.Entry<String, String> entry : map1.entrySet()) {
            String key = entry.getKey();
            String value = entry.getValue();
            if (!map2.containsKey(key) || !map2.get(key).equals(value)) {
                log.warn("键值对不匹配:{}={} vs {}", key, value, map2.get(key));
                return false;
            }
        }

        for (Map.Entry<String, String> entry : map2.entrySet()) {
            String key = entry.getKey();
            String value = entry.getValue();
            if (!map1.containsKey(key) || !map1.get(key).equals(value)) {
                log.warn("键值对不匹配:{}={} vs {}", key, value, map1.get(key));
                return false;
            }
        }

        return true;
    }
}

nacos配置

yy:
  web:
    check:
      enable-controller-parameter-check: true
      scan-packages:
        - com.yy.mp
        - com.yy.order
      exclude-classes:
        - com.yy.api.controller.DemoController
        - com.yy.api.UserService

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值