一、自定义限流注解
package com.jiuqi.std.db.bean;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* RateLimiter
*
* @author haijian.li
* @version 1.0
* 2022/11/22 21:37
**/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface RateLimiter {
/**
* 限流前缀
* @return
*/
String key() default "rate_limiter";
/**
* 限流时间窗
*/
int time() default 60;
/**
* 时间窗内的限流次数
* @return
*/
int count() default 100;
/**
* 限流类型
* @return
*/
RateLimiterType limiterType() default RateLimiterType.DEFAULT;
}
public enum RateLimiterType {
/**
* 默认类型
*/
DEFAULT,
/**
* 针对ip限流
*/
IP
}
二、配置redis序列化和lua脚本
@Configuration
public class RedisConfig {
@Bean("test")
RedisTemplate<Object,Object> template(RedisConnectionFactory factory){
RedisTemplate<Object,Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(factory);
Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
redisTemplate.setKeySerializer(serializer);
redisTemplate.setValueSerializer(serializer);
redisTemplate.setHashKeySerializer(serializer);
redisTemplate.setHashValueSerializer(serializer);
return redisTemplate;
}
@Bean
DefaultRedisScript<Long> limitScript(){
DefaultRedisScript<Long> script = new DefaultRedisScript<>();
script.setResultType(Long.class);
script.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
return script;
}
}
limit.lua
local key = KEYS[1]
local time = tonumber(ARGV[1])
local count = tonumber(ARGV[2])
local current = redis.call('get',key)
if current and tonumber(current) > count then
return tonumber(current)
end
current = redis.call('incr',key)
if tonumber(current) == 1 then
redis.call('expire',key,time)
end
return tonumber(current)
三、定义全局异常
public class RateLimitException extends RuntimeException{
public RateLimitException(String message) {
super(message);
}
}
@RestControllerAdvice
public class GlobalException {
@ExceptionHandler(RateLimitException.class)
public Map<String,Object> rateLimitException(RateLimitException e){
Map<String,Object> map = new HashMap<>(2);
map.put("status",500);
map.put("message",e.getMessage());
return map;
}
}
四、定义切面
@Aspect
@Component
public class RateLimitAspectj {
@Autowired
@Qualifier("test")
RedisTemplate<Object,Object> redisTemplate;
@Autowired
RedisScript<Long> redisScript;
public static final Logger logger = LoggerFactory.getLogger(RateLimitAspectj.class);
@Before("@annotation(rateLimiter)")
public void before(JoinPoint jp, RateLimiter rateLimiter){
int time = rateLimiter.time();
int count = rateLimiter.count();
String key = getKey(jp,rateLimiter);
try {
Long number = redisTemplate.execute(redisScript, Collections.singletonList(key), time, count);
if(number == null || number.intValue() > count){
throw new RateLimitException("访问频繁,请稍后再试");
}
logger.info("一个时间窗内请求次数:{},当前请求次数:{},缓存key是:{}",count,number,key);
} catch (Exception e) {
throw e;
}
}
/**
* 获取redis中缓存key
* @param jp
* @param rateLimiter
* @return
*/
private String getKey(JoinPoint jp, RateLimiter rateLimiter) {
StringBuffer key = new StringBuffer(rateLimiter.key());
if(rateLimiter.limiterType() == RateLimiterType.IP){
key.append(LogUtil.getIpAddr(((ServletRequestAttributes)RequestContextHolder.getRequestAttributes()).getRequest()))
.append("-");
}else {
MethodSignature signature =(MethodSignature) jp.getSignature();
Method method = signature.getMethod();
key.append(method.getDeclaringClass())
.append("-")
.append(method.getName());
}
return key.toString();
}
}