SpringCloud Gateway自定义均衡策略

本文介绍了如何在SpringCloudGateway项目中,针对特定需求定制负载均衡策略,通过Redis缓存和自定义ReactorServiceInstanceLoadBalancer实现对指定请求的负载均衡,同时处理了SpringCloudGateway新版本中默认负载均衡器的变化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目前项目存在需求:需要将某个请求负载均衡到指定的服务实例,于是写一篇文章来总结一下开发过程

1.前言

首先项目需求大概是这样的:因为某个请求会下载文件到本地,而另一个请求又会使用到之前请求下载到本地的文件,所以需要将第二个请求路由到指定的某台机器上

而整体的思路就是在第一个请求中将当前ip和端口记录在redis缓存中,然后在第二个请求路由之前再取出缓存,负载均衡到指定的服务实例上

且由于项目版本问题,SpringCloud Gateway版本较新,而稍微高版本中默认负载均衡已经不是ribbon了,而是Spring LoadBalancer,所以这里自定义负载均衡策略就使用Spring LoadBalancer

2.代码开发

首先编写的就是自定义负载均衡器,要实现该负载均衡器,需要实现接口ReactorServiceInstanceLoadBalancer, 主要就是实现choose方法,而其中具体代码可以直接查看Spring中默认实现的RoundRobinLoadBalancer源码,将源码复制过来简单修改即可,想要具体研究可以自行搜索Spring LoadBalancer,我的代码如下

public class CustomLoadBalancer implements ReactorServiceInstanceLoadBalancer {
    private static final Log log = LogFactory.getLog(CustomLoadBalancer.class);
    private final String serviceId;
    private final AtomicInteger position;
    private final ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider;
    private final RedissonClient redissonClient;

    public CustomLoadBalancer(String serviceId,
                            ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider, RedissonClient redissonClient) {
        this(serviceId, new Random().nextInt(1000), serviceInstanceListSupplierProvider, redissonClient);
    }

    public CustomLoadBalancer(String serviceId, int seed,
                              ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider, RedissonClient redissonClient) {
        this.serviceId = serviceId;
        this.position = new AtomicInteger(seed);
        this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;
        this.redissonClient = redissonClient;
    }

    public Mono<Response<ServiceInstance>> choose(Request request) {
        // 获得自定义请求的上下文,原本的request包含信息太少,且无请求体
        ServerWebExchange exchange = (ServerWebExchange) request.getContext();
        ServerHttpRequest context = exchange.getRequest();
        ServiceInstanceListSupplier supplier = (ServiceInstanceListSupplier)this.serviceInstanceListSupplierProvider
                .getIfAvailable(NoopServiceInstanceListSupplier::new);

        // 获取请求路径和请求方法类型
        String path = context.getURI().getPath();
        HttpMethod method = context.getMethod();

        StringBuilder builder = new StringBuilder();

        // 若为指定接口
        if (...) {
            return getXXX(exchange).doOnNext(builder::append).then(process(request, supplier, builder));
        }

        return process(request, supplier, builder);
    }

    /**
     * 处理返回Mono响应
     *
     * @param request
     * @param supplier
     * @param builder
     * @return
     */
    private Mono<Response<ServiceInstance>> process(Request request, ServiceInstanceListSupplier supplier, StringBuilder builder) {
        return supplier.get(request).next().map((serviceInstances) -> {
            return this.processInstanceResponse(supplier, getInstanceResponse(serviceInstances, builder.toString()));
        });
    }

    /**
     * 处理服务实例响应
     *
     * @param supplier
     * @param serviceInstanceResponse
     * @return
     */
    private Response<ServiceInstance> processInstanceResponse(ServiceInstanceListSupplier supplier, Response<ServiceInstance> serviceInstanceResponse) {
        if (supplier instanceof SelectedInstanceCallback && serviceInstanceResponse.hasServer()) {
            ((SelectedInstanceCallback)supplier).selectedServiceInstance((ServiceInstance)serviceInstanceResponse.getServer());
        }

        return serviceInstanceResponse;
    }

    /**
     * 获得服务实例响应
     *
     * @param instances
     * @param XXX
     * @return
     */
    private Response<ServiceInstance> getInstanceResponse(List<ServiceInstance> instances, String XXX) {
        // 若服务实例列表为空
        if (instances.isEmpty()) {
            if (log.isWarnEnabled()) {
                log.warn("No servers available for service:" + this.serviceId);
            }

            return new EmptyResponse();
        } else {
            // 为空表示非指定接口,走正常轮询策略,否则走指定策略
            if (!StringUtils.isEmpty(XXX)) {
                return getCachedInstance(instances, XXX);
            }

            return getRoundRobinInstance(instances);
        }
    }

    /**
     * 根据redis中缓存到指定服务实例
     *
     * @param instances
     * @param XXX
     * @return
     */
    private Response<ServiceInstance> getCachedInstance(List<ServiceInstance> instances, String XXX) {
        int idx = 0;
        for (ServiceInstance instance : instances) {
            String host = instance.getHost();
            int port = instance.getPort();
            String instanceAddress = host + RouteConstants.ADDRESS_SEPARATOR + port;
            String cachedKey = XXX + RouteConstants.HOST;
            // 获取缓存中服务实例
            RBucket<String> bucket = redissonClient.getBucket(cachedKey);
            String cachedAddress = bucket.get();
            if (instanceAddress.equals(cachedAddress)) {
                log.info("choose cached instance: " + cachedAddress);
                return new DefaultResponse(instances.get(idx));
            }

            idx++;
        }

        log.info("no matched instance from cached");
        return getRoundRobinInstance(instances);
    }

    /**
     * 轮询获取服务实例
     *
     * @param instances
     * @return
     */
    private Response<ServiceInstance> getRoundRobinInstance(List<ServiceInstance> instances) {
        int pos = this.position.incrementAndGet() & Integer.MAX_VALUE;
        ServiceInstance instance = instances.get(pos % instances.size());
        return new DefaultResponse(instance);
    }

    /**
     * 获取POST方法体中的XXX
     *
     * @param exchange
     * @return
     */
    private Mono<String> getXXX(ServerWebExchange exchange){
        ServerHttpRequest request = exchange.getRequest();
        //获取请求体
        Flux<DataBuffer> body = request.getBody();

        return DataBufferUtils.join(body).flatMap(dataBuffer -> {
            byte[] bytes = new byte[dataBuffer.readableByteCount()];
            dataBuffer.read(bytes);
            String bodyString = new String(bytes, StandardCharsets.UTF_8);

            DataBufferUtils.release(dataBuffer);

            // 将方法体字符串设置到attributes中,方便后续再从其中获取并重置请求,因为请求中的body只能获取一次
            exchange.getAttributes().put(RouteConstants.REQUEST_BODY, bodyString);
            // 获取XXX,接受数据类型: application/json
            Map<String, String> map = JsonUtil.jsonToObject(bodyString, Map.class);

            return Mono.just(map.get(RouteConstants.XXX));
        });

    }

}

上方代码中需要注意的地方是:首先要在第一个接口中缓存的服务实例地址自己决定写法;然后choose方法默认的参数request中是获取不到post方法的请求体的,也就无法获取post方法传递的参数,因为我需要根据参数缓存中获取对应服务实例地址,所以需要获得请求体,且该请求体由于流式读取所以只能读取一次,我是读取之后将其设置到参数request中,而该request是我们自身写的负载均衡全局过滤器中传递的,request类型是我们自己定义的,该类代码如下

public class CustomRequest extends DefaultRequest<ServerWebExchange> {

    public CustomRequest(ServerWebExchange context) {
        super(context);
    }
}

在自定义负载均衡器中choose方法中可以看到,我们转换获取的类型就是ServerWebExchange,该类型也是全局过滤器中filter方法的参数,就是将获后的请求体设置到该exchange的attribute中,然后后续在全局过滤中重置请求,自定义负载均衡全局过滤器代码如下,该代码也是将ReactiveLoadBalancerClientFilter源码复制过来进行一定的修改

@Slf4j
public class CustomReactiveLoadBalancerClientFilter extends ReactiveLoadBalancerClientFilter {
    public static final int LOAD_BALANCER_CLIENT_FILTER_ORDER = 10150;
    private final LoadBalancerClientFactory clientFactory;
    private final GatewayLoadBalancerProperties properties;

    public CustomReactiveLoadBalancerClientFilter(LoadBalancerClientFactory clientFactory, GatewayLoadBalancerProperties properties) {
        super(clientFactory, properties);
        this.clientFactory = clientFactory;
        this.properties = properties;
    }

    public int getOrder() {
        return LOAD_BALANCER_CLIENT_FILTER_ORDER;
    }

    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        URI url = (URI)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR);
        // 获取原始路由前缀 若包含 lb 则表示需要进行负载均衡
        String schemePrefix = (String)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_SCHEME_PREFIX_ATTR);
        if (url != null && ("lb".equals(url.getScheme()) || "lb".equals(schemePrefix))) {
            ServerWebExchangeUtils.addOriginalRequestUrl(exchange, url);
            if (log.isTraceEnabled()) {
                log.trace(ReactiveLoadBalancerClientFilter.class.getSimpleName() + " url before: " + url);
            }

            URI requestUri = (URI)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR);
            String serviceId = requestUri.getHost();
            Set<LoadBalancerLifecycle> supportedLifecycleProcessors = LoadBalancerLifecycleValidator.getSupportedLifecycleProcessors(this.clientFactory.getInstances(serviceId, LoadBalancerLifecycle.class), RequestDataContext.class, ResponseData.class, ServiceInstance.class);
            // 自定义请求,将参数ServerWebExchange作为context传递,使得自定义负载均衡器中可以使用
            DefaultRequest<ServerWebExchange> lbRequest = new CustomRequest(exchange);
            return this.choose(lbRequest, serviceId, supportedLifecycleProcessors).flatMap(response -> {
                if (!response.hasServer()) {
                    supportedLifecycleProcessors.forEach((lifecycle) -> {
                        lifecycle.onComplete(new CompletionContext(Status.DISCARD, lbRequest, response));
                    });
                    throw NotFoundException.create(this.properties.isUse404(), "Unable to find instance for " + url.getHost());
                } else {
                    ServiceInstance retrievedInstance = (ServiceInstance)response.getServer();
                    URI uri = exchange.getRequest().getURI();
                    String overrideScheme = retrievedInstance.isSecure() ? "https" : "http";
                    if (schemePrefix != null) {
                        overrideScheme = url.getScheme();
                    }

                    DelegatingServiceInstance serviceInstance = new DelegatingServiceInstance(retrievedInstance, overrideScheme);
                    URI requestUrl = this.reconstructURI(serviceInstance, uri);
                    if (log.isTraceEnabled()) {
                        log.trace("LoadBalancerClientFilter url chosen: " + requestUrl);
                    }

                    exchange.getAttributes().put(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR, requestUrl);
                    exchange.getAttributes().put(ServerWebExchangeUtils.GATEWAY_LOADBALANCER_RESPONSE_ATTR, response);
                    supportedLifecycleProcessors.forEach((lifecycle) -> {
                        lifecycle.onStartRequest(lbRequest, response);
                    });
                }
                return chain.filter(getMutatedExchange(exchange));
            });
        } else {
            return chain.filter(exchange);
        }
    }

    private ResponseData buildResponseData(ServerWebExchange exchange, boolean useRawStatusCodes) {
        return useRawStatusCodes ? new ResponseData(new RequestData(exchange.getRequest()), exchange.getResponse()) : new ResponseData(exchange.getResponse(), new RequestData(exchange.getRequest()));
    }

    protected URI reconstructURI(ServiceInstance serviceInstance, URI original) {
        return LoadBalancerUriTools.reconstructURI(serviceInstance, original);
    }

    /**
     * 获取对应负载均衡器并调用choose方法
     *
     * @param lbRequest
     * @param serviceId
     * @param supportedLifecycleProcessors
     * @return
     */
    private Mono<Response<ServiceInstance>> choose(Request<ServerWebExchange> lbRequest, String serviceId, Set<LoadBalancerLifecycle> supportedLifecycleProcessors) {
        ReactorLoadBalancer<ServiceInstance> loadBalancer = (ReactorLoadBalancer)this.clientFactory.getInstance(serviceId, ReactorServiceInstanceLoadBalancer.class);
        if (loadBalancer == null) {
            throw new NotFoundException("No loadbalancer available for " + serviceId);
        } else {
            supportedLifecycleProcessors.forEach((lifecycle) -> {
                lifecycle.onStart(lbRequest);
            });
            return loadBalancer.choose(lbRequest);
        }
    }

    /**
     * 若为指定接口,则方法body会被流式获取,由于该body只能获取一次,需要重新构造request
     *
     * @param exchange
     * @return
     */
    private ServerWebExchange getMutatedExchange(ServerWebExchange exchange) {
        // 获取请求路径和请求方法类型
        ServerHttpRequest context = exchange.getRequest();
        String path = context.getURI().getPath();
        HttpMethod method = context.getMethod();

        // 若不为指定接口则不需要重置请求
        if (...) {
            return exchange;
        }

        // 负载均衡器中若获取了body过后会将值设置到Attribute中
        String requestBody = exchange.getAttribute(RouteConstants.REQUEST_BODY);
        if (Objects.isNull(requestBody)) {
            return exchange;
        }

        // 构造重置请求
        ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {
            @Override
            public Flux<DataBuffer> getBody() {
                byte[] bytes = requestBody.getBytes(StandardCharsets.UTF_8);
                DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
                return Flux.just(buffer);
            }
        };

        return exchange.mutate().request(mutatedRequest).build();
    }

}

然后进行配置

@Configuration
public class CustomLoadBalanceClientConfiguration {

    @Bean
    public ReactorServiceInstanceLoadBalancer reactorServiceInstanceLoadBalancer(Environment environment, LoadBalancerClientFactory loadBalancerClientFactory, RedissonClient redissonClient) {
        String name = environment.getProperty(LoadBalancerClientFactory.PROPERTY_NAME);
        return new CustomLoadBalancer(name, loadBalancerClientFactory.
                getLazyProvider(name, ServiceInstanceListSupplier.class), redissonClient);
    }

    @Bean
    public ReactiveLoadBalancerClientFilter reactiveLoadBalancerClientFilter(LoadBalancerClientFactory clientFactory,
                                                                             GatewayLoadBalancerProperties properties) {
        return new CustomReactiveLoadBalancerClientFilter(clientFactory, properties);
    }

}

最后在启动类或随意任意类上使用注解

@LoadBalancerClients(defaultConfiguration = CustomLoadBalanceClientConfiguration.class)

3.结尾

参考文章:
SpringCloud自定义负载均衡策略–LoadBalancer
gateway自定义负载均衡策略
springcloud的gateway之GlobalFilter获取请求信息及requestBody

Spring Cloud Gateway支持多种负载均衡策略,例如随机、轮询、权重等。如果现有的负载均衡策略不能满足你的需求,你可以自定义负载均衡策略。 首先,你需要实现`org.springframework.cloud.client.loadbalancer.reactive.LoadBalancer`接口来定义你的负载均衡策略。然后,你需要创建一个`org.springframework.cloud.gateway.filter.factory.rewrite.RewriteFunction`实例,用于将服务的URI重写为负载均衡的服务实例地址。最后,你需要将这个自定义负载均衡策略应用到Spring Cloud Gateway的路由规则中。 以下是一个示例,展示了如何定义一个基于特定请求头的自定义负载均衡策略: ```java public class CustomLoadBalancer implements LoadBalancer<ServiceInstance> { private final String headerName; public CustomLoadBalancer(String headerName) { this.headerName = headerName; } @Override public Mono<Response<ServiceInstance>> choose(Request request) { Object headerValue = request.headers().getFirst(headerName); String serviceName = "my-service"; // 根据请求头的值选择服务实例 ServiceInstance serviceInstance = ...; return Mono.just(new DefaultResponse(serviceInstance)); } } public class CustomLoadBalancerGatewayFilterFactory extends AbstractGatewayFilterFactory<CustomLoadBalancerGatewayFilterFactory.Config> { public CustomLoadBalancerGatewayFilterFactory() { super(Config.class); } @Override public GatewayFilter apply(Config config) { LoadBalancer<ServiceInstance> loadBalancer = new CustomLoadBalancer(config.getHeaderName()); RewriteFunction<String, String> rewriteFunction = uri -> { // 将URI重写为负载均衡的服务实例地址 ServiceInstance serviceInstance = loadBalancer.choose(Request.create("", new HttpHeaders())).block().getServer(); return "http://" + serviceInstance.getHost() + ":" + serviceInstance.getPort() + uri; }; return new RewritePathGatewayFilterFactory().apply(new RewritePathGatewayFilterFactory.Config().setRewriteFunction(rewriteFunction)); } public static class Config { private String headerName; public String getHeaderName() { return headerName; } public void setHeaderName(String headerName) { this.headerName = headerName; } } } ``` 在上面的示例中,`CustomLoadBalancer`实现了`LoadBalancer`接口,并基于特定的请求头选择服务实例。`CustomLoadBalancerGatewayFilterFactory`则将`CustomLoadBalancer`应用到Spring Cloud Gateway的路由规则中,并将服务的URI重写为负载均衡的服务实例地址。最后,你可以在路由规则中使用`CustomLoadBalancerGatewayFilterFactory`来定义自定义负载均衡策略。 ```yaml spring: cloud: gateway: routes: - id: my-route uri: http://my-service filters: - CustomLoadBalancer=my-header ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值