SpringBoot项目中对请求进行限流实现

一:自定义实现:

由于该场景一般在高并发场景下,及出于安全性考虑,采用支持线程安全的类+spring过滤器实现

核心类

ConcurrentMap:存储限流相关的key值和属性
AtomicInteger:记录当前容量及使用依赖原子性进行容量消费
/*
*  限流属性实体类
*/
@Data
public class StatItem {
    
    private String name; //key
​
    private long lastResetTime;  //记录开始时间
​
    private long interval;  //间隔时间(ms)
​
    private AtomicInteger token;  //当前剩余容量
​
    private int rate;   //总容量
    
    StatItem(String name, int rate, long interval) {
        this.name = name;
        this.rate = rate;
        this.interval = interval;
        this.lastResetTime = System.currentTimeMillis();
        this.token = new AtomicInteger(rate);
    }
    /**
     * 是否超过流速限制
     * @param servivce  服务名
     * @return true:超速;false:未超速
     */
    public boolean isAllowable(String servivce) {
        long now = System.currentTimeMillis();
        if (now > lastResetTime + interval) {
            token.set(rate);
            lastResetTime = now;
        }
​
        int value = token.get();
        boolean flag = false;
        while (value > 0 && !flag) {
            flag = token.compareAndSet(value, value - 1);
            value = token.get();
        }
​
        return flag;
    }
​
    long getLastResetTime() {
        return lastResetTime;
    }
    
    int getToken() {
        return token.get();
    }
}
​
​
/**
 * 限制单server处理能力的tps限制
 * 实现Filter接口,重写doFilter实现当次请求是否触发限流;未触发放行,触发则返回对应提示
 *
 */
@Service
@Slf4j
public class TpsLimitFilter implements Filter {
    private final TPSLimiter tpsLimiter = new DefaultTPSLimiter();
​
    @Override
    public void init(FilterConfig config) throws ServletException {
​
        ServletContext context = config.getServletContext();
        ApplicationContext ctx = WebApplicationContextUtils.getRequiredWebApplicationContext(context);
​
    }
​
    @Override
    public void destroy() {
​
    }
​
    @Override
    public void doFilter(ServletRequest req, ServletResponse res,
                         FilterChain chain) throws IOException, ServletException {
        if (!tpsLimiter.isAllowable("OPEN_SERVER_SINGLE")) {
            logger.error("超过TPS限制速率。每:" + 10000 / 1000 + "秒钟" + 3 + "次调用!");
            // 不放行
            Map<String, String> map = new HashMap<String, String>();
            map.put("code", "500");
            map.put("message", "请求超tps速率,稍后再试");
            JSONObject json = new JSONObject();
            json.putAll(map);
            HttpServletResponse response = (HttpServletResponse) res;
            response.setCharacterEncoding("utf-8");
            PrintWriter out = response.getWriter();
            out.print(json.toString());
            out.flush();
            out.close();
            logger.info("=====启动线程,单server 告警信息入库");
            return;
        }
        chain.doFilter(req, res);
    }
​
}
​
/**
     * 是否触发限流接口
     * @param url
     * @return
     */
public interface TPSLimiter {
    
    boolean isAllowable(String key);
    
}
​
​
/**
 * 是否触发限流实现类
 *
 */
public class DefaultTPSLimiter implements TPSLimiter {
    private static Logger log = LogManager.getLogger(DefaultTPSLimiter.class);
    private final ConcurrentMap<String, StatItem> stats = new ConcurrentHashMap<String, StatItem>();
​
    public boolean isAllowable(String service) {
        int rate = 3;  //规定时间内最大可以请求多少次,根据实际需求进行修改
        long interval = 10000; //时间间隔,根据实际需求进行修改
        if (rate > 0) {
            StatItem statItem = stats.get(service);
            if (statItem == null) {
                stats.putIfAbsent(service, new StatItem(service, rate, interval));
                statItem = stats.get(service);
            }
            log.info("TPS并发数[" + 3 + "]剩余" + statItem.getToken());
            return statItem.isAllowable(service);
        } else {
            StatItem statItem = stats.get(service);
            if (statItem != null) {
                stats.remove(service);
            }
        }
​
        return true;
    }
}

二:使用第三方工具类

使用bucket4j-core

1.导入坐标

 <dependency>
      <groupId>io.github.bucket4j</groupId>
      <artifactId>bucket4j-core</artifactId>
      <version>6.0.0</version> <!-- jdk8版本,如果是jdk11或者更高版本,请自行更换坐标 artifactId不一样 -->
 </dependency>

2.自定义限流过滤器实现OncePerRequestFilter抽象类

OncePerRequestFilter:推荐使用,它的作用是确保每个请求只被过滤一次,无论请求如何转发或包含其他请求。这是通过检查是否已经执行过过滤操作来实现的。通常用于避免某些重复操作,如日志记录、统计等。

@Component
public class RateLimitFilter extends OncePerRequestFilter {
​
    private final Bucket bucket = Bucket4j.builder()
        .addLimit(Bandwidth.simple(100, Duration.ofSeconds(10)))  // 每 10 秒最多 100 次请求,根据实际需求自行调整
        .build();
​
    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
            throws ServletException, IOException {
​
        if (bucket.tryConsume(1)) {
            filterChain.doFilter(request, response);  // 允许请求继续
        } else {
            Map<String, String> map = new HashMap<String, String>();
            map.put("code", "500");
            map.put("message", "RateLimitFilter超tps速率");
            JSONObject json = new JSONObject();
            json.putAll(map);
            response.setStatus(HttpServletResponse.SC_GATEWAY_TIMEOUT);  // 返回 502
            response.setCharacterEncoding("UTF-8");
            response.getWriter().write(json.toString());
        }
    }
}

3.注册自定义过滤器

@Configuration
public class FilterConfig {
​
    @Bean
    public FilterRegistrationBean<RateLimitFilter> loggingFilterOne() {
        FilterRegistrationBean<RateLimitFilter> registrationBean = new FilterRegistrationBean<>();
        registrationBean.setFilter(new RateLimitFilter());
​
        // 设置拦截路径 (可以设置多个 URL 模式)
        registrationBean.addUrlPatterns("/*");  //   /*:拦截所有,/api/:只拦截 "/api/" 路径下的请求
​
        // 设置过滤器的优先级 (数字越小,优先级越高)
        registrationBean.setOrder(1);  // 1 表示最高优先级
​
        return registrationBean;
    }
}
  
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值