Redis+Lua脚本实现Ip限流
1.所需的pom依赖,springboota 版本:3.2.0
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>19.0</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- 对象池,使用redis时必须引入 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.7.17</version><!--使用最新版本-->
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>javax.servlet-api</artifactId>
<version>4.0.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
2.自定义注解,IpLimiter
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface IpLimiter {
/**
* 限流ip
*/
boolean useDynamicIp() default true; // 添加此属性表示是否使用动态IP
/**
* 单位时间限制通过请求数
*/
long limit() default 10;
/**
* 单位时间,单位秒
*/
long time() default 1;
/**
* 达到限流提示语
*/
String message();
}
3.Iputil 工具类
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import jakarta.servlet.http.HttpServletRequest;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.HashMap;
import java.util.Map;
/**
* IP相关工具类
*
* @author xiegege
* @date 2021/02/22 16:08
*/
@Slf4j
public class IpUtil {
/**
* 获取当前网络ip
*/
public static String getIpAddr(HttpServletRequest request) {
String ipAddress = request.getHeader("x-forwarded-for");
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("WL-Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getRemoteAddr();
if ("127.0.0.1".equals(ipAddress) || "0:0:0:0:0:0:0:1".equals(ipAddress)) {
// 根据网卡取本机配置的IP
InetAddress inet = null;
try {
inet = InetAddress.getLocalHost();
} catch (UnknownHostException e) {
e.printStackTrace();
}
ipAddress = inet.getHostAddress();
}
}
// 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割 //"***.***.***.***".length() = 15
if (ipAddress != null && ipAddress.length() > 15) {
if (ipAddress.indexOf(",") > 0) {
ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
}
}
return ipAddress;
}
/**
* 获取真实IP
*/
public static String getRealIp(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
return checkIp(ip) ? ip : (
checkIp(ip = request.getHeader("Proxy-Client-IP")) ? ip : (
checkIp(ip = request.getHeader("WL-Proxy-Client-IP")) ? ip :
request.getRemoteAddr()));
}
/**
* 校验IP
*/
private static boolean checkIp(String ip) {
return !StringUtils.isEmpty(ip) && !"unknown".equalsIgnoreCase(ip);
}
/**
* 获取操作系统,浏览器及浏览器版本信息
*/
public static Map<String, String> getOsAndBrowserInfo(HttpServletRequest request) {
String userAgent = request.getHeader("User-Agent");
String user = userAgent.toLowerCase();
String os;
String browser = "";
//=================OS Info=======================
if (userAgent.toLowerCase().contains("windows")) {
os = "Windows";
} else if (userAgent.toLowerCase().contains("mac")) {
os = "Mac";
} else if (userAgent.toLowerCase().contains("x11")) {
os = "Unix";
} else if (userAgent.toLowerCase().contains("android")) {
os = "Android";
} else if (userAgent.toLowerCase().contains("iphone")) {
os = "IPhone";
} else {
os = "UnKnown, More-Info: " + userAgent;
}
//===============Browser===========================
try {
if (user.contains("edge")) {
browser = (userAgent.substring(userAgent.indexOf("Edge")).split(" ")[0]).replace("/", "-");
} else if (user.contains("msie")) {
String substring = userAgent.substring(userAgent.indexOf("MSIE")).split(";")[0];
browser = substring.split(" ")[0].replace("MSIE", "IE") + "-" + substring.split(" ")[1];
} else if (user.contains("safari") && user.contains("version")) {
browser = (userAgent.substring(userAgent.indexOf("Safari")).split(" ")[0]).split("/")[0]
+ "-" + (userAgent.substring(userAgent.indexOf("Version")).split(" ")[0]).split("/")[1];
} else if (user.contains("opr") || user.contains("opera")) {
if (user.contains("opera")) {
browser = (userAgent.substring(userAgent.indexOf("Opera")).split(" ")[0]).split("/")[0]
+ "-" + (userAgent.substring(userAgent.indexOf("Version")).split(" ")[0]).split("/")[1];
} else if (user.contains("opr")) {
browser = ((userAgent.substring(userAgent.indexOf("OPR")).split(" ")[0]).replace("/", "-"))
.replace("OPR", "Opera");
}
} else if (user.contains("chrome")) {
browser = (userAgent.substring(userAgent.indexOf("Chrome")).split(" ")[0]).replace("/", "-");
} else if ((user.contains("mozilla/7.0")) || (user.contains("netscape6")) ||
(user.contains("mozilla/4.7")) || (user.contains("mozilla/4.78")) ||
(user.contains("mozilla/4.08")) || (user.contains("mozilla/3"))) {
browser = "Netscape-?";
} else if (user.contains("firefox")) {
browser = (userAgent.substring(userAgent.indexOf("Firefox")).split(" ")[0]).replace("/", "-");
} else if (user.contains("rv")) {
String ieVersion = (userAgent.substring(userAgent.indexOf("rv")).split(" ")[0]).replace("rv:", "-");
browser = "IE" + ieVersion.substring(0, ieVersion.length() - 1);
} else {
browser = "UnKnown";
}
} catch (Exception e) {
log.error("获取浏览器版本失败");
log.error(e.getMessage());
browser = "UnKnown";
}
Map<String, String> result = new HashMap<>(2);
result.put("OS", os);
result.put("BROWSER", browser);
return result;
}
/**
* 判断是否是内网IP
*/
public static boolean internalIp(String ip) {
byte[] addr = textToNumericFormatV4(ip);
return internalIp(addr) || "127.0.0.1".equals(ip);
}
private static boolean internalIp(byte[] addr) {
if (StringUtils.isEmpty(addr) || addr.length < 2) {
return true;
}
final byte b0 = addr[0];
final byte b1 = addr[1];
// 10.x.x.x/8
final byte SECTION_1 = 0x0A;
// 172.16.x.x/12
final byte SECTION_2 = (byte) 0xAC;
final byte SECTION_3 = (byte) 0x10;
final byte SECTION_4 = (byte) 0x1F;
// 192.168.x.x/16
final byte SECTION_5 = (byte) 0xC0;
final byte SECTION_6 = (byte) 0xA8;
switch (b0) {
case SECTION_1:
return true;
case SECTION_2:
if (b1 >= SECTION_3 && b1 <= SECTION_4) {
return true;
}
case SECTION_5:
if (b1 == SECTION_6) {
return true;
}
default:
return false;
}
}
/**
* 将IPv4地址转换成字节
*
* @param text IPv4地址
* @return byte 字节
*/
public static byte[] textToNumericFormatV4(String text) {
if (text.length() == 0) {
return null;
}
byte[] bytes = new byte[4];
String[] elements = text.split("\\.", -1);
try {
long l;
int i;
switch (elements.length) {
case 1:
l = Long.parseLong(elements[0]);
if ((l < 0L) || (l > 4294967295L)) {
return null;
}
bytes[0] = (byte) (int) (l >> 24 & 0xFF);
bytes[1] = (byte) (int) ((l & 0xFFFFFF) >> 16 & 0xFF);
bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
bytes[3] = (byte) (int) (l & 0xFF);
break;
case 2:
l = Integer.parseInt(elements[0]);
if ((l < 0L) || (l > 255L)) {
return null;
}
bytes[0] = (byte) (int) (l & 0xFF);
l = Integer.parseInt(elements[1]);
if ((l < 0L) || (l > 16777215L)) {
return null;
}
bytes[1] = (byte) (int) (l >> 16 & 0xFF);
bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
bytes[3] = (byte) (int) (l & 0xFF);
break;
case 3:
for (i = 0; i < 2; ++i) {
l = Integer.parseInt(elements[i]);
if ((l < 0L) || (l > 255L)) {
return null;
}
bytes[i] = (byte) (int) (l & 0xFF);
}
l = Integer.parseInt(elements[2]);
if ((l < 0L) || (l > 65535L)) {
return null;
}
bytes[2] = (byte) (int) (l >> 8 & 0xFF);
bytes[3] = (byte) (int) (l & 0xFF);
break;
case 4:
for (i = 0; i < 4; ++i) {
l = Integer.parseInt(elements[i]);
if ((l < 0L) || (l > 255L)) {
return null;
}
bytes[i] = (byte) (int) (l & 0xFF);
}
break;
default:
return null;
}
} catch (NumberFormatException e) {
return null;
}
return bytes;
}
/**
* 获取IP
*/
public static String getHostIp() {
try {
return InetAddress.getLocalHost().getHostAddress();
} catch (UnknownHostException e) {
e.printStackTrace();
}
return "127.0.0.1";
}
/**
* 获取主机名
*/
public static String getHostName() {
try {
return InetAddress.getLocalHost().getHostName();
} catch (UnknownHostException e) {
e.printStackTrace();
}
return "未知";
}
}
4.AOP切面类
import com.google.common.base.Preconditions;
import jakarta.annotation.PostConstruct;
import jakarta.servlet.http.HttpServletRequest;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.Signature;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import java.util.ArrayList;
import java.util.List;
@Aspect
@Component
public class IpLimterHandler {
private static final Logger LOGGER = LoggerFactory.getLogger(IpLimterHandler.class);
@Autowired
RedisTemplate redisTemplate;
/**
* getRedisScript 读取脚本工具类
* 这里设置为Long,是因为ipLimiter.lua 脚本返回的是数字类型
*/
private DefaultRedisScript<Long> getRedisScript;
@PostConstruct
public void init() {
getRedisScript = new DefaultRedisScript<>();
getRedisScript.setResultType(Long.class);
getRedisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("ipLimiter.lua")));
LOGGER.info("IpLimterHandler[分布式限流处理器]脚本加载完成");
}
/**
* 这个切点可以不要,因为下面的本身就是个注解
*/
// @Pointcut("@annotation(com.jincou.iplimiter.annotation.IpLimiter)")
// public void rateLimiter() {}
/**
* 如果保留上面这个切点,那么这里可以写成
*/
@Around("@annotation(ipLimiter)")
public Object around(ProceedingJoinPoint proceedingJoinPoint, IpLimiter ipLimiter) throws Throwable {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("IpLimterHandler[分布式限流处理器]开始执行限流操作");
}
Signature signature = proceedingJoinPoint.getSignature();
if (!(signature instanceof MethodSignature)) {
throw new IllegalArgumentException("the Annotation @IpLimter must used on method!");
}
/**
* 获取注解参数
*/
// 限流模块IP
// todo
HttpServletRequest request = attributes.getRequest();
String limitIp = IpUtil.getIpAddr(request);
Preconditions.checkNotNull(limitIp);
// 限流阈值
long limitTimes = ipLimiter.limit();
// 限流超时时间
long expireTime = ipLimiter.time();
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("IpLimterHandler[分布式限流处理器]参数值为-limitTimes={},limitTimeout={}", limitTimes, expireTime);
}
// 限流提示语
String message = ipLimiter.message();
/**
* 执行Lua脚本
*/
List<String> ipList = new ArrayList<>();
// 设置key值为注解中的值
ipList.add(limitIp);
/**
* 调用脚本并执行
*/
Long result = (Long) redisTemplate.execute(getRedisScript, ipList, expireTime, limitTimes);
if (result == 0) {
String msg = "由于超过单位时间=" + expireTime + "-允许的请求次数=" + limitTimes + "[触发限流]";
LOGGER.debug(msg);
// 达到限流返回给前端信息
return message;
}
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("IpLimterHandler[分布式限流处理器]限流执行结果-result={},请求[正常]响应", result);
}
return proceedingJoinPoint.proceed();
}
}
- Redis key value 序列化 的配置类
@Configuration
public class RedisCacheConfig {
private static final Logger LOGGER = LoggerFactory.getLogger(RedisCacheConfig.class);
@Bean
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory factory) {
RedisTemplate<String, Object> template = new RedisTemplate<>();
template.setConnectionFactory(factory);
// 使用Jackson2JsonRedisSerializer来序列化和反序列化redis的value值(默认使用JDK的序列化方式)
Jackson2JsonRedisSerializer serializer = new Jackson2JsonRedisSerializer(Object.class);
ObjectMapper mapper = new ObjectMapper();
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
mapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
serializer.setObjectMapper(mapper);
template.setValueSerializer(serializer);
// 使用StringRedisSerializer来序列化和反序列化redis的key值
template.setKeySerializer(new StringRedisSerializer());
template.afterPropertiesSet();
LOGGER.info("Springboot RedisTemplate 加载完成");
return template;
}
}
6.控制层,测试
@Controller
public class IpController {
private static final Logger LOGGER = LoggerFactory.getLogger(IpController.class);
private static final String MESSAGE = "请求失败,你的IP访问太频繁";
@ResponseBody
@RequestMapping("iplimiter")
@IpLimiter( limit = 5, time = 10, message = MESSAGE)
public String sendPayment(HttpServletRequest request) throws Exception {
String ipAddr = IpUtil.getIpAddr(request);
String hostIp = IpUtil.getHostIp();
String hostName = IpUtil.getHostName();
System.out.println(ipAddr);
System.out.println(hostIp);
System.out.println(hostName);
System.out.println(IpUtil.getOsAndBrowserInfo(request));
System.out.println("IpUtil.getRealIp(request) = " + IpUtil.getRealIp(request));
return "请求成功";
}
@ResponseBody
@RequestMapping("iplimiter1")
@IpLimiter( limit = 4, time = 10, message = MESSAGE)
public String sendPayment1(HttpServletRequest request) throws Exception {
return "请求成功";
}
}
- Lua脚本:ipLimiter.lua
--获取KEY
local key1 = KEYS[1]
local val = redis.call('incr', key1)
local ttl = redis.call('ttl', key1)
--获取ARGV内的参数并打印
local expire = ARGV[1]
local times = ARGV[2]
redis.log(redis.LOG_DEBUG,tostring(times))
redis.log(redis.LOG_DEBUG,tostring(expire))
redis.log(redis.LOG_NOTICE, "incr "..key1.." "..val);
if val == 1 then
redis.call('expire', key1, tonumber(expire))
else
if ttl == -1 then
redis.call('expire', key1, tonumber(expire))
end
end
if val > tonumber(times) then
return 0
end
return 1