用注解的方式实现redis缓存控制qps/tps/接口请求/流量

文章展示了如何使用SpringAOP和Redis实现一个分布式流量限制器,代码中包含一个切面类,它在每个请求前检查TPS限制。通过Redis的有序集合和原子操作确保了限流的正确性,同时适配了阿里云和腾讯云的接口优化。该限流器可以应用于任何方法,具有很好的灵活性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

想法

一个流量限制器,使用chatgpt生成后优化代码
一开始想用在controller上,后来想用到内部服务模块上,尤其是阿里云/腾讯云的接口
优化了一下,可以用在任何方法上
chatgpt提问过程放在最后

直接上代码
@Slf4j
@Aspect
@Component
@RequiredArgsConstructor
public class TpsLimitAspect {

	private final CacheService cacheService;

	@Before("@annotation(tpsLimit)")
	public void tpsLimitAdvice(JoinPoint joinPoint, TpsLimit tpsLimit) throws ServiceException {
		if (Objects.isNull(tpsLimit)) {
			throw new ServiceException();
		}
		//限制业务key
		byte[] rateLimiterKey = tpsLimit.rateLimiterKey().getBytes();
		//限制时间
		long rateLimitPeriod = tpsLimit.rateLimitPeriod();
		//限制令牌数
		long rateLimitTokens = tpsLimit.rateLimitTokens();
		//当前时间
		long currentTimeMillis = System.currentTimeMillis();
		//雪花算法生成key
		String key = String.valueOf(SnowFlakeUtil.getFlowIdInstance().nextId());

		// 使用管道执行Redis命令,实现原子操作
		// 使用有序集合,排序方式为时间戳
		// 每次将时间戳添加到集合中,然后删除 当前时间减去限制时间 之前的所有数据
		// 统计数量,当前数量即为限制时间内的总数
		// 判断数量是否符合令牌数要求,如果不符合则把本次添加的时间戳删除,防止相对限制时间内的token被占用
		Boolean currentTime = cacheService.execute(connection -> {
			connection.openPipeline();
			RedisZSetCommands redisZSetCommands = connection.zSetCommands();
			//添加请求数据
			redisZSetCommands.zAdd(rateLimiterKey, currentTimeMillis, key.getBytes());
			//移除小于当前时间减去时效区间内的值
			redisZSetCommands.zRemRangeByScore(rateLimiterKey, 0, currentTimeMillis - (rateLimitPeriod * 1000));
			//统计总数
			redisZSetCommands.zCard(rateLimiterKey);
			List<Object> results = connection.closePipeline();
			Long count = (Long) results.get(2);
			if (Objects.nonNull(count) && count <= rateLimitTokens) {
				log.debug("key: {}, limitCount: {}, result: {}, time: {}", key, count, true, LocalDateUtils.getCurrentTimeFormat(LocalDateUtils.FORMAT_MMSS));
				return true;
			}
			log.debug("key: {}, limitCount: {}, result: {}, time: {}", key, count, false, LocalDateUtils.getCurrentTimeFormat(LocalDateUtils.FORMAT_MMSS));
			//移除本次请求,防止相对限制时间内的token被占用
			redisZSetCommands.zRem(rateLimiterKey, key.getBytes());
			return Boolean.FALSE;
		});

		if (Boolean.FALSE.equals(currentTime)) {
			throw new ServiceException("接口限流了");
		}
	}
}

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface TpsLimit {

	String rateLimiterKey();

	// 限流周期,单位:秒
	long rateLimitPeriod();

	// 限流周期内可用令牌数量
	long rateLimitTokens();

}

chatgpt聊天过程,可以略过
用springboot实现一个功能,限制某个请求的tps为每秒200次.
当每次有请求时,我会请求外部的接口,这个接口的限制为每秒最多200次.
功能需要考虑分布式
您可以使用Spring Boot实现以下功能来限制某个请求的TPS为每秒200次,并同时考虑分布式方案。
  1. 创建一个Spring Boot项目,并添加所需的依赖。

  2. 在你的配置文件(例如application.properties)中,设置外部接口的URL。

  3. 创建一个自定义注解,用于标记需要限制TPS的请求方法。

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface TpsLimit {
}
  1. 创建一个切面类,在该类中处理每个请求,并进行TPS限制。
@Aspect
@Component
public class TpsLimitAspect {
  
  private static final int MAX_TPS = 200; // 每秒最多允许的请求数量
  private Map<String, Long> requestCounts = new ConcurrentHashMap<>(); // 用来记录每个请求的时间戳
  
  @Pointcut("@annotation(com.example.TpsLimit)") // 使用自定义注解作为切点
  public void tpsLimitPointcut() {
  }
  
  @Around("tpsLimitPointcut()")
  public Object tpsLimitAdvice(ProceedingJoinPoint joinPoint) throws Throwable {
    String requestKey = getRequestKey(joinPoint); // 获取请求的唯一标识,可以使用请求的URL或方法名等
    
    // 获取当前时间戳
    long now = System.currentTimeMillis();
    
    // 计算一秒内的请求数量
    long count = requestCounts.entrySet().stream()
        .filter(entry -> now - entry.getValue() < 1000) // 一秒内的请求
        .count();
    
    // 如果请求数量超过限制,返回错误信息
    if (count >= MAX_TPS) {
      return "Too many requests, please try again later";
    }
    
    // 记录当前请求的时间戳
    requestCounts.put(requestKey, now);
    
    // 继续执行原方法
    return joinPoint.proceed();
  }
  
  // 根据joinPoint获取请求的唯一标识
  private String getRequestKey(ProceedingJoinPoint joinPoint) {
    // 这里可以根据实际需求,使用请求的URL或方法名等作为唯一标识
    // 例如可以使用request.getRequestURI()获取URL
    return joinPoint.getSignature().toShortString();
  }
}
  1. 在需要限制TPS的请求方法上添加自定义注解。
@RestController
public class MyController {

  @GetMapping("/limit")
  @TpsLimit // 添加自定义注解
  public String limitRequest() {
    // 处理请求逻辑
    return "Request handled successfully";
  }
}
  1. 使用Spring Boot分布式限流方案,例如使用Redis进行分布式计数。
@Configuration
@EnableAspectJAutoProxy
public class TpsLimitConfig {
  
  @Bean
  public RedisConnectionFactory redisConnectionFactory() {
    // 配置Redis连接工厂
    return new JedisConnectionFactory();
  }
  
  @Bean
  public RedisTemplate<String, Long> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
    // 配置RedisTemplate
    RedisTemplate<String, Long> redisTemplate = new RedisTemplate<>();
    redisTemplate.setConnectionFactory(redisConnectionFactory);
    redisTemplate.setKeySerializer(new StringRedisSerializer());
    redisTemplate.setValueSerializer(new GenericToStringSerializer<>(Long.class));
    return redisTemplate;
  }
  
  @Bean
  public TpsLimitAspect tpsLimitAspect(RedisTemplate<String, Long> redisTemplate) {
    // 创建TpsLimitAspect,并将RedisTemplate注入
    return new TpsLimitAspect(redisTemplate);
  }
}
  1. 在切面类中使用RedisTemplate来进行分布式计数。
@Aspect
@Component
public class TpsLimitAspect {
  
  private static final int MAX_TPS = 200; // 每秒最多允许的请求数量
  private StringRedisTemplate redisTemplate; // 使用StringRedisTemplate进行分布式计数
  
  public TpsLimitAspect(StringRedisTemplate redisTemplate) {
    this.redisTemplate = redisTemplate;
  }
  
  @Pointcut("@annotation(com.example.TpsLimit)") // 使用自定义注解作为切点
  public void tpsLimitPointcut() {
  }
  
  @Around("tpsLimitPointcut()")
  public Object tpsLimitAdvice(ProceedingJoinPoint joinPoint) throws Throwable {
    String requestKey = getRequestKey(joinPoint); // 获取请求的唯一标识
    
    // 获取当前时间戳
    long now = System.currentTimeMillis();
    
    // 设置Redis的Key和Value
    String redisKey = "tpsLimit:" + requestKey;
    String redisValue = Long.toString(now);
    
    // 使用Redis原子操作,将当前时间戳写入到Redis
    String result = redisTemplate.execute(new RedisCallback<String>() {
        @Override
        public String doInRedis(RedisConnection connection) throws DataAccessException {
            Object nativeConnection = connection.getNativeConnection();
            if (nativeConnection instanceof Jedis) {
                Jedis jedis = (Jedis) nativeConnection;
                // NX表示只在Key不存在时进行操作
                // EX表示设置Key的过期时间为1秒
                String reply = jedis.set(redisKey, redisValue, "NX", "EX", 1);
                return reply;
            }
            return null;
        }
    });
    
    // 如果Redis操作失败(表示该请求已经存在),返回错误信息
    if (!"OK".equals(result)) {
      return "Too many requests, please try again later";
    }
    
    // 使用Redis原子操作,计算一秒内的请求数量
    long count = redisTemplate.execute(new RedisCallback<Long>() {
      @Override
      public Long doInRedis(RedisConnection connection) throws DataAccessException {
        Object nativeConnection = connection.getNativeConnection();
        if (nativeConnection instanceof Jedis) {
          Jedis jedis = (Jedis) nativeConnection;
          // 使用SCAN命令获取Redis中的Key,并过滤出一秒内的Key
          String cursor = ScanParams.SCAN_POINTER_START;
          long totalCount = 0;
          ScanParams scanParams = new ScanParams().match("tpsLimit:*");
          do {
            ScanResult<String> scanResult = jedis.scan(cursor, scanParams);
            List<String> keys = scanResult.getResult();
            for (String key : keys) {
              String value = jedis.get(key);
              long timestamp = Long.parseLong(value);
              if (now - timestamp < 1000) { // 一秒内的请求
                totalCount++;
              }
            }
            cursor = scanResult.getCursor();
          } while (!"0".equals(cursor));
          return totalCount;
        }
        return null;
      }
    });
    
    // 如果请求数量超过限制,返回错误信息
    if (count >= MAX_TPS) {
      return "Too many requests, please try again later";
    }
    
    // 继续执行原方法
    return joinPoint.proceed();
  }
  
  // 根据joinPoint获取请求的唯一标识
  private String getRequestKey(ProceedingJoinPoint joinPoint) {
    // 这里可以根据实际需求,使用请求的URL或方法名等作为唯一标识
    // 例如可以使用request.getRequestURI()获取URL
    return joinPoint.getSignature().toShortString();
  }
}

这样,您就可以使用Spring Boot实现一个限制某个请求的TPS为每秒200次的功能,并考虑到了分布式的问题。注意,以上代码仅作为示例,您可以根据实际需求进行修改和优化。

请优化使用RedisTemplate来进行分布式计数的方法,这个方法中使用了不推荐的命令scan

如果使用RedisTemplate进行分布式计数,并且希望避免使用不推荐的scan命令,可以尝试以下优化方法:

  1. 使用Redis的自增命令(INCR):Redis提供了INCR命令用于对一个键进行自增操作。可以使用RedisTemplate的opsForValue()方法获取ValueOperations对象,然后使用increment(key, delta)方法进行自增操作,其中key是计数器的键名,delta是要增加的数量。
ValueOperations<String, Long> valueOps = redisTemplate.opsForValue();
Long count = valueOps.increment("counter", 1);

这样可以避免使用scan命令,而且自增操作是原子性的,可以保证在多线程环境下的并发安全。

  1. 使用Lua脚本:Redis支持Lua脚本,可以将自增逻辑封装在Lua脚本中,在Redis端进行原子性执行。可以使用RedisTemplate的execute()方法执行Lua脚本。
String script = "return redis.call('INCRBY', KEYS[1], ARGV[1])";
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>(script, Long.class);
Long count = redisTemplate.execute(redisScript, Arrays.asList("counter"), 1);

通过将自增逻辑封装在Lua脚本中,可以减少网络传输开销,并且确保自增操作的原子性。

总体而言,推荐使用自增命令或Lua脚本来实现分布式计数,而不是使用scan命令。自增命令和Lua脚本都能够通过RedisTemplate进行简便的操作,并且保证了计数的原子性和性能。

请使用redis进行分布式锁的限流,避免使用scan方法

是的,除了使用Lua脚本,还有一种更简单的方式来从Redis中获取一个随机字符串,那就是使用Redis的SRANDMEMBER命令。

以下是一个使用RedisTemplate从Redis中获取随机字符串的示例:

public class RedisRandomStringGetter {
    private RedisTemplate<String, String> redisTemplate;
    private String keysPattern; // 匹配键的模式
    
    public RedisRandomStringGetter(RedisTemplate<String, String> redisTemplate, String keysPattern) {
        this.redisTemplate = redisTemplate;
        this.keysPattern = keysPattern;
    }
    
    public String getRandomString() {
        Set<String> keys = redisTemplate.keys(keysPattern);
        if (keys != null && !keys.isEmpty()) {
            int randomIndex = new Random().nextInt(keys.size());
            return new ArrayList<>(keys).get(randomIndex);
        }
        return null;
    }
}

在上述示例中,通过构造函数注入RedisTemplate,并指定一个键的模式(keysPattern)。然后,通过调用redisTemplate的keys方法获取与模式匹配的所有键,并将其存储在一个Set中。

接下来,如果keys集合不为空,就随机选择一个键,然后返回该键作为随机字符串。示例中使用了Random类来生成随机索引,然后通过ArrayList将Set转换为List,并根据索引获取随机键。

需要注意的是,使用keys方法在Redis中执行模式匹配操作会有性能问题,因为该操作会遍历所有的键。如果Redis中存储的键较多,建议使用更具体的模式来限定范围,以减少遍历的键的数量。

希望以上示例能够帮助你使用RedisTemplate从Redis中获取一个随机字符串。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值