sprngboot在网关层实现白名单+限流

sprngboot在网关层实现白名单+限流

拒绝废话,上代码

1.限流实现

1.1 限流config

package com.*.gateway.config;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.*.common.redis.util.RedisUtils;
import com.*.common.tools.StringUtil;
import com.*.gateway.rate.model.RateLimiter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
@Component
public class ApiRateLimiterConfig {

    @Autowired
    private RedisUtils redisUtils;
	//API白名单可以写到配置里
    private String[] whiteList = {"oms", "wms", "outos"};

    private String vailParam = "timestamp";

    /**
     * 限流毫秒数
     */
    private Long ipSecond = 3600L;


    private String delimiter = "#";



    public boolean checkRate(String ip, String apiName, String paramValue, int limitCount) {
        //验证白名单API
        if (!checkWhiteListAndParam(apiName, paramValue)) {
            return true;
        }
        String apiRate= redisUtils.get(ip);
        if(StringUtil.isBlank(apiRate)){
            RateLimiter rate = new RateLimiter();
            Map<String, RateLimiter.ApiLimiter> rateMap = new HashMap<>(18);
            RateLimiter.ApiLimiter apiLimiter= new RateLimiter.ApiLimiter(new AtomicInteger(0),System.currentTimeMillis() / 1000);
            rateMap.computeIfAbsent(apiName + delimiter + paramValue, k -> apiLimiter);
            rate.setApiMap(rateMap);
            redisUtils.set(ip,rate,ipSecond);
        }else {
            JSONObject jsonObject = JSONObject.parseObject(apiRate);
            RateLimiter rateLimiter = JSON.toJavaObject(jsonObject, RateLimiter.class);
            Map<String, RateLimiter.ApiLimiter> apiM = rateLimiter.getApiMap();
            RateLimiter.ApiLimiter apiL = apiM.computeIfAbsent(apiName + delimiter + paramValue, k ->  new RateLimiter.ApiLimiter(new AtomicInteger(0),System.currentTimeMillis() / 1000));
            AtomicInteger atomicInteger = apiL.getCount();
            Long starTimeMillis = apiL.getTimestamp();
            long now = System.currentTimeMillis() / 1000;
            //毫秒差值
            long difference = now - starTimeMillis;
            //大于一小时
            if (difference > ipSecond) {
                atomicInteger = new AtomicInteger(0);
                apiL.setTimestamp(now);
                apiL.setCount(atomicInteger);
            }
            //自增并获取 判断是否超过限制
            log.info("限流作用rateLimiter:{}",rateLimiter);
            if (atomicInteger.incrementAndGet() > limitCount) {
                log.info("限流IP,{},API,{}",ip,apiName);
                apiL.setCount(atomicInteger);
                return false;
            }
            redisUtils.set(ip,rateLimiter,ipSecond);
        }
        return true;
    }

    public boolean checkWhiteListAndParam(String apiName, String paramValue) {
        Boolean flag = true;
        //参数中不存在timestamp 则直接跳过限制
        if (StringUtil.isNotBlank(paramValue) && paramValue.indexOf(vailParam) == -1) {
            log.info("放行api,{}",apiName);
            flag =  false;
        }
        //api包含oms wms outos直接跳过限制
        for (int i = 0; whiteList.length - 1 < i; i++) {
            if (apiName.indexOf(whiteList[i]) != -1) {
                log.info("放行api,{}",apiName);
                flag =  false;
            }
        }
        return flag;
    }

}



1.2 限流Filter

package com.*.gateway.fiflt;

import com.*.gateway.config.ApiRateLimiterConfig;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

@Component
public class ApiRateLimitFilter implements GlobalFilter, Ordered {
    @Autowired
    private ApiRateLimiterConfig apiRateLimiter;

    // IP限流次数
    private int ipLimitCount = 10;

    //无需限制的白名单路径


    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        String ip = exchange.getRequest().getRemoteAddress().getAddress().getHostAddress();
        String apiName = exchange.getRequest().getPath().value();
        String paramValue = exchange.getRequest().getQueryParams().toString();

        if (!apiRateLimiter.checkRate(ip, apiName, paramValue, ipLimitCount)) {
            exchange.getResponse().setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
            return exchange.getResponse().setComplete();
        }

        return chain.filter(exchange);
    }

    @Override
    public int getOrder() {
        return -1000;
    }
}

2.白名单实现

如果有在gateway中 实现Auth,则可以直接写到AuthFilter里

/**
 * Copyright (c) 2019-2099, XIA HUI
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.*.gateway.filter;

import java.io.UnsupportedEncodingException;
import java.util.Arrays;

import javax.annotation.Resource;

import com.*.common.tools.IdUtil;
import com.*.gateway.constants.GatewayConstants;
import com.*.gateway.props.AuthProperties;
import com.*.gateway.provider.AuthProvider;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;

import com.alibaba.fastjson.JSON;
import com.*.common.base.constants.BaseConstants;
import com.*.common.base.resp.R;
import com.*.sso.config.TokenDecoder;
import com.*.sso.vo.TokenUser;

import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
 * 网关鉴权
 */
@Slf4j
@Component
public class AuthFilter implements GlobalFilter, Ordered {

    @Resource
    private AuthProperties authProperties;

    // 排除过滤的 uri 地址
    // swagger排除自行添加
//	private static final String[] whiteList = { "/auth/login", "/user/register", "/system/v2/api-docs",
//			"/auth/captcha/check", "/auth/captcha/get", "/auth/captcha/captchaImage", "/auth/login/slide",
//			"/auth/formLogin", "/wms-rfbiz/login", "/auth/rfLogin", "/auth/getToken" };

    /**
     * 暂时写死可以后续放入配置里或数据库
     */
    private static final String[] whiteList = { "testwilmar-intl.com","wilmarapps.com" };
    @Resource
    private TokenDecoder tokenDecoder;

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        String url = exchange.getRequest().getURI().getPath();
        if (log.isDebugEnabled()) {
            log.debug("url:{}", url);
        }
        log.info("Headers:{}", exchange.getRequest().getHeaders());
        String host = exchange.getRequest().getHeaders().getFirst(BaseConstants.HOST);
        String origin = exchange.getRequest().getHeaders().getFirst(BaseConstants.ORIGIN);
        if (origin ==null || origin.indexOf(host)==-1  || verifyWhitelist(origin)) {
            log.info("origin,白名单校验origin,{},host,{}",origin,host);
            return setUnauthorizedResponse(exchange, "Not on the whitelist. Please contact the administrator to add the whitelist");
        }


        String traceId = IdUtil.getInstance().uuid();
        String realIp = "";
        if (null != exchange.getRequest().getRemoteAddress()) {
            realIp = exchange.getRequest().getRemoteAddress().getAddress().getHostAddress();
        }
        exchange.getRequest().mutate().header(GatewayConstants.TRACEID, traceId);
        exchange.getRequest().mutate().header(GatewayConstants.REQUESTTIME, System.currentTimeMillis() + "");
        // 跳过不需要验证的路径
        if (isSkip(url, realIp)) {
            return chain.filter(exchange);
        }
        String token = exchange.getRequest().getHeaders().getFirst(BaseConstants.TOKEN);
        // token为空
        if (StringUtils.isBlank(token)) {
            return setUnauthorizedResponse(exchange, "token can't null or empty string");
        }
        token = token.replace(BaseConstants.TOKEN_START_WITH, "");
        Long ct = System.currentTimeMillis();
        TokenUser user = tokenDecoder.getTokenUser(token);
        Long ct1 = System.currentTimeMillis();

        if (log.isDebugEnabled()) {
            log.debug("decode token spend{} ms", ct1 - ct);
        }
        if (null == user) {
            return setUnauthorizedResponse(exchange, "token verify error");
        }
        // 查询token信息
        if (null == user.getId()) {
            return setUnauthorizedResponse(exchange, "token verify error");
        }
        // 设置userId到request里,后续根据userId,获取用户信息
        ServerHttpRequest mutableReq = exchange.getRequest().mutate()
                .header(BaseConstants.CURRENT_ID, user.getId().toString())
                .header(BaseConstants.CURRENT_USERNAME, user.getLoginName()).build();
        ServerWebExchange mutableExchange = exchange.mutate().request(mutableReq).build();
        return chain.filter(mutableExchange);
    }

    private Boolean verifyWhitelist(String origin){
        Boolean flag = true;
        log.info("origin,白名单校验,{}",origin);
        for(String white:whiteList){
            if(origin.indexOf(white)!=-1){
                log.info("origin,白名单校验white,{}",white);
                flag = false;
            }
        }
        return flag;
    }

    private Mono<Void> setUnauthorizedResponse(ServerWebExchange exchange, String msg) {
        ServerHttpResponse originalResponse = exchange.getResponse();
        originalResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
        originalResponse.getHeaders().add("Content-Type", "application/json;charset=UTF-8");
        byte[] response = null;
        try {
            response = JSON.toJSONString(R.err(401, msg)).getBytes(BaseConstants.UTF8);
        } catch (UnsupportedEncodingException e) {
            log.error(msg);
        }
        DataBuffer buffer = originalResponse.bufferFactory().wrap(response);
        return originalResponse.writeWith(Flux.just(buffer));
    }

    @Override
    public int getOrder() {
        return -200;
    }

    /**
     * 是否放行路径
     *
     * @param path 路径
     * @return 是否
     */
    private boolean isSkip(String path, String realIp) {
        return AuthProvider.getDefaultSkipUrl().stream().map(url -> url.replace(AuthProvider.TARGET, AuthProvider.REPLACEMENT)).anyMatch(path::startsWith)
                || ((CollectionUtils.isEmpty(authProperties.getSkipIp()) || authProperties.getSkipIp().stream().anyMatch(ip -> realIp.equals(ip)))
                && authProperties.getSkipUrl().stream().map(url -> url.replace(AuthProvider.TARGET, AuthProvider.REPLACEMENT)).anyMatch(path::startsWith));
    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值