定义注解
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface RateLimiter {
/**
* 限流前缀
*/
String key() default "redis_limiter:";
/**
* 限流时间窗
*/
int time() default 10;
/**
* 时间窗内的限流次数
*/
int count() default 100;
/**
* 限流类型
*/
RateLimiterType limiterType() default RateLimiterType.DEFAULT;
}
aop进行注解的实现
@Slf4j
@Aspect
@Component
public class RateLimitAspect {
@Resource
StringRedisTemplate stringRedisTemplate;
private static final DefaultRedisScript<Long> SECKILL_SCRIPT;
static {
SECKILL_SCRIPT = new DefaultRedisScript<>();
SECKILL_SCRIPT.setLocation(new ClassPathResource("lua/rollingRateLimiter.lua"));
SECKILL_SCRIPT.setResultType(Long.class);
}
@Before("@annotation(rateLimiter)")
public void before(JoinPoint jp, RateLimiter rateLimiter){
int time = rateLimiter.time();
int count = rateLimiter.count();
String key = getKey(jp, rateLimiter);
// 当前时间
String now = String.valueOf(System.currentTimeMillis());
// 生成value值
String uuid = UUID.fastUUID().toString();
// 执行lua脚本
Long result = stringRedisTemplate.execute(
SECKILL_SCRIPT,
Arrays.asList(key, uuid, now),
String.valueOf(time * 1000), String.valueOf(count)
);
if (result == null || result.intValue() != 1) {
throw new RateLimitException("访问频繁,请稍后再试");
}
}
/**
* 生成key
*/
private String getKey(JoinPoint jp, RateLimiter rateLimiter) {
MethodSignature signature =(MethodSignature) jp.getSignature();
Method method = signature.getMethod();
StringBuilder key = new StringBuilder(rateLimiter.key());
if(rateLimiter.limiterType() == RateLimiterType.IP){
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
key.append(method.getDeclaringClass()).append(".").append(method.getName()).append(IPUtils.getIpRequest(request));
}else {
key.append(method.getDeclaringClass()).append(".").append(method.getName());
}
return key.toString();
}
}
lua脚本
-- 滑动窗口限流器
-- key=前缀:全限定方法名(IP策略还需拼接ip)
local zSetName = KEYS[1]
-- uuid
local uuid = KEYS[2]
-- 当前时间戳, 精确到毫秒
local now = tonumber(KEYS[3])
-- 统计的间隔(也就是滑动窗口大小)
local interval = tonumber(ARGV[1])
-- 每个间隔允许的操作数量
local maxInInterval = tonumber(ARGV[2])
-- 1 删除集合中距离当前时间超过interval的数据
redis.call('ZREMRANGEBYSCORE', zSetName, 0, now - interval)
-- 2 获取集合中元素的数量
local count = redis.call('ZCARD', zSetName)
-- 3 判断 count 是否大于 maxInInterval
if count >= maxInInterval then
-- 设置过期时间
redis.call('EXPIRE', zSetName, math.ceil(interval/1000))
-- 请求数量大于阈值 返回 -1
return -1
end
-- 4 将当前请求存入集合
redis.call('ZADD', zSetName, now, uuid)
-- 设置过期时间
redis.call('EXPIRE', zSetName, math.ceil(interval/1000))
-- 返回成功
return 1
异常定义及全局处理
public class RateLimitException extends RuntimeException{
public RateLimitException(String message) {
super(message);
}
}
@RestControllerAdvice
@Order(value = Ordered.HIGHEST_PRECEDENCE)
public class GlobalExceptionHandler {
@ExceptionHandler(RateLimitException.class)
@ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR)
public Result rateLimitException(RateLimitException e){
return new Result(500,e.getMessage());
}
}
思想:使用zset数据结构,score是时间戳,统计范围内的请求数决定是否放行
// 表示同一IP10s内最多只接受5次请求
@RateLimiter(time = 10,count = 5,limiterType = RateLimiterType.IP)
@RequestMapping(value = "/test",method = RequestMethod.GET)
public Result test(){
return Result.success("成功");
}