基于RateLimiter的接口限流器

本文介绍了Google Guava库中的RateLimiter如何用于接口限流,详细阐述了其工作原理,包括两种创建方法及限流策略。并展示了通过RateLimiterFactory创建实例、注解方式和代码插入方式实现限流的详细步骤。
部署运行你感兴趣的模型镜像

知识前述

RateLimiter是Google开源工具包Guava提供的限流工具类,该类基于令牌桶算法实现流量限制。

RateLimiter有两种构造方法:

(1)通过RateLimiter.create(double permitsPerSecond)创建实例,该实例以固定的速率生成令牌,参数permitsPerSecond代表每秒生成的令牌桶数量,令牌桶存放一秒的令牌数量;

(2)通过RateLimiter.create(double permitsPerSecond, long warmupPeriod, TimeUnit unit)方法创建限流器实例,该实例初始生成令牌的速率为零,经过一定时间段warmupPeriod后,速率达到指定的permitsPerSecond。

限流方法:

(1)通过ratelimiter.acquire(int permits)获取指定数量的令牌,等待直至获取到令牌;

(2)通过ratelimiter.tryAcquire(int permits, long timeout, TimeUnit unit)在指定时间内获取指定数量令牌,通过计算获取到令牌的时间,若小于指定时间,则等待直至获取到令牌返回true,若大于指定时间,则直接返回false。

限流器的功能

限流器需要实现两个功能:

(1)提供注解的方式,采用AOP对指定方法限流;

(2)采用代码插入方式限流。

由于一个RateLimiter实例生成令牌的速度固定,若多个地方同时使用,则该RateLimiter限制的速度是几个地方的调用速度之和。因此,为明确每个需要限流的接口的限流速率,需要为每个限流接口生成一个RateLimiter实例。

限流器的具体实现

1、RateLimiterFactory工厂类——生成RateLimiter实例

为了保证并发状态下,同一个接口多次调用,使用同一个RateLimiter实例,通过一个ConcurrentHashMap存储所有RateLimiter实例,key为该实例所在接口的信息。

调用getRateLimiter(String location, int rate)方法获取RateLimiter,并调整令牌生成速率rate。若该location没有对应的RateLimiter,则新建实例。

import com.google.common.util.concurrent.RateLimiter;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;

/**
 * Created by wangpy on 2019-12-02 10:34
 */
public class RateLimiterFactory {
    private static ConcurrentMap<String, RateLimiter> rateLimiterConcurrentMap = new ConcurrentHashMap<>();

    public static RateLimiter getRateLimiter(String location, int rate) {
        if(rateLimiterConcurrentMap.containsKey(location)) {
            RateLimiter rateLimiter = rateLimiterConcurrentMap.get(location);
            if (rateLimiter.getRate() != rate) {
                rateLimiter.setRate(rate);
            }
            return rateLimiter;
        } else {
            RateLimiter rateLimiter = RateLimiter.create(rate, 1000, TimeUnit.MILLISECONDS);
            rateLimiterConcurrentMap.putIfAbsent(location, rateLimiter);

            return rateLimiterConcurrentMap.get(location);
        }
    }
}

2、新建注解类

注解类有两个标签,(1)rate:表示RateLimiter生成令牌的速率;(2)expireTime:表示获取令牌的超时时间,默认为-1。若expireTime<=0,则表示等待至获取令牌成功;若expireTime>0,则表示获取令牌超时时间,超时则直接返回。

import java.lang.annotation.*;

/**
 * Created by wangpy on 2019-12-02 15:10
 */
@Documented
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiterAnno {
    //获取token速率
    int rate();
    //超时时间(单位:millsecond)
    int expireTime() default -1;
}

3、切面类

(1)定义切点为添加RateLimiterAnno注解的方法

@Pointcut("@annotation(me.common.ratelimiter.RateLimiterAnno)")

(2)创建环绕通知

通过ProceedingJoinPoint获取切点所在方法位置location,并获取注解的标签expireTime和rate,调用RateLimiterFactory方法获取RateLimiter实例,根据expireTime判断获取rateLimiter获取令牌的方式,是否采用acquire()方法持续等待直至获取令牌,还是采用tryAcquire()方法尝试获取令牌,超时则抛出异常。

import com.google.common.util.concurrent.RateLimiter;
import me.ele.elog.Log;
import me.ele.elog.LogFactory;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;
import org.aspectj.lang.annotation.Aspect;

import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;

/**
 * Created by wangpy on 2019-12-02 15:14
 */
@Component
@Aspect
public class RateLimiterAspect {
    private static final transient Log LOG = LogFactory.getLog(RateLimiterAspect.class);

    @Pointcut("@annotation(me.common.ratelimiter.RateLimiterAnno)")
    public void pointCut(){
    }

    @Around("pointCut()")
    public Object process(ProceedingJoinPoint point) {
        try {
            MethodSignature signature = (MethodSignature)point.getSignature();
            String location = signature.getMethod().toString();
            Method method = signature.getMethod();
            RateLimiterAnno rateLimiterAnno = method.getAnnotation(RateLimiterAnno.class);
            int expirTime = rateLimiterAnno.expireTime();
            int rate = rateLimiterAnno.rate();

            RateLimiter rateLimiter = RateLimiterFactory.getRateLimiter(location, rate);

            if (expirTime <= 0) {
                rateLimiter.acquire();
                return point.proceed();
            }

            if (rateLimiter.tryAcquire(expirTime, TimeUnit.MILLISECONDS)) {
                return point.proceed();
            } else {
                throw new RuntimeException("overstep the limits of interface rate");
            }

        } catch (Throwable throwable) {
            LOG.error("RateLimiterAspect aspectj error!", throwable);
            return null;
        }
    }
}

4、代码插入方式

由于添加注解的方法必须被反射调用,注解才能生效。因此,为了给不适用于反射调用方法添加限流器,我们封装了RateLimiterTool工具类,提供静态限流方法。

该工具类提供三个限流方法。

(1)waitToGetTokenForSingleThread(int rate):等待获取token。若获取不到,则线程睡眠一个令牌生成周期;

(2)waitToGetTokenWithExpireTime(int rate, long millsecond):等待获取token。若获取token的时长超过millsecond,则直接返回false;

(3)waitToGetTokenForMultiThread(int rate):等待获取token。若获取不到,则持续等待。

其中,getRateLimiterLocation()方法通过Thread.currentThread().getStackTrace()获取限流方法所在位置location,之后通过RateLimiterFactory获取RateLimiter实例。

package me.ele.lpd_ai.pressure_balance.common.ratelimiter;

import com.google.common.util.concurrent.RateLimiter;
import me.ele.elog.Log;
import me.ele.elog.LogFactory;

import java.util.concurrent.TimeUnit;

/**
 * Created by wangpy on 2019-11-29 15:21
 */
public class RateLimiterTool {
    public static Log log = LogFactory.getLog(RateLimiterTool.class);

    /**
     * 等待获取token。若获取不到,则线程睡眠一个令牌生成周期
     * @param rate 获取token的速率,单位:次/分钟
     */
    public static void waitToGetTokenForSingleThread(int rate) {
        String location = getRateLimiterLocation();
        RateLimiter rateLimiter = RateLimiterFactory.getRateLimiter(location, rate);
        
        int period = 1000 / rate;
        int timeout = 4 * period;
        if (!rateLimiter.tryAcquire(timeout, TimeUnit.MILLISECONDS)) {
            try {
                log.info("fail to get token and sleep {} ms", period);
                Thread.sleep(period);
            }catch (Exception e){
                log.error("Wait to get token sleep error", e);
            }
        }
    }

    /**
     * 等待获取token。若获取token的时长超过millsecond,则直接返回false
     * @param rate
     * @param millsecond
     * @return
     */
    public static boolean waitToGetTokenWithExpireTime(int rate, long millsecond) {
        String location = getRateLimiterLocation();
        RateLimiter rateLimiter = RateLimiterFactory.getRateLimiter(location, rate);

        return rateLimiter.tryAcquire(millsecond, TimeUnit.MILLISECONDS);
    }

    private static String getRateLimiterLocation() {
        String file = Thread.currentThread().getStackTrace()[3].getClassName();
        int line = Thread.currentThread().getStackTrace()[3].getLineNumber();
        String location = file + ":" + line;
        return location;
    }

    /**
     * 等待获取token。若获取不到,则持续等待
     * @param rate
     */
    public static void waitToGetTokenForMultiThread(int rate) {
        String location = getRateLimiterLocation();
        RateLimiter rateLimiter = RateLimiterFactory.getRateLimiter(location, rate);
        rateLimiter.acquire();
    }

    public static void test() {
        //waitToGetTokenForSingleThread(10);
        waitToGetTokenWithExpireTime(10, 1000);
    }

    public static void main(String[] args) {
//        RateLimiterTool.waitToGetTokenForSendZBAoiFast2MQ();
//        String clazz = Thread.currentThread() .getStackTrace() [1].getClassName();
//        System.out.println(clazz);
        long startTime = System.currentTimeMillis();
        test();
        log.info("start time: " + (System.currentTimeMillis() - startTime));
        for (int i = 0; i < 10; i++) {
            //test();
            //log.info(Thread.currentThread().getName() + "get token time: " + (System.currentTimeMillis() - startTime));
            Thread thread = new Thread(new TestThread(), "Thread:" + i);
            thread.start();
        }
        System.out.println();
    }

    static class TestThread implements Runnable {
        @Override
        public void run() {
            test();
            log.info(Thread.currentThread().getName() + "get token time: " + System.currentTimeMillis());
        }
    }
}

 

您可能感兴趣的与本文相关的镜像

Llama Factory

Llama Factory

模型微调
LLama-Factory

LLaMA Factory 是一个简单易用且高效的大型语言模型(Large Language Model)训练与微调平台。通过 LLaMA Factory,可以在无需编写任何代码的前提下,在本地完成上百种预训练模型的微调

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值