Gateway白名单+限流
sprngboot在网关层实现白名单+限流
拒绝废话,上代码
1.限流实现
1.1 限流config
package com.*.gateway.config;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.*.common.redis.util.RedisUtils;
import com.*.common.tools.StringUtil;
import com.*.gateway.rate.model.RateLimiter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
@Component
public class ApiRateLimiterConfig {
@Autowired
private RedisUtils redisUtils;
//API白名单可以写到配置里
private String[] whiteList = {"oms", "wms", "outos"};
private String vailParam = "timestamp";
/**
* 限流毫秒数
*/
private Long ipSecond = 3600L;
private String delimiter = "#";
public boolean checkRate(String ip, String apiName, String paramValue, int limitCount) {
//验证白名单API
if (!checkWhiteListAndParam(apiName, paramValue)) {
return true;
}
String apiRate= redisUtils.get(ip);
if(StringUtil.isBlank(apiRate)){
RateLimiter rate = new RateLimiter();
Map<String, RateLimiter.ApiLimiter> rateMap = new HashMap<>(18);
RateLimiter.ApiLimiter apiLimiter= new RateLimiter.ApiLimiter(new AtomicInteger(0),System.currentTimeMillis() / 1000);
rateMap.computeIfAbsent(apiName + delimiter + paramValue, k -> apiLimiter);
rate.setApiMap(rateMap);
redisUtils.set(ip,rate,ipSecond);
}else {
JSONObject jsonObject = JSONObject.parseObject(apiRate);
RateLimiter rateLimiter = JSON.toJavaObject(jsonObject, RateLimiter.class);
Map<String, RateLimiter.ApiLimiter> apiM = rateLimiter.getApiMap();
RateLimiter.ApiLimiter apiL = apiM.computeIfAbsent(apiName + delimiter + paramValue, k -> new RateLimiter.ApiLimiter(new AtomicInteger(0),System.currentTimeMillis() / 1000));
AtomicInteger atomicInteger = apiL.getCount();
Long starTimeMillis = apiL.getTimestamp();
long now = System.currentTimeMillis() / 1000;
//毫秒差值
long difference = now - starTimeMillis;
//大于一小时
if (difference > ipSecond) {
atomicInteger = new AtomicInteger(0);
apiL.setTimestamp(now);
apiL.setCount(atomicInteger);
}
//自增并获取 判断是否超过限制
log.info("限流作用rateLimiter:{}",rateLimiter);
if (atomicInteger.incrementAndGet() > limitCount) {
log.info("限流IP,{},API,{}",ip,apiName);
apiL.setCount(atomicInteger);
return false;
}
redisUtils.set(ip,rateLimiter,ipSecond);
}
return true;
}
public boolean checkWhiteListAndParam(String apiName, String paramValue) {
Boolean flag = true;
//参数中不存在timestamp 则直接跳过限制
if (StringUtil.isNotBlank(paramValue) && paramValue.indexOf(vailParam) == -1) {
log.info("放行api,{}",apiName);
flag = false;
}
//api包含oms wms outos直接跳过限制
for (int i = 0; whiteList.length - 1 < i; i++) {
if (apiName.indexOf(whiteList[i]) != -1) {
log.info("放行api,{}",apiName);
flag = false;
}
}
return flag;
}
}
1.2 限流Filter
package com.*.gateway.fiflt;
import com.*.gateway.config.ApiRateLimiterConfig;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
@Component
public class ApiRateLimitFilter implements GlobalFilter, Ordered {
@Autowired
private ApiRateLimiterConfig apiRateLimiter;
// IP限流次数
private int ipLimitCount = 10;
//无需限制的白名单路径
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
String ip = exchange.getRequest().getRemoteAddress().getAddress().getHostAddress();
String apiName = exchange.getRequest().getPath().value();
String paramValue = exchange.getRequest().getQueryParams().toString();
if (!apiRateLimiter.checkRate(ip, apiName, paramValue, ipLimitCount)) {
exchange.getResponse().setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
return exchange.getResponse().setComplete();
}
return chain.filter(exchange);
}
@Override
public int getOrder() {
return -1000;
}
}
2.白名单实现
如果有在gateway中 实现Auth,则可以直接写到AuthFilter里
/**
* Copyright (c) 2019-2099, XIA HUI
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.*.gateway.filter;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import javax.annotation.Resource;
import com.*.common.tools.IdUtil;
import com.*.gateway.constants.GatewayConstants;
import com.*.gateway.props.AuthProperties;
import com.*.gateway.provider.AuthProvider;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import com.alibaba.fastjson.JSON;
import com.*.common.base.constants.BaseConstants;
import com.*.common.base.resp.R;
import com.*.sso.config.TokenDecoder;
import com.*.sso.vo.TokenUser;
import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
/**
* 网关鉴权
*/
@Slf4j
@Component
public class AuthFilter implements GlobalFilter, Ordered {
@Resource
private AuthProperties authProperties;
// 排除过滤的 uri 地址
// swagger排除自行添加
// private static final String[] whiteList = { "/auth/login", "/user/register", "/system/v2/api-docs",
// "/auth/captcha/check", "/auth/captcha/get", "/auth/captcha/captchaImage", "/auth/login/slide",
// "/auth/formLogin", "/wms-rfbiz/login", "/auth/rfLogin", "/auth/getToken" };
/**
* 暂时写死可以后续放入配置里或数据库
*/
private static final String[] whiteList = { "testwilmar-intl.com","wilmarapps.com" };
@Resource
private TokenDecoder tokenDecoder;
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
String url = exchange.getRequest().getURI().getPath();
if (log.isDebugEnabled()) {
log.debug("url:{}", url);
}
log.info("Headers:{}", exchange.getRequest().getHeaders());
String host = exchange.getRequest().getHeaders().getFirst(BaseConstants.HOST);
String origin = exchange.getRequest().getHeaders().getFirst(BaseConstants.ORIGIN);
if (origin ==null || origin.indexOf(host)==-1 || verifyWhitelist(origin)) {
log.info("origin,白名单校验origin,{},host,{}",origin,host);
return setUnauthorizedResponse(exchange, "Not on the whitelist. Please contact the administrator to add the whitelist");
}
String traceId = IdUtil.getInstance().uuid();
String realIp = "";
if (null != exchange.getRequest().getRemoteAddress()) {
realIp = exchange.getRequest().getRemoteAddress().getAddress().getHostAddress();
}
exchange.getRequest().mutate().header(GatewayConstants.TRACEID, traceId);
exchange.getRequest().mutate().header(GatewayConstants.REQUESTTIME, System.currentTimeMillis() + "");
// 跳过不需要验证的路径
if (isSkip(url, realIp)) {
return chain.filter(exchange);
}
String token = exchange.getRequest().getHeaders().getFirst(BaseConstants.TOKEN);
// token为空
if (StringUtils.isBlank(token)) {
return setUnauthorizedResponse(exchange, "token can't null or empty string");
}
token = token.replace(BaseConstants.TOKEN_START_WITH, "");
Long ct = System.currentTimeMillis();
TokenUser user = tokenDecoder.getTokenUser(token);
Long ct1 = System.currentTimeMillis();
if (log.isDebugEnabled()) {
log.debug("decode token spend{} ms", ct1 - ct);
}
if (null == user) {
return setUnauthorizedResponse(exchange, "token verify error");
}
// 查询token信息
if (null == user.getId()) {
return setUnauthorizedResponse(exchange, "token verify error");
}
// 设置userId到request里,后续根据userId,获取用户信息
ServerHttpRequest mutableReq = exchange.getRequest().mutate()
.header(BaseConstants.CURRENT_ID, user.getId().toString())
.header(BaseConstants.CURRENT_USERNAME, user.getLoginName()).build();
ServerWebExchange mutableExchange = exchange.mutate().request(mutableReq).build();
return chain.filter(mutableExchange);
}
private Boolean verifyWhitelist(String origin){
Boolean flag = true;
log.info("origin,白名单校验,{}",origin);
for(String white:whiteList){
if(origin.indexOf(white)!=-1){
log.info("origin,白名单校验white,{}",white);
flag = false;
}
}
return flag;
}
private Mono<Void> setUnauthorizedResponse(ServerWebExchange exchange, String msg) {
ServerHttpResponse originalResponse = exchange.getResponse();
originalResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
originalResponse.getHeaders().add("Content-Type", "application/json;charset=UTF-8");
byte[] response = null;
try {
response = JSON.toJSONString(R.err(401, msg)).getBytes(BaseConstants.UTF8);
} catch (UnsupportedEncodingException e) {
log.error(msg);
}
DataBuffer buffer = originalResponse.bufferFactory().wrap(response);
return originalResponse.writeWith(Flux.just(buffer));
}
@Override
public int getOrder() {
return -200;
}
/**
* 是否放行路径
*
* @param path 路径
* @return 是否
*/
private boolean isSkip(String path, String realIp) {
return AuthProvider.getDefaultSkipUrl().stream().map(url -> url.replace(AuthProvider.TARGET, AuthProvider.REPLACEMENT)).anyMatch(path::startsWith)
|| ((CollectionUtils.isEmpty(authProperties.getSkipIp()) || authProperties.getSkipIp().stream().anyMatch(ip -> realIp.equals(ip)))
&& authProperties.getSkipUrl().stream().map(url -> url.replace(AuthProvider.TARGET, AuthProvider.REPLACEMENT)).anyMatch(path::startsWith));
}
}