目前项目存在需求:需要将某个请求负载均衡到指定的服务实例,于是写一篇文章来总结一下开发过程
1.前言
首先项目需求大概是这样的:因为某个请求会下载文件到本地,而另一个请求又会使用到之前请求下载到本地的文件,所以需要将第二个请求路由到指定的某台机器上
而整体的思路就是在第一个请求中将当前ip和端口记录在redis缓存中,然后在第二个请求路由之前再取出缓存,负载均衡到指定的服务实例上
且由于项目版本问题,SpringCloud Gateway版本较新,而稍微高版本中默认负载均衡已经不是ribbon了,而是Spring LoadBalancer,所以这里自定义负载均衡策略就使用Spring LoadBalancer
2.代码开发
首先编写的就是自定义负载均衡器,要实现该负载均衡器,需要实现接口ReactorServiceInstanceLoadBalancer, 主要就是实现choose方法,而其中具体代码可以直接查看Spring中默认实现的RoundRobinLoadBalancer源码,将源码复制过来简单修改即可,想要具体研究可以自行搜索Spring LoadBalancer,我的代码如下
public class CustomLoadBalancer implements ReactorServiceInstanceLoadBalancer {
private static final Log log = LogFactory.getLog(CustomLoadBalancer.class);
private final String serviceId;
private final AtomicInteger position;
private final ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider;
private final RedissonClient redissonClient;
public CustomLoadBalancer(String serviceId,
ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider, RedissonClient redissonClient) {
this(serviceId, new Random().nextInt(1000), serviceInstanceListSupplierProvider, redissonClient);
}
public CustomLoadBalancer(String serviceId, int seed,
ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider, RedissonClient redissonClient) {
this.serviceId = serviceId;
this.position = new AtomicInteger(seed);
this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;
this.redissonClient = redissonClient;
}
public Mono<Response<ServiceInstance>> choose(Request request) {
// 获得自定义请求的上下文,原本的request包含信息太少,且无请求体
ServerWebExchange exchange = (ServerWebExchange) request.getContext();
ServerHttpRequest context = exchange.getRequest();
ServiceInstanceListSupplier supplier = (ServiceInstanceListSupplier)this.serviceInstanceListSupplierProvider
.getIfAvailable(NoopServiceInstanceListSupplier::new);
// 获取请求路径和请求方法类型
String path = context.getURI().getPath();
HttpMethod method = context.getMethod();
StringBuilder builder = new StringBuilder();
// 若为指定接口
if (...) {
return getXXX(exchange).doOnNext(builder::append).then(process(request, supplier, builder));
}
return process(request, supplier, builder);
}
/**
* 处理返回Mono响应
*
* @param request
* @param supplier
* @param builder
* @return
*/
private Mono<Response<ServiceInstance>> process(Request request, ServiceInstanceListSupplier supplier, StringBuilder builder) {
return supplier.get(request).next().map((serviceInstances) -> {
return this.processInstanceResponse(supplier, getInstanceResponse(serviceInstances, builder.toString()));
});
}
/**
* 处理服务实例响应
*
* @param supplier
* @param serviceInstanceResponse
* @return
*/
private Response<ServiceInstance> processInstanceResponse(ServiceInstanceListSupplier supplier, Response<ServiceInstance> serviceInstanceResponse) {
if (supplier instanceof SelectedInstanceCallback && serviceInstanceResponse.hasServer()) {
((SelectedInstanceCallback)supplier).selectedServiceInstance((ServiceInstance)serviceInstanceResponse.getServer());
}
return serviceInstanceResponse;
}
/**
* 获得服务实例响应
*
* @param instances
* @param XXX
* @return
*/
private Response<ServiceInstance> getInstanceResponse(List<ServiceInstance> instances, String XXX) {
// 若服务实例列表为空
if (instances.isEmpty()) {
if (log.isWarnEnabled()) {
log.warn("No servers available for service:" + this.serviceId);
}
return new EmptyResponse();
} else {
// 为空表示非指定接口,走正常轮询策略,否则走指定策略
if (!StringUtils.isEmpty(XXX)) {
return getCachedInstance(instances, XXX);
}
return getRoundRobinInstance(instances);
}
}
/**
* 根据redis中缓存到指定服务实例
*
* @param instances
* @param XXX
* @return
*/
private Response<ServiceInstance> getCachedInstance(List<ServiceInstance> instances, String XXX) {
int idx = 0;
for (ServiceInstance instance : instances) {
String host = instance.getHost();
int port = instance.getPort();
String instanceAddress = host + RouteConstants.ADDRESS_SEPARATOR + port;
String cachedKey = XXX + RouteConstants.HOST;
// 获取缓存中服务实例
RBucket<String> bucket = redissonClient.getBucket(cachedKey);
String cachedAddress = bucket.get();
if (instanceAddress.equals(cachedAddress)) {
log.info("choose cached instance: " + cachedAddress);
return new DefaultResponse(instances.get(idx));
}
idx++;
}
log.info("no matched instance from cached");
return getRoundRobinInstance(instances);
}
/**
* 轮询获取服务实例
*
* @param instances
* @return
*/
private Response<ServiceInstance> getRoundRobinInstance(List<ServiceInstance> instances) {
int pos = this.position.incrementAndGet() & Integer.MAX_VALUE;
ServiceInstance instance = instances.get(pos % instances.size());
return new DefaultResponse(instance);
}
/**
* 获取POST方法体中的XXX
*
* @param exchange
* @return
*/
private Mono<String> getXXX(ServerWebExchange exchange){
ServerHttpRequest request = exchange.getRequest();
//获取请求体
Flux<DataBuffer> body = request.getBody();
return DataBufferUtils.join(body).flatMap(dataBuffer -> {
byte[] bytes = new byte[dataBuffer.readableByteCount()];
dataBuffer.read(bytes);
String bodyString = new String(bytes, StandardCharsets.UTF_8);
DataBufferUtils.release(dataBuffer);
// 将方法体字符串设置到attributes中,方便后续再从其中获取并重置请求,因为请求中的body只能获取一次
exchange.getAttributes().put(RouteConstants.REQUEST_BODY, bodyString);
// 获取XXX,接受数据类型: application/json
Map<String, String> map = JsonUtil.jsonToObject(bodyString, Map.class);
return Mono.just(map.get(RouteConstants.XXX));
});
}
}
上方代码中需要注意的地方是:首先要在第一个接口中缓存的服务实例地址自己决定写法;然后choose方法默认的参数request中是获取不到post方法的请求体的,也就无法获取post方法传递的参数,因为我需要根据参数缓存中获取对应服务实例地址,所以需要获得请求体,且该请求体由于流式读取所以只能读取一次,我是读取之后将其设置到参数request中,而该request是我们自身写的负载均衡全局过滤器中传递的,request类型是我们自己定义的,该类代码如下
public class CustomRequest extends DefaultRequest<ServerWebExchange> {
public CustomRequest(ServerWebExchange context) {
super(context);
}
}
在自定义负载均衡器中choose方法中可以看到,我们转换获取的类型就是ServerWebExchange,该类型也是全局过滤器中filter方法的参数,就是将获后的请求体设置到该exchange的attribute中,然后后续在全局过滤中重置请求,自定义负载均衡全局过滤器代码如下,该代码也是将ReactiveLoadBalancerClientFilter源码复制过来进行一定的修改
@Slf4j
public class CustomReactiveLoadBalancerClientFilter extends ReactiveLoadBalancerClientFilter {
public static final int LOAD_BALANCER_CLIENT_FILTER_ORDER = 10150;
private final LoadBalancerClientFactory clientFactory;
private final GatewayLoadBalancerProperties properties;
public CustomReactiveLoadBalancerClientFilter(LoadBalancerClientFactory clientFactory, GatewayLoadBalancerProperties properties) {
super(clientFactory, properties);
this.clientFactory = clientFactory;
this.properties = properties;
}
public int getOrder() {
return LOAD_BALANCER_CLIENT_FILTER_ORDER;
}
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
URI url = (URI)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR);
// 获取原始路由前缀 若包含 lb 则表示需要进行负载均衡
String schemePrefix = (String)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_SCHEME_PREFIX_ATTR);
if (url != null && ("lb".equals(url.getScheme()) || "lb".equals(schemePrefix))) {
ServerWebExchangeUtils.addOriginalRequestUrl(exchange, url);
if (log.isTraceEnabled()) {
log.trace(ReactiveLoadBalancerClientFilter.class.getSimpleName() + " url before: " + url);
}
URI requestUri = (URI)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR);
String serviceId = requestUri.getHost();
Set<LoadBalancerLifecycle> supportedLifecycleProcessors = LoadBalancerLifecycleValidator.getSupportedLifecycleProcessors(this.clientFactory.getInstances(serviceId, LoadBalancerLifecycle.class), RequestDataContext.class, ResponseData.class, ServiceInstance.class);
// 自定义请求,将参数ServerWebExchange作为context传递,使得自定义负载均衡器中可以使用
DefaultRequest<ServerWebExchange> lbRequest = new CustomRequest(exchange);
return this.choose(lbRequest, serviceId, supportedLifecycleProcessors).flatMap(response -> {
if (!response.hasServer()) {
supportedLifecycleProcessors.forEach((lifecycle) -> {
lifecycle.onComplete(new CompletionContext(Status.DISCARD, lbRequest, response));
});
throw NotFoundException.create(this.properties.isUse404(), "Unable to find instance for " + url.getHost());
} else {
ServiceInstance retrievedInstance = (ServiceInstance)response.getServer();
URI uri = exchange.getRequest().getURI();
String overrideScheme = retrievedInstance.isSecure() ? "https" : "http";
if (schemePrefix != null) {
overrideScheme = url.getScheme();
}
DelegatingServiceInstance serviceInstance = new DelegatingServiceInstance(retrievedInstance, overrideScheme);
URI requestUrl = this.reconstructURI(serviceInstance, uri);
if (log.isTraceEnabled()) {
log.trace("LoadBalancerClientFilter url chosen: " + requestUrl);
}
exchange.getAttributes().put(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR, requestUrl);
exchange.getAttributes().put(ServerWebExchangeUtils.GATEWAY_LOADBALANCER_RESPONSE_ATTR, response);
supportedLifecycleProcessors.forEach((lifecycle) -> {
lifecycle.onStartRequest(lbRequest, response);
});
}
return chain.filter(getMutatedExchange(exchange));
});
} else {
return chain.filter(exchange);
}
}
private ResponseData buildResponseData(ServerWebExchange exchange, boolean useRawStatusCodes) {
return useRawStatusCodes ? new ResponseData(new RequestData(exchange.getRequest()), exchange.getResponse()) : new ResponseData(exchange.getResponse(), new RequestData(exchange.getRequest()));
}
protected URI reconstructURI(ServiceInstance serviceInstance, URI original) {
return LoadBalancerUriTools.reconstructURI(serviceInstance, original);
}
/**
* 获取对应负载均衡器并调用choose方法
*
* @param lbRequest
* @param serviceId
* @param supportedLifecycleProcessors
* @return
*/
private Mono<Response<ServiceInstance>> choose(Request<ServerWebExchange> lbRequest, String serviceId, Set<LoadBalancerLifecycle> supportedLifecycleProcessors) {
ReactorLoadBalancer<ServiceInstance> loadBalancer = (ReactorLoadBalancer)this.clientFactory.getInstance(serviceId, ReactorServiceInstanceLoadBalancer.class);
if (loadBalancer == null) {
throw new NotFoundException("No loadbalancer available for " + serviceId);
} else {
supportedLifecycleProcessors.forEach((lifecycle) -> {
lifecycle.onStart(lbRequest);
});
return loadBalancer.choose(lbRequest);
}
}
/**
* 若为指定接口,则方法body会被流式获取,由于该body只能获取一次,需要重新构造request
*
* @param exchange
* @return
*/
private ServerWebExchange getMutatedExchange(ServerWebExchange exchange) {
// 获取请求路径和请求方法类型
ServerHttpRequest context = exchange.getRequest();
String path = context.getURI().getPath();
HttpMethod method = context.getMethod();
// 若不为指定接口则不需要重置请求
if (...) {
return exchange;
}
// 负载均衡器中若获取了body过后会将值设置到Attribute中
String requestBody = exchange.getAttribute(RouteConstants.REQUEST_BODY);
if (Objects.isNull(requestBody)) {
return exchange;
}
// 构造重置请求
ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {
@Override
public Flux<DataBuffer> getBody() {
byte[] bytes = requestBody.getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
return Flux.just(buffer);
}
};
return exchange.mutate().request(mutatedRequest).build();
}
}
然后进行配置
@Configuration
public class CustomLoadBalanceClientConfiguration {
@Bean
public ReactorServiceInstanceLoadBalancer reactorServiceInstanceLoadBalancer(Environment environment, LoadBalancerClientFactory loadBalancerClientFactory, RedissonClient redissonClient) {
String name = environment.getProperty(LoadBalancerClientFactory.PROPERTY_NAME);
return new CustomLoadBalancer(name, loadBalancerClientFactory.
getLazyProvider(name, ServiceInstanceListSupplier.class), redissonClient);
}
@Bean
public ReactiveLoadBalancerClientFilter reactiveLoadBalancerClientFilter(LoadBalancerClientFactory clientFactory,
GatewayLoadBalancerProperties properties) {
return new CustomReactiveLoadBalancerClientFilter(clientFactory, properties);
}
}
最后在启动类或随意任意类上使用注解
@LoadBalancerClients(defaultConfiguration = CustomLoadBalanceClientConfiguration.class)
3.结尾
参考文章:
SpringCloud自定义负载均衡策略–LoadBalancer
gateway自定义负载均衡策略
springcloud的gateway之GlobalFilter获取请求信息及requestBody