知识前述
RateLimiter是Google开源工具包Guava提供的限流工具类,该类基于令牌桶算法实现流量限制。
RateLimiter有两种构造方法:
(1)通过RateLimiter.create(double permitsPerSecond)创建实例,该实例以固定的速率生成令牌,参数permitsPerSecond代表每秒生成的令牌桶数量,令牌桶存放一秒的令牌数量;
(2)通过RateLimiter.create(double permitsPerSecond, long warmupPeriod, TimeUnit unit)方法创建限流器实例,该实例初始生成令牌的速率为零,经过一定时间段warmupPeriod后,速率达到指定的permitsPerSecond。
限流方法:
(1)通过ratelimiter.acquire(int permits)获取指定数量的令牌,等待直至获取到令牌;
(2)通过ratelimiter.tryAcquire(int permits, long timeout, TimeUnit unit)在指定时间内获取指定数量令牌,通过计算获取到令牌的时间,若小于指定时间,则等待直至获取到令牌返回true,若大于指定时间,则直接返回false。
限流器的功能
限流器需要实现两个功能:
(1)提供注解的方式,采用AOP对指定方法限流;
(2)采用代码插入方式限流。
由于一个RateLimiter实例生成令牌的速度固定,若多个地方同时使用,则该RateLimiter限制的速度是几个地方的调用速度之和。因此,为明确每个需要限流的接口的限流速率,需要为每个限流接口生成一个RateLimiter实例。
限流器的具体实现
1、RateLimiterFactory工厂类——生成RateLimiter实例
为了保证并发状态下,同一个接口多次调用,使用同一个RateLimiter实例,通过一个ConcurrentHashMap存储所有RateLimiter实例,key为该实例所在接口的信息。
调用getRateLimiter(String location, int rate)方法获取RateLimiter,并调整令牌生成速率rate。若该location没有对应的RateLimiter,则新建实例。
import com.google.common.util.concurrent.RateLimiter;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
/**
* Created by wangpy on 2019-12-02 10:34
*/
public class RateLimiterFactory {
private static ConcurrentMap<String, RateLimiter> rateLimiterConcurrentMap = new ConcurrentHashMap<>();
public static RateLimiter getRateLimiter(String location, int rate) {
if(rateLimiterConcurrentMap.containsKey(location)) {
RateLimiter rateLimiter = rateLimiterConcurrentMap.get(location);
if (rateLimiter.getRate() != rate) {
rateLimiter.setRate(rate);
}
return rateLimiter;
} else {
RateLimiter rateLimiter = RateLimiter.create(rate, 1000, TimeUnit.MILLISECONDS);
rateLimiterConcurrentMap.putIfAbsent(location, rateLimiter);
return rateLimiterConcurrentMap.get(location);
}
}
}
2、新建注解类
注解类有两个标签,(1)rate:表示RateLimiter生成令牌的速率;(2)expireTime:表示获取令牌的超时时间,默认为-1。若expireTime<=0,则表示等待至获取令牌成功;若expireTime>0,则表示获取令牌超时时间,超时则直接返回。
import java.lang.annotation.*;
/**
* Created by wangpy on 2019-12-02 15:10
*/
@Documented
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiterAnno {
//获取token速率
int rate();
//超时时间(单位:millsecond)
int expireTime() default -1;
}
3、切面类
(1)定义切点为添加RateLimiterAnno注解的方法
@Pointcut("@annotation(me.common.ratelimiter.RateLimiterAnno)")
(2)创建环绕通知
通过ProceedingJoinPoint获取切点所在方法位置location,并获取注解的标签expireTime和rate,调用RateLimiterFactory方法获取RateLimiter实例,根据expireTime判断获取rateLimiter获取令牌的方式,是否采用acquire()方法持续等待直至获取令牌,还是采用tryAcquire()方法尝试获取令牌,超时则抛出异常。
import com.google.common.util.concurrent.RateLimiter;
import me.ele.elog.Log;
import me.ele.elog.LogFactory;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;
import org.aspectj.lang.annotation.Aspect;
import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;
/**
* Created by wangpy on 2019-12-02 15:14
*/
@Component
@Aspect
public class RateLimiterAspect {
private static final transient Log LOG = LogFactory.getLog(RateLimiterAspect.class);
@Pointcut("@annotation(me.common.ratelimiter.RateLimiterAnno)")
public void pointCut(){
}
@Around("pointCut()")
public Object process(ProceedingJoinPoint point) {
try {
MethodSignature signature = (MethodSignature)point.getSignature();
String location = signature.getMethod().toString();
Method method = signature.getMethod();
RateLimiterAnno rateLimiterAnno = method.getAnnotation(RateLimiterAnno.class);
int expirTime = rateLimiterAnno.expireTime();
int rate = rateLimiterAnno.rate();
RateLimiter rateLimiter = RateLimiterFactory.getRateLimiter(location, rate);
if (expirTime <= 0) {
rateLimiter.acquire();
return point.proceed();
}
if (rateLimiter.tryAcquire(expirTime, TimeUnit.MILLISECONDS)) {
return point.proceed();
} else {
throw new RuntimeException("overstep the limits of interface rate");
}
} catch (Throwable throwable) {
LOG.error("RateLimiterAspect aspectj error!", throwable);
return null;
}
}
}
4、代码插入方式
由于添加注解的方法必须被反射调用,注解才能生效。因此,为了给不适用于反射调用方法添加限流器,我们封装了RateLimiterTool工具类,提供静态限流方法。
该工具类提供三个限流方法。
(1)waitToGetTokenForSingleThread(int rate):等待获取token。若获取不到,则线程睡眠一个令牌生成周期;
(2)waitToGetTokenWithExpireTime(int rate, long millsecond):等待获取token。若获取token的时长超过millsecond,则直接返回false;
(3)waitToGetTokenForMultiThread(int rate):等待获取token。若获取不到,则持续等待。
其中,getRateLimiterLocation()方法通过Thread.currentThread().getStackTrace()获取限流方法所在位置location,之后通过RateLimiterFactory获取RateLimiter实例。
package me.ele.lpd_ai.pressure_balance.common.ratelimiter;
import com.google.common.util.concurrent.RateLimiter;
import me.ele.elog.Log;
import me.ele.elog.LogFactory;
import java.util.concurrent.TimeUnit;
/**
* Created by wangpy on 2019-11-29 15:21
*/
public class RateLimiterTool {
public static Log log = LogFactory.getLog(RateLimiterTool.class);
/**
* 等待获取token。若获取不到,则线程睡眠一个令牌生成周期
* @param rate 获取token的速率,单位:次/分钟
*/
public static void waitToGetTokenForSingleThread(int rate) {
String location = getRateLimiterLocation();
RateLimiter rateLimiter = RateLimiterFactory.getRateLimiter(location, rate);
int period = 1000 / rate;
int timeout = 4 * period;
if (!rateLimiter.tryAcquire(timeout, TimeUnit.MILLISECONDS)) {
try {
log.info("fail to get token and sleep {} ms", period);
Thread.sleep(period);
}catch (Exception e){
log.error("Wait to get token sleep error", e);
}
}
}
/**
* 等待获取token。若获取token的时长超过millsecond,则直接返回false
* @param rate
* @param millsecond
* @return
*/
public static boolean waitToGetTokenWithExpireTime(int rate, long millsecond) {
String location = getRateLimiterLocation();
RateLimiter rateLimiter = RateLimiterFactory.getRateLimiter(location, rate);
return rateLimiter.tryAcquire(millsecond, TimeUnit.MILLISECONDS);
}
private static String getRateLimiterLocation() {
String file = Thread.currentThread().getStackTrace()[3].getClassName();
int line = Thread.currentThread().getStackTrace()[3].getLineNumber();
String location = file + ":" + line;
return location;
}
/**
* 等待获取token。若获取不到,则持续等待
* @param rate
*/
public static void waitToGetTokenForMultiThread(int rate) {
String location = getRateLimiterLocation();
RateLimiter rateLimiter = RateLimiterFactory.getRateLimiter(location, rate);
rateLimiter.acquire();
}
public static void test() {
//waitToGetTokenForSingleThread(10);
waitToGetTokenWithExpireTime(10, 1000);
}
public static void main(String[] args) {
// RateLimiterTool.waitToGetTokenForSendZBAoiFast2MQ();
// String clazz = Thread.currentThread() .getStackTrace() [1].getClassName();
// System.out.println(clazz);
long startTime = System.currentTimeMillis();
test();
log.info("start time: " + (System.currentTimeMillis() - startTime));
for (int i = 0; i < 10; i++) {
//test();
//log.info(Thread.currentThread().getName() + "get token time: " + (System.currentTimeMillis() - startTime));
Thread thread = new Thread(new TestThread(), "Thread:" + i);
thread.start();
}
System.out.println();
}
static class TestThread implements Runnable {
@Override
public void run() {
test();
log.info(Thread.currentThread().getName() + "get token time: " + System.currentTimeMillis());
}
}
}
本文介绍了Google Guava库中的RateLimiter如何用于接口限流,详细阐述了其工作原理,包括两种创建方法及限流策略。并展示了通过RateLimiterFactory创建实例、注解方式和代码插入方式实现限流的详细步骤。
85万+

被折叠的 条评论
为什么被折叠?



