想法
一个流量限制器,使用chatgpt生成后优化代码
一开始想用在controller上,后来想用到内部服务模块上,尤其是阿里云/腾讯云的接口
优化了一下,可以用在任何方法上
chatgpt提问过程放在最后
直接上代码
@Slf4j
@Aspect
@Component
@RequiredArgsConstructor
public class TpsLimitAspect {
private final CacheService cacheService;
@Before("@annotation(tpsLimit)")
public void tpsLimitAdvice(JoinPoint joinPoint, TpsLimit tpsLimit) throws ServiceException {
if (Objects.isNull(tpsLimit)) {
throw new ServiceException();
}
//限制业务key
byte[] rateLimiterKey = tpsLimit.rateLimiterKey().getBytes();
//限制时间
long rateLimitPeriod = tpsLimit.rateLimitPeriod();
//限制令牌数
long rateLimitTokens = tpsLimit.rateLimitTokens();
//当前时间
long currentTimeMillis = System.currentTimeMillis();
//雪花算法生成key
String key = String.valueOf(SnowFlakeUtil.getFlowIdInstance().nextId());
// 使用管道执行Redis命令,实现原子操作
// 使用有序集合,排序方式为时间戳
// 每次将时间戳添加到集合中,然后删除 当前时间减去限制时间 之前的所有数据
// 统计数量,当前数量即为限制时间内的总数
// 判断数量是否符合令牌数要求,如果不符合则把本次添加的时间戳删除,防止相对限制时间内的token被占用
Boolean currentTime = cacheService.execute(connection -> {
connection.openPipeline();
RedisZSetCommands redisZSetCommands = connection.zSetCommands();
//添加请求数据
redisZSetCommands.zAdd(rateLimiterKey, currentTimeMillis, key.getBytes());
//移除小于当前时间减去时效区间内的值
redisZSetCommands.zRemRangeByScore(rateLimiterKey, 0, currentTimeMillis - (rateLimitPeriod * 1000));
//统计总数
redisZSetCommands.zCard(rateLimiterKey);
List<Object> results = connection.closePipeline();
Long count = (Long) results.get(2);
if (Objects.nonNull(count) && count <= rateLimitTokens) {
log.debug("key: {}, limitCount: {}, result: {}, time: {}", key, count, true, LocalDateUtils.getCurrentTimeFormat(LocalDateUtils.FORMAT_MMSS));
return true;
}
log.debug("key: {}, limitCount: {}, result: {}, time: {}", key, count, false, LocalDateUtils.getCurrentTimeFormat(LocalDateUtils.FORMAT_MMSS));
//移除本次请求,防止相对限制时间内的token被占用
redisZSetCommands.zRem(rateLimiterKey, key.getBytes());
return Boolean.FALSE;
});
if (Boolean.FALSE.equals(currentTime)) {
throw new ServiceException("接口限流了");
}
}
}
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface TpsLimit {
String rateLimiterKey();
// 限流周期,单位:秒
long rateLimitPeriod();
// 限流周期内可用令牌数量
long rateLimitTokens();
}
chatgpt聊天过程,可以略过
用springboot实现一个功能,限制某个请求的tps为每秒200次.
当每次有请求时,我会请求外部的接口,这个接口的限制为每秒最多200次.
功能需要考虑分布式
您可以使用Spring Boot实现以下功能来限制某个请求的TPS为每秒200次,并同时考虑分布式方案。
-
创建一个Spring Boot项目,并添加所需的依赖。
-
在你的配置文件(例如application.properties)中,设置外部接口的URL。
-
创建一个自定义注解,用于标记需要限制TPS的请求方法。
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface TpsLimit {
}
- 创建一个切面类,在该类中处理每个请求,并进行TPS限制。
@Aspect
@Component
public class TpsLimitAspect {
private static final int MAX_TPS = 200; // 每秒最多允许的请求数量
private Map<String, Long> requestCounts = new ConcurrentHashMap<>(); // 用来记录每个请求的时间戳
@Pointcut("@annotation(com.example.TpsLimit)") // 使用自定义注解作为切点
public void tpsLimitPointcut() {
}
@Around("tpsLimitPointcut()")
public Object tpsLimitAdvice(ProceedingJoinPoint joinPoint) throws Throwable {
String requestKey = getRequestKey(joinPoint); // 获取请求的唯一标识,可以使用请求的URL或方法名等
// 获取当前时间戳
long now = System.currentTimeMillis();
// 计算一秒内的请求数量
long count = requestCounts.entrySet().stream()
.filter(entry -> now - entry.getValue() < 1000) // 一秒内的请求
.count();
// 如果请求数量超过限制,返回错误信息
if (count >= MAX_TPS) {
return "Too many requests, please try again later";
}
// 记录当前请求的时间戳
requestCounts.put(requestKey, now);
// 继续执行原方法
return joinPoint.proceed();
}
// 根据joinPoint获取请求的唯一标识
private String getRequestKey(ProceedingJoinPoint joinPoint) {
// 这里可以根据实际需求,使用请求的URL或方法名等作为唯一标识
// 例如可以使用request.getRequestURI()获取URL
return joinPoint.getSignature().toShortString();
}
}
- 在需要限制TPS的请求方法上添加自定义注解。
@RestController
public class MyController {
@GetMapping("/limit")
@TpsLimit // 添加自定义注解
public String limitRequest() {
// 处理请求逻辑
return "Request handled successfully";
}
}
- 使用Spring Boot分布式限流方案,例如使用Redis进行分布式计数。
@Configuration
@EnableAspectJAutoProxy
public class TpsLimitConfig {
@Bean
public RedisConnectionFactory redisConnectionFactory() {
// 配置Redis连接工厂
return new JedisConnectionFactory();
}
@Bean
public RedisTemplate<String, Long> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
// 配置RedisTemplate
RedisTemplate<String, Long> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(redisConnectionFactory);
redisTemplate.setKeySerializer(new StringRedisSerializer());
redisTemplate.setValueSerializer(new GenericToStringSerializer<>(Long.class));
return redisTemplate;
}
@Bean
public TpsLimitAspect tpsLimitAspect(RedisTemplate<String, Long> redisTemplate) {
// 创建TpsLimitAspect,并将RedisTemplate注入
return new TpsLimitAspect(redisTemplate);
}
}
- 在切面类中使用RedisTemplate来进行分布式计数。
@Aspect
@Component
public class TpsLimitAspect {
private static final int MAX_TPS = 200; // 每秒最多允许的请求数量
private StringRedisTemplate redisTemplate; // 使用StringRedisTemplate进行分布式计数
public TpsLimitAspect(StringRedisTemplate redisTemplate) {
this.redisTemplate = redisTemplate;
}
@Pointcut("@annotation(com.example.TpsLimit)") // 使用自定义注解作为切点
public void tpsLimitPointcut() {
}
@Around("tpsLimitPointcut()")
public Object tpsLimitAdvice(ProceedingJoinPoint joinPoint) throws Throwable {
String requestKey = getRequestKey(joinPoint); // 获取请求的唯一标识
// 获取当前时间戳
long now = System.currentTimeMillis();
// 设置Redis的Key和Value
String redisKey = "tpsLimit:" + requestKey;
String redisValue = Long.toString(now);
// 使用Redis原子操作,将当前时间戳写入到Redis
String result = redisTemplate.execute(new RedisCallback<String>() {
@Override
public String doInRedis(RedisConnection connection) throws DataAccessException {
Object nativeConnection = connection.getNativeConnection();
if (nativeConnection instanceof Jedis) {
Jedis jedis = (Jedis) nativeConnection;
// NX表示只在Key不存在时进行操作
// EX表示设置Key的过期时间为1秒
String reply = jedis.set(redisKey, redisValue, "NX", "EX", 1);
return reply;
}
return null;
}
});
// 如果Redis操作失败(表示该请求已经存在),返回错误信息
if (!"OK".equals(result)) {
return "Too many requests, please try again later";
}
// 使用Redis原子操作,计算一秒内的请求数量
long count = redisTemplate.execute(new RedisCallback<Long>() {
@Override
public Long doInRedis(RedisConnection connection) throws DataAccessException {
Object nativeConnection = connection.getNativeConnection();
if (nativeConnection instanceof Jedis) {
Jedis jedis = (Jedis) nativeConnection;
// 使用SCAN命令获取Redis中的Key,并过滤出一秒内的Key
String cursor = ScanParams.SCAN_POINTER_START;
long totalCount = 0;
ScanParams scanParams = new ScanParams().match("tpsLimit:*");
do {
ScanResult<String> scanResult = jedis.scan(cursor, scanParams);
List<String> keys = scanResult.getResult();
for (String key : keys) {
String value = jedis.get(key);
long timestamp = Long.parseLong(value);
if (now - timestamp < 1000) { // 一秒内的请求
totalCount++;
}
}
cursor = scanResult.getCursor();
} while (!"0".equals(cursor));
return totalCount;
}
return null;
}
});
// 如果请求数量超过限制,返回错误信息
if (count >= MAX_TPS) {
return "Too many requests, please try again later";
}
// 继续执行原方法
return joinPoint.proceed();
}
// 根据joinPoint获取请求的唯一标识
private String getRequestKey(ProceedingJoinPoint joinPoint) {
// 这里可以根据实际需求,使用请求的URL或方法名等作为唯一标识
// 例如可以使用request.getRequestURI()获取URL
return joinPoint.getSignature().toShortString();
}
}
这样,您就可以使用Spring Boot实现一个限制某个请求的TPS为每秒200次的功能,并考虑到了分布式的问题。注意,以上代码仅作为示例,您可以根据实际需求进行修改和优化。
请优化使用RedisTemplate来进行分布式计数的方法,这个方法中使用了不推荐的命令scan
如果使用RedisTemplate进行分布式计数,并且希望避免使用不推荐的scan命令,可以尝试以下优化方法:
- 使用Redis的自增命令(INCR):Redis提供了INCR命令用于对一个键进行自增操作。可以使用RedisTemplate的opsForValue()方法获取ValueOperations对象,然后使用increment(key, delta)方法进行自增操作,其中key是计数器的键名,delta是要增加的数量。
ValueOperations<String, Long> valueOps = redisTemplate.opsForValue();
Long count = valueOps.increment("counter", 1);
这样可以避免使用scan命令,而且自增操作是原子性的,可以保证在多线程环境下的并发安全。
- 使用Lua脚本:Redis支持Lua脚本,可以将自增逻辑封装在Lua脚本中,在Redis端进行原子性执行。可以使用RedisTemplate的execute()方法执行Lua脚本。
String script = "return redis.call('INCRBY', KEYS[1], ARGV[1])";
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>(script, Long.class);
Long count = redisTemplate.execute(redisScript, Arrays.asList("counter"), 1);
通过将自增逻辑封装在Lua脚本中,可以减少网络传输开销,并且确保自增操作的原子性。
总体而言,推荐使用自增命令或Lua脚本来实现分布式计数,而不是使用scan命令。自增命令和Lua脚本都能够通过RedisTemplate进行简便的操作,并且保证了计数的原子性和性能。
请使用redis进行分布式锁的限流,避免使用scan方法
是的,除了使用Lua脚本,还有一种更简单的方式来从Redis中获取一个随机字符串,那就是使用Redis的SRANDMEMBER命令。
以下是一个使用RedisTemplate从Redis中获取随机字符串的示例:
public class RedisRandomStringGetter {
private RedisTemplate<String, String> redisTemplate;
private String keysPattern; // 匹配键的模式
public RedisRandomStringGetter(RedisTemplate<String, String> redisTemplate, String keysPattern) {
this.redisTemplate = redisTemplate;
this.keysPattern = keysPattern;
}
public String getRandomString() {
Set<String> keys = redisTemplate.keys(keysPattern);
if (keys != null && !keys.isEmpty()) {
int randomIndex = new Random().nextInt(keys.size());
return new ArrayList<>(keys).get(randomIndex);
}
return null;
}
}
在上述示例中,通过构造函数注入RedisTemplate,并指定一个键的模式(keysPattern)。然后,通过调用redisTemplate的keys方法获取与模式匹配的所有键,并将其存储在一个Set中。
接下来,如果keys集合不为空,就随机选择一个键,然后返回该键作为随机字符串。示例中使用了Random类来生成随机索引,然后通过ArrayList将Set转换为List,并根据索引获取随机键。
需要注意的是,使用keys方法在Redis中执行模式匹配操作会有性能问题,因为该操作会遍历所有的键。如果Redis中存储的键较多,建议使用更具体的模式来限定范围,以减少遍历的键的数量。
希望以上示例能够帮助你使用RedisTemplate从Redis中获取一个随机字符串。