一、前言
接口防重亦是分布式锁的一种特殊形式。市面上最典型的就是大名鼎鼎的Redisson实现的分布式锁。
但使用Redisson仅仅作为一个锁,未免大材小用,同时也显得臃肿。
因此,有必要实现一个类似功能且安全高效的、简单的分布式锁。
二、技术
利用Spring提供的切面技术,SpringEl表达式语言,JDK自定义注解,Redis及Lua等。
三、应用
1.自定义注解,用于标注使用到controller层的方法上
import java.lang.annotation.*;
/**
* 接口防重标识,用于请求接口的controller方法上
* <p>
* 基于Spring AOP,Redis实现
* <p>
* 注:仅在Controller中的方法生效
*/
@Inherited
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface PreventDuplicate {
/**
* 锁超时释放时间(秒),默认180秒
*/
int expire() default 180;
/**
* 是否使用IP作为防重key
*/
boolean useIP() default false;
/**
* 是否使用session用户ID作为防重key
*/
boolean useSession() default false;
/**
* <pre>
* Spring表达式
* 使用示例:
* 1.存在方法 getUser(UserParam param),且UserParam存在属性userId,则用法为:"#对象名.属性名"
* 如:expression = "#param.userId"
* 2.存在方法 getUser(String userId),则用法为:"#参数变量名"
* 如:expression = "#userId"
* </pre>
*/
String expression() default "";
/**
* 错误提示消息
*/
String msg() default "业务正在处理中,请勿重复操作!";
}
2.定义切面逻辑处理
/**
* <pre>
* 接口防重切面
* 锁的key取值优先级依次为(由高到低):
* 1.方法注解的EL表达式{@link PreventDuplicate#expression()}
* 2.请求参数实现 {@link IPreventParam}接口
* 3.请求参数的属性使用注解{@link PreventKey}标注
* </pre>
*/
@Slf4j
@Aspect
@Component
public class InterfacePreventDuplicateAspect {
private final StringRedisTemplate redisTemplate;
/**
* 锁续期缓存集合(key:锁名称-线层ID;value:为锁超时时间)
*/
private final static Map<String, Long> LOCKS = new HashMap<>();
private final SpringExpressionEvaluator evaluator = new SpringExpressionEvaluator();
private static final String LOCK_PREFIX = "PLOCK:";
private static final String LOCK_LUA = "if redis.call('set', KEYS[1], ARGV[1], 'NX', 'EX', ARGV[2]) then return 1 else return 0 end";
private static final String UNLOCK_LUA = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end";
public InterfacePreventDuplicateAspect(StringRedisTemplate redisTemplate) {
this.redisTemplate = redisTemplate;
startTask();
}
@Around("@annotation(prevent)" +
"&&(@annotation(org.springframework.web.bind.annotation.PostMapping)" +
"||@annotation(org.springframework.web.bind.annotation.GetMapping)" +
"||@annotation(org.springframework.web.bind.annotation.RequestMapping))")
public Object aroundHandle(ProceedingJoinPoint pjp, PreventDuplicate prevent) throws Throwable {
HttpServletRequest req = getRequest();
// 不处理feign调用
if (req == null || StringUtils.isNotEmpty(req.getHeader(Constant.FEIGN_REQ_HEADER))) {
return pjp.proceed();
}
String lockKey = getLockKey(req, pjp, prevent);
if (!tryLock(lockKey, prevent.expire())) {
throw new RuntimeException(prevent.msg());
}
try {
return pjp.proceed();
} finally {
unlock(lockKey);
}
}
private boolean tryLock(String lockKey, long expire) {
long tid = Thread.currentThread().getId();
Long result = redisTemplate.execute(new DefaultRedisScript<>(LOCK_LUA, Long.class),
Collections.singletonList(lockKey), String.valueOf(tid), String.valueOf(expire));
if (result != null && result == 1) {
LOCKS.putIfAbsent(getFullKey(lockKey, tid), expire);
return true;
}
return false;
}
private void unlock(String lockKey) {
long tid = Thread.currentThread().getId();
Long result = redisTemplate.execute(new DefaultRedisScript<>(UNLOCK_LUA, Long.class),
Collections.singletonList(lockKey), String.valueOf(tid));
if (result != null && result > 0) {
LOCKS.remove(getFullKey(lockKey, tid));
}
}
private String getLockKey(HttpServletRequest req, ProceedingJoinPoint pjp, PreventDuplicate pd) {
StringBuilder lockKey = new StringBuilder(LOCK_PREFIX).append(pjp.getSignature().getName());
if (pd.useIP()) {
lockKey.append(":").append(IpUtils.getRealIP(req));
}
if (pd.useSession()) {
lockKey.append(":").append(req.getHeader("Authorization"));
}
lockKey.append(":").append(evaluator.getValue(pjp, pd.expression()));
return lockKey.toString();
}
private void startTask() {
Executors.newSingleThreadScheduledExecutor().scheduleAtFixedRate(() -> {
Iterator<Map.Entry<String, Long>> iter = LOCKS.entrySet().iterator();
while (iter.hasNext()) {
Map.Entry<String, Long> e = iter.next();
String lockKey = e.getKey().split("-")[0];
Long expire = redisTemplate.getExpire(lockKey, TimeUnit.MILLISECONDS);
// 防止调用unlock时未清除LOCKS中的变量,此处需要再做清除校验
if (expire == null || expire <= 0) {
iter.remove();
continue;
}
// 若锁剩余超时时间小于锁总超时时间的一半,则进行续命
if (expire < (e.getValue() * 500)) {
redisTemplate.expire(lockKey, e.getValue(), TimeUnit.SECONDS);
}
}
}, 1, 1, TimeUnit.SECONDS);
}
private HttpServletRequest getRequest() {
ServletRequestAttributes sra = (ServletRequestAttributes) RequestContextHolder
.getRequestAttributes();
return sra == null ? null : sra.getRequest();
}
private String getFullKey(String lockKey, long tid) {
return lockKey + "-" + tid;
}
}
3.原理说明
利用Redis Client执行Lua的原子特性,在判断Redis中是否存在当前的Key,如果存在,则抛出异常,表示当前请求的数据已被占用。
在执行完成后,清除Redis锁时,需要判断是否是当前执行线程的锁,是才会被删除。
另外在当前线程未执行完(即锁不能释放),所以需要定期给锁“续命”,以保证当前线层的任务能不受干扰执行完成。具体未,本地在获取锁成功后,将锁的Key及当前线程ID作为Key,锁超时时间作为Value缓存在本地,开起一个单线程,定期给Redis锁执行“续命”(锁超时时间的一半)。
使用实例
@RestController
public class OrderController {
@PreventDuplicate(expression = "#order.id", msg = "当前订单业务正在处理中,请稍后!")
@PostMapping("/handleOrder")
public void handleOrder(Order order) {
// TODO
}
}
4.使用到的工具
import com.gynsh.constant.Constant;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import javax.servlet.http.HttpServletRequest;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
@Slf4j
public abstract class IpUtils {
/**
* 如果网站设置了nginx反向代理
* 需要在nginx配置文件加上如下配置才能获取用户真实ip
* proxy_set_header Host $host;
* proxy_set_header X-Real-IP $remote_addr;
* proxy_set_header REMOTE-HOST $remote_addr;
* proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
*/
public static String getRealIP(HttpServletRequest request) {
if (request == null) {
return "127.0.0.1";
}
String ip = request.getHeader("x-forwarded-for");
if (StringUtils.isNotEmpty(ip)) {
// 多次反向代理后会有多个ip值,第一个ip才是真实ip
ip = ip.split(",")[0];
}
if (isSpecialEmpty(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (isSpecialEmpty(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (isSpecialEmpty(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (isSpecialEmpty(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (isSpecialEmpty(ip)) {
ip = request.getHeader("X-Real-IP");
}
if (isSpecialEmpty(ip)) {
if (Constant.LOCALHOST_IPv4.equals(request.getRemoteAddr()) || Constant.LOCALHOST_IPv6.equals(request.getRemoteAddr())) {
try {
// 根据网卡获取真实IP
ip = InetAddress.getLocalHost().getHostAddress();
} catch (UnknownHostException e) {
log.error("UnknownHostException异常:{}", e.toString());
}
}
}
return isSpecialEmpty(ip) ? request.getRemoteAddr() : ip;
}
public static String getRealIP(ServerHttpRequest request) {
if (request == null) {
return "127.0.0.1";
}
HttpHeaders headers = request.getHeaders();
String ip = headers.getFirst("x-forwarded-for");
if (StringUtils.isNotEmpty(ip)) {
// 多次反向代理后会有多个ip值,第一个ip才是真实ip
ip = ip.split(",")[0];
}
if (isSpecialEmpty(ip)) {
ip = headers.getFirst("Proxy-Client-IP");
}
if (isSpecialEmpty(ip)) {
ip = headers.getFirst("WL-Proxy-Client-IP");
}
if (isSpecialEmpty(ip)) {
ip = headers.getFirst("HTTP_CLIENT_IP");
}
if (isSpecialEmpty(ip)) {
ip = headers.getFirst("HTTP_X_FORWARDED_FOR");
}
if (isSpecialEmpty(ip)) {
ip = headers.getFirst("X-Real-IP");
}
if (isSpecialEmpty(ip)) {
try {
// 根据网卡获取真实IP
ip = InetAddress.getLocalHost().getHostAddress();
} catch (UnknownHostException e) {
log.error("UnknownHostException异常:{}", e.toString());
}
}
return ip;
}
static boolean isSpecialEmpty(String ip) {
return StringUtils.isEmpty(ip) || Constant.UNKNOWN.equalsIgnoreCase(ip);
}
public static String getLocalIp() {
try {
return InetAddress.getLocalHost().getHostAddress();
} catch (Exception ignored) {
}
return null;
}
@Getter
private static final List<String> localIps = new ArrayList<>();
public static void initLocalIps() {
try {
Enumeration<NetworkInterface> networkInterfaces = NetworkInterface.getNetworkInterfaces();
while (networkInterfaces.hasMoreElements()) {
NetworkInterface networkInterface = networkInterfaces.nextElement();
Enumeration<InetAddress> inetAddresses = networkInterface.getInetAddresses();
while (inetAddresses.hasMoreElements()) {
InetAddress inetAddress = inetAddresses.nextElement();
localIps.add(inetAddress.getHostAddress());
}
}
} catch (Exception e) {
log.warn("a exception occur when take local ips", e);
}
}
static {
initLocalIps();
}
}
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.aop.support.AopUtils;
import org.springframework.context.expression.AnnotatedElementKey;
import org.springframework.context.expression.CachedExpressionEvaluator;
import org.springframework.context.expression.MethodBasedEvaluationContext;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
/**
* EL表达式解析
*/
public class SpringExpressionEvaluator extends CachedExpressionEvaluator {
private final ParameterNameDiscoverer paramNameDiscoverer = new DefaultParameterNameDiscoverer();
private final Map<ExpressionKey, Expression> conditionCache = new ConcurrentHashMap<>(128);
private final Map<AnnotatedElementKey, Method> targetMethodCache = new ConcurrentHashMap<>(128);
private final static Pattern EL_PATTERN = Pattern.compile("^#.+$");
/**
* 获取Spring ElExpression表达式值
*
* @param jp - 切点
* @param express - spring el表达式
* @return - value
*/
public String getValue(ProceedingJoinPoint jp, String express) {
EvaluationContext context = this.createEvaluationContext(jp);
AnnotatedElementKey elementKey = this.createAnnotatedElementKey(((MethodSignature) jp.getSignature()).getMethod(), jp.getTarget().getClass());
return this.getValue(express, elementKey, context);
}
/**
* 创建EvaluationContext上下文
*
* @param joinPoint
* @return
*/
private EvaluationContext createEvaluationContext(JoinPoint joinPoint) {
return createEvaluationContext(joinPoint.getTarget(), joinPoint.getTarget().getClass(), ((MethodSignature) joinPoint.getSignature()).getMethod(), joinPoint.getArgs());
}
/**
* 创建EvaluationContext上下文
*
* @param object
* @param targetClass
* @param method
* @param args
* @return
*/
private EvaluationContext createEvaluationContext(Object object, Class<?> targetClass, Method method, Object[] args) {
Method targetMethod = getTargetMethod(targetClass, method);
ExpressionRootObject root = new ExpressionRootObject(object, args);
return new MethodBasedEvaluationContext(root, targetMethod, args, this.paramNameDiscoverer);
}
private AnnotatedElementKey createAnnotatedElementKey(Method method, Class<?> targetClass) {
return new AnnotatedElementKey(method, targetClass);
}
/**
* 获取Spring ElExpression表达式值
*
* @param conditionExpression - EL表达式
* @param elementKey
* @param evalContext
* @return
*/
private String getValue(String conditionExpression, AnnotatedElementKey elementKey, EvaluationContext evalContext) {
if (!EL_PATTERN.matcher(conditionExpression).matches()) {
return conditionExpression;
}
return getExpression(this.conditionCache, elementKey, conditionExpression).getValue(evalContext, String.class);
}
private Method getTargetMethod(Class<?> targetClass, Method method) {
AnnotatedElementKey methodKey = new AnnotatedElementKey(method, targetClass);
Method targetMethod = this.targetMethodCache.get(methodKey);
if (targetMethod == null) {
targetMethod = AopUtils.getMostSpecificMethod(method, targetClass);
this.targetMethodCache.put(methodKey, targetMethod);
}
return targetMethod;
}
@Setter
@Getter
@AllArgsConstructor
static class ExpressionRootObject {
private Object object;
private Object[] args;
}
}