简易版接口防重实现(分布式锁)

一、前言

接口防重亦是分布式锁的一种特殊形式。市面上最典型的就是大名鼎鼎的Redisson实现的分布式锁。

但使用Redisson仅仅作为一个锁,未免大材小用,同时也显得臃肿。

因此,有必要实现一个类似功能且安全高效的、简单的分布式锁。

二、技术

利用Spring提供的切面技术,SpringEl表达式语言,JDK自定义注解,Redis及Lua等。

三、应用

1.自定义注解,用于标注使用到controller层的方法上

import java.lang.annotation.*;

/**
 * 接口防重标识,用于请求接口的controller方法上
 * <p>
 * 基于Spring AOP,Redis实现
 * <p>
 * 注:仅在Controller中的方法生效
 */
@Inherited
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface PreventDuplicate {

    /**
     * 锁超时释放时间(秒),默认180秒
     */
    int expire() default 180;

    /**
     * 是否使用IP作为防重key
     */
    boolean useIP() default false;

    /**
     * 是否使用session用户ID作为防重key
     */
    boolean useSession() default false;

    /**
     * <pre>
     * Spring表达式
     * 使用示例:
     * 1.存在方法 getUser(UserParam param),且UserParam存在属性userId,则用法为:"#对象名.属性名"
     *   如:expression = "#param.userId"
     * 2.存在方法 getUser(String userId),则用法为:"#参数变量名"
     *   如:expression = "#userId"
     * </pre>
     */
    String expression() default "";

    /**
     * 错误提示消息
     */
    String msg() default "业务正在处理中,请勿重复操作!";
}

2.定义切面逻辑处理

/**
 * <pre>
 * 接口防重切面
 * 锁的key取值优先级依次为(由高到低):
 * 1.方法注解的EL表达式{@link PreventDuplicate#expression()}
 * 2.请求参数实现 {@link IPreventParam}接口
 * 3.请求参数的属性使用注解{@link PreventKey}标注
 * </pre>
 */
@Slf4j
@Aspect
@Component
public class InterfacePreventDuplicateAspect {

    private final StringRedisTemplate redisTemplate;
    /**
     * 锁续期缓存集合(key:锁名称-线层ID;value:为锁超时时间)
     */
    private final static Map<String, Long> LOCKS = new HashMap<>();
    private final SpringExpressionEvaluator evaluator = new SpringExpressionEvaluator();
    private static final String LOCK_PREFIX = "PLOCK:";
    private static final String LOCK_LUA = "if redis.call('set', KEYS[1], ARGV[1], 'NX', 'EX', ARGV[2]) then return 1 else return 0 end";
    private static final String UNLOCK_LUA = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end";

    public InterfacePreventDuplicateAspect(StringRedisTemplate redisTemplate) {
        this.redisTemplate = redisTemplate;
        startTask();
    }

    @Around("@annotation(prevent)" +
            "&&(@annotation(org.springframework.web.bind.annotation.PostMapping)" +
            "||@annotation(org.springframework.web.bind.annotation.GetMapping)" +
            "||@annotation(org.springframework.web.bind.annotation.RequestMapping))")
    public Object aroundHandle(ProceedingJoinPoint pjp, PreventDuplicate prevent) throws Throwable {
        HttpServletRequest req = getRequest();
        // 不处理feign调用
        if (req == null || StringUtils.isNotEmpty(req.getHeader(Constant.FEIGN_REQ_HEADER))) {
            return pjp.proceed();
        }
        String lockKey = getLockKey(req, pjp, prevent);
        if (!tryLock(lockKey, prevent.expire())) {
            throw new RuntimeException(prevent.msg());
        }
        try {
            return pjp.proceed();
        } finally {
            unlock(lockKey);
        }
    }

    private boolean tryLock(String lockKey, long expire) {
        long tid = Thread.currentThread().getId();
        Long result = redisTemplate.execute(new DefaultRedisScript<>(LOCK_LUA, Long.class),
                Collections.singletonList(lockKey), String.valueOf(tid), String.valueOf(expire));
        if (result != null && result == 1) {
            LOCKS.putIfAbsent(getFullKey(lockKey, tid), expire);
            return true;
        }
        return false;
    }

    private void unlock(String lockKey) {
        long tid = Thread.currentThread().getId();
        Long result = redisTemplate.execute(new DefaultRedisScript<>(UNLOCK_LUA, Long.class),
                Collections.singletonList(lockKey), String.valueOf(tid));
        if (result != null && result > 0) {
            LOCKS.remove(getFullKey(lockKey, tid));
        }
    }

    private String getLockKey(HttpServletRequest req, ProceedingJoinPoint pjp, PreventDuplicate pd) {
        StringBuilder lockKey = new StringBuilder(LOCK_PREFIX).append(pjp.getSignature().getName());
        if (pd.useIP()) {
            lockKey.append(":").append(IpUtils.getRealIP(req));
        }
        if (pd.useSession()) {
            lockKey.append(":").append(req.getHeader("Authorization"));
        }
        lockKey.append(":").append(evaluator.getValue(pjp, pd.expression()));
        return lockKey.toString();
    }

    private void startTask() {
        Executors.newSingleThreadScheduledExecutor().scheduleAtFixedRate(() -> {
            Iterator<Map.Entry<String, Long>> iter = LOCKS.entrySet().iterator();
            while (iter.hasNext()) {
                Map.Entry<String, Long> e = iter.next();
                String lockKey = e.getKey().split("-")[0];
                Long expire = redisTemplate.getExpire(lockKey, TimeUnit.MILLISECONDS);
                // 防止调用unlock时未清除LOCKS中的变量,此处需要再做清除校验
                if (expire == null || expire <= 0) {
                    iter.remove();
                    continue;
                }
                // 若锁剩余超时时间小于锁总超时时间的一半,则进行续命
                if (expire < (e.getValue() * 500)) {
                    redisTemplate.expire(lockKey, e.getValue(), TimeUnit.SECONDS);
                }
            }
        }, 1, 1, TimeUnit.SECONDS);
    }

    private HttpServletRequest getRequest() {
        ServletRequestAttributes sra = (ServletRequestAttributes) RequestContextHolder
                .getRequestAttributes();
        return sra == null ? null : sra.getRequest();
    }

    private String getFullKey(String lockKey, long tid) {
        return lockKey + "-" + tid;
    }
}

3.原理说明

利用Redis Client执行Lua的原子特性,在判断Redis中是否存在当前的Key,如果存在,则抛出异常,表示当前请求的数据已被占用。

在执行完成后,清除Redis锁时,需要判断是否是当前执行线程的锁,是才会被删除。

另外在当前线程未执行完(即锁不能释放),所以需要定期给锁“续命”,以保证当前线层的任务能不受干扰执行完成。具体未,本地在获取锁成功后,将锁的Key及当前线程ID作为Key,锁超时时间作为Value缓存在本地,开起一个单线程,定期给Redis锁执行“续命”(锁超时时间的一半)。

使用实例

@RestController
public class OrderController {

    @PreventDuplicate(expression = "#order.id", msg = "当前订单业务正在处理中,请稍后!")
    @PostMapping("/handleOrder")
    public void handleOrder(Order order) {
        // TODO
    }
}

4.使用到的工具

import com.gynsh.constant.Constant;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;

import javax.servlet.http.HttpServletRequest;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;

@Slf4j
public abstract class IpUtils {

    /**
     * 如果网站设置了nginx反向代理
     * 需要在nginx配置文件加上如下配置才能获取用户真实ip
     * proxy_set_header Host $host;
     * proxy_set_header X-Real-IP $remote_addr;
     * proxy_set_header REMOTE-HOST $remote_addr;
     * proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
     */
    public static String getRealIP(HttpServletRequest request) {
        if (request == null) {
            return "127.0.0.1";
        }
        String ip = request.getHeader("x-forwarded-for");
        if (StringUtils.isNotEmpty(ip)) {
            // 多次反向代理后会有多个ip值,第一个ip才是真实ip
            ip = ip.split(",")[0];
        }
        if (isSpecialEmpty(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (isSpecialEmpty(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (isSpecialEmpty(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (isSpecialEmpty(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (isSpecialEmpty(ip)) {
            ip = request.getHeader("X-Real-IP");
        }
        if (isSpecialEmpty(ip)) {
            if (Constant.LOCALHOST_IPv4.equals(request.getRemoteAddr()) || Constant.LOCALHOST_IPv6.equals(request.getRemoteAddr())) {
                try {
                    // 根据网卡获取真实IP
                    ip = InetAddress.getLocalHost().getHostAddress();
                } catch (UnknownHostException e) {
                    log.error("UnknownHostException异常:{}", e.toString());
                }
            }
        }
        return isSpecialEmpty(ip) ? request.getRemoteAddr() : ip;
    }

    public static String getRealIP(ServerHttpRequest request) {
        if (request == null) {
            return "127.0.0.1";
        }
        HttpHeaders headers = request.getHeaders();
        String ip = headers.getFirst("x-forwarded-for");
        if (StringUtils.isNotEmpty(ip)) {
            // 多次反向代理后会有多个ip值,第一个ip才是真实ip
            ip = ip.split(",")[0];
        }
        if (isSpecialEmpty(ip)) {
            ip = headers.getFirst("Proxy-Client-IP");
        }
        if (isSpecialEmpty(ip)) {
            ip = headers.getFirst("WL-Proxy-Client-IP");
        }
        if (isSpecialEmpty(ip)) {
            ip = headers.getFirst("HTTP_CLIENT_IP");
        }
        if (isSpecialEmpty(ip)) {
            ip = headers.getFirst("HTTP_X_FORWARDED_FOR");
        }
        if (isSpecialEmpty(ip)) {
            ip = headers.getFirst("X-Real-IP");
        }
        if (isSpecialEmpty(ip)) {
            try {
                // 根据网卡获取真实IP
                ip = InetAddress.getLocalHost().getHostAddress();
            } catch (UnknownHostException e) {
                log.error("UnknownHostException异常:{}", e.toString());
            }
        }
        return ip;
    }

    static boolean isSpecialEmpty(String ip) {
        return StringUtils.isEmpty(ip) || Constant.UNKNOWN.equalsIgnoreCase(ip);
    }

    public static String getLocalIp() {
        try {
            return InetAddress.getLocalHost().getHostAddress();
        } catch (Exception ignored) {
        }
        return null;
    }

    @Getter
    private static final List<String> localIps = new ArrayList<>();

    public static void initLocalIps() {
        try {
            Enumeration<NetworkInterface> networkInterfaces = NetworkInterface.getNetworkInterfaces();
            while (networkInterfaces.hasMoreElements()) {
                NetworkInterface networkInterface = networkInterfaces.nextElement();
                Enumeration<InetAddress> inetAddresses = networkInterface.getInetAddresses();
                while (inetAddresses.hasMoreElements()) {
                    InetAddress inetAddress = inetAddresses.nextElement();
                    localIps.add(inetAddress.getHostAddress());
                }
            }
        } catch (Exception e) {
            log.warn("a exception occur when take local ips", e);
        }
    }

    static {
        initLocalIps();
    }
}
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.aop.support.AopUtils;
import org.springframework.context.expression.AnnotatedElementKey;
import org.springframework.context.expression.CachedExpressionEvaluator;
import org.springframework.context.expression.MethodBasedEvaluationContext;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;

import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;

/**
 * EL表达式解析
 */
public class SpringExpressionEvaluator extends CachedExpressionEvaluator {
    private final ParameterNameDiscoverer paramNameDiscoverer = new DefaultParameterNameDiscoverer();
    private final Map<ExpressionKey, Expression> conditionCache = new ConcurrentHashMap<>(128);
    private final Map<AnnotatedElementKey, Method> targetMethodCache = new ConcurrentHashMap<>(128);
    private final static Pattern EL_PATTERN = Pattern.compile("^#.+$");

    /**
     * 获取Spring ElExpression表达式值
     *
     * @param jp      - 切点
     * @param express - spring el表达式
     * @return - value
     */
    public String getValue(ProceedingJoinPoint jp, String express) {
        EvaluationContext context = this.createEvaluationContext(jp);
        AnnotatedElementKey elementKey = this.createAnnotatedElementKey(((MethodSignature) jp.getSignature()).getMethod(), jp.getTarget().getClass());
        return this.getValue(express, elementKey, context);
    }

    /**
     * 创建EvaluationContext上下文
     *
     * @param joinPoint
     * @return
     */
    private EvaluationContext createEvaluationContext(JoinPoint joinPoint) {
        return createEvaluationContext(joinPoint.getTarget(), joinPoint.getTarget().getClass(), ((MethodSignature) joinPoint.getSignature()).getMethod(), joinPoint.getArgs());
    }

    /**
     * 创建EvaluationContext上下文
     *
     * @param object
     * @param targetClass
     * @param method
     * @param args
     * @return
     */
    private EvaluationContext createEvaluationContext(Object object, Class<?> targetClass, Method method, Object[] args) {
        Method targetMethod = getTargetMethod(targetClass, method);
        ExpressionRootObject root = new ExpressionRootObject(object, args);
        return new MethodBasedEvaluationContext(root, targetMethod, args, this.paramNameDiscoverer);
    }

    private AnnotatedElementKey createAnnotatedElementKey(Method method, Class<?> targetClass) {
        return new AnnotatedElementKey(method, targetClass);
    }

    /**
     * 获取Spring ElExpression表达式值
     *
     * @param conditionExpression - EL表达式
     * @param elementKey
     * @param evalContext
     * @return
     */
    private String getValue(String conditionExpression, AnnotatedElementKey elementKey, EvaluationContext evalContext) {
        if (!EL_PATTERN.matcher(conditionExpression).matches()) {
            return conditionExpression;
        }
        return getExpression(this.conditionCache, elementKey, conditionExpression).getValue(evalContext, String.class);
    }

    private Method getTargetMethod(Class<?> targetClass, Method method) {
        AnnotatedElementKey methodKey = new AnnotatedElementKey(method, targetClass);
        Method targetMethod = this.targetMethodCache.get(methodKey);
        if (targetMethod == null) {
            targetMethod = AopUtils.getMostSpecificMethod(method, targetClass);
            this.targetMethodCache.put(methodKey, targetMethod);
        }
        return targetMethod;
    }

    @Setter
    @Getter
    @AllArgsConstructor
    static class ExpressionRootObject {
        private Object object;
        private Object[] args;
    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

流沙QS

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值