Spark RPC架构源码解读

本文围绕Spark RPC展开,介绍其基本架构。spark - network - common模块借助Netty4实现核心RPC过程,涵盖TransportContext、TransportClientFactory等组件。还阐述了RPC协议实现,包括消息类型及处理逻辑。最后指出该模块封装了RPC请求处理逻辑,用于spark - core不同启动模式下的消息处理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Spark RPC基本架构

 spark-network-common模块通过Netty4的基础api实现了Spark中整个核心的RPC实现过程。Spark中RPC框架架构基本架构图如下:


 Spark的RPC框架所包含的各个组件如下:

  • TransportContext:传输上下文,包含了用于创建传输服务端(TransportServer)和传输客户端工厂(TransportClientFactory)的上下文信息,并支持使用TransportChannelHandler设置Netty提供的SocketChannel的Pipeline的实现。
  • TransportConf:传输上下文的配置信息。
  • RpcHandler:对调用传输客户端(TransportClient)的sendRPC方法发送的消息进行处理的程序。
  • MessageEncoder:在将消息放入管道前,先对消息内容进行编码,防止管道另一端读取时丢包和解析错误。
  • MessageDecoder:对从管道中读取的ByteBuf进行解析,防止丢包和解析错误;
  • TransportFrameDecoder:对从管道中读取的ByteBuf按照数据帧进行解析;
  • RpcResponseCallback:RpcHandler对请求的消息处理完毕后,进行回调的接口。
  • TransportClientFactory:创建传输客户端(TransportClient)的传输客户端工厂类。
  • ClientPool:在两个对等节点间维护的关于传输客户端(TransportClient)的池子。ClientPool是TransportClientFactory的内部组件。
  • TransportClient:RPC框架的客户端,用于获取预先协商好的流中的连续块。TransportClient旨在允许有效传输大量数据,这些数据将被拆分成几百KB到几MB的块。当TransportClient处理从流中获取的获取的块时,实际的设置是在传输层之外完成的。sendRPC方法能够在客户端和服务端的同一水平线的通信进行这些设置。
  • TransportClientBootstrap:当服务端响应客户端连接时在客户端执行一次的引导程序。
  • TransportRequestHandler:用于处理客户端的请求并在写完块数据后返回的处理程序。
  • TransportResponseHandler:用于处理服务端的响应,并且对发出请求的客户端进行响应的处理程序。
  • TransportChannelHandler:代理由TransportRequestHandler处理的请求和由TransportResponseHandler处理的响应,并加入传输层的处理。
  • TransportServerBootstrap:当客户端连接到服务端时在服务端执行一次的引导程序。
  • TransportServer:RPC框架的服务端,提供高效的、低级别的流服务。
TransportContext

 可以看出TransportContext是Spark RPC实现过程中的核心类,首先看看TransportContext中的基本属性和构造函数如下:

private final TransportConf conf;
private final RpcHandler rpcHandler;
private final boolean closeIdleConnections;

private final MessageEncoder encoder;
private final MessageDecoder decoder;

  public TransportContext(
      TransportConf conf,
      RpcHandler rpcHandler,
      boolean closeIdleConnections) {
    this.conf = conf;
    this.rpcHandler = rpcHandler;
    this.encoder = new MessageEncoder();
    this.decoder = new MessageDecoder();
    this.closeIdleConnections = closeIdleConnections;
  }

 TransportContext各个属性的作用如下:

  1. TransportConf:传输上下文的配置信息。
  2. RpcHandler接收处理TransportClient发送sendRPC() 方法发送的消息。
  3. MessageEncoder,MessageEncoder是对RPC消息的编码解析器。
  4. closeIdleConnections标志连接是否空闲。
    TransportContext核心方法如下:
/**创建TransportClientFactory*/
 public TransportClientFactory createClientFactory(List<TransportClientBootstrap> bootstraps) {
    return new TransportClientFactory(this, bootstraps);
  }

  /**创建TransportServer,绑定端口号port*/
  public TransportServer createServer(int port, List<TransportServerBootstrap> bootstraps) {
    return new TransportServer(this, null, port, rpcHandler, bootstraps);
  }

   /**创建TransportServer,绑定host*/
  public TransportServer createServer(
      String host, int port, List<TransportServerBootstrap> bootstraps) {
    return new TransportServer(this, host, port, rpcHandler, bootstraps);
  }

  /**初始化管道initializePipeline*/
  public TransportChannelHandler initializePipeline(
      SocketChannel channel,
      RpcHandler channelRpcHandler) {
    try {
      TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
      channel.pipeline()
      /**RPC消息处理链*/
        .addLast("encoder", encoder)
        .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
        .addLast("decoder", decoder)
        .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
        .addLast("handler", channelHandler);
      return channelHandler;
    } catch (RuntimeException e) {
      logger.error("Error while initializing Netty pipeline", e);
      throw e;
    }
  }

   /**创建TransportChannelHandler,绑定host*/
  private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
    TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
    TransportClient client = new TransportClient(channel, responseHandler);
    TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
      rpcHandler);
    return new TransportChannelHandler(client, responseHandler, requestHandler,
      conf.connectionTimeoutMs(), closeIdleConnections);
  }
TransportClientFactory

 TransportClientFactory是创建传输客户端(TransportClient)的工厂类。TransportClientFactory持有客户端的连接池,对于向相应的host,TransportClientFactory创建是相同的的传输客户端,对于所有的连接客户端之间,共享一个共同的线程池

客户端的连接池
/**客户端的连接池结构*/
private static class ClientPool {
    TransportClient[] clients;
    Object[] locks;

    ClientPool(int size) {
      clients = new TransportClient[size];
      locks = new Object[size];
      for (int i = 0; i < size; i++) {
        locks[i] = new Object();
      }
    }
  }

createClient是创建传输客户端的核心方法。
创建连接客户端,TransportClientFactory持有一个客户端连接的数组,随机从中选取一个槽,并创建连接客户端放入。TransportClientBootstrap创建新的连接客户端将会注册在TransportClientFactory中。

/**
  *执行步骤如下:
  *1.获取客户端连接池,如果没有,创建新的连接池
  *2.从连接池中获取客户端,如果没有,则创建新的连接
  */
 public TransportClient createClient(String remoteHost, int remotePort) throws IOException {

    final InetSocketAddress unresolvedAddress =
      InetSocketAddress.createUnresolved(remoteHost, remotePort);
      /**首先获取客户端连接池,如果未完成初始化,首先创建新的客户端池*/
    ClientPool clientPool = connectionPool.get(unresolvedAddress);
    if (clientPool == null) {
      connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
      clientPool = connectionPool.get(unresolvedAddress);
    }

   /**随机从客户端连接池中获取一个连接*/
    int clientIndex = rand.nextInt(numConnectionsPerPeer);
    TransportClient cachedClient = clientPool.clients[clientIndex];
    /**检查获取的连接是否有效,以及channel能否写数据*/
    if (cachedClient != null && cachedClient.isActive()) {
      /**获取TransportChannelHandler,用于选取requests,responses消息处理hadler*/
      TransportChannelHandler handler = cachedClient.getChannel().pipeline()
        .get(TransportChannelHandler.class);
      synchronized (handler) {
        handler.getResponseHandler().updateTimeOfLastRequest();
      }

      if (cachedClient.isActive()) {
        logger.trace("Returning cached connection to {}: {}",
          cachedClient.getSocketAddress(), cachedClient);
        return cachedClient;
      }
    }

   /**执行到这,说明没有从客户端连接池中获取可用的连接池,此时创建新的连接池*/
    final long preResolveHost = System.nanoTime();
    final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);

    synchronized (clientPool.locks[clientIndex]) {
      cachedClient = clientPool.clients[clientIndex];

      if (cachedClient != null) {
        if (cachedClient.isActive()) {
          logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient);
          return cachedClient;
        } else {
          logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
        }
      }
      clientPool.clients[clientIndex] = createClient(resolvedAddress);
      return clientPool.clients[clientIndex];
    }
  }
Spark RPC Protocol实现
RPC消息Message

 spark-network-common模块对RPC消息Message接口定义如下:

public interface Message extends Encodable {
  /** 判断消息类型 */
  Type type();

  /** 判断消息体 */
  ManagedBuffer body();

  /**用于判断消息的主体是否包含在消息的同一帧中*/
  boolean isBodyInFrame();

  /** RPC传输的所有可序列化消息类型 */
  enum Type implements Encodable {
    ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
    RpcRequest(3), RpcResponse(4), RpcFailure(5),
    StreamRequest(6), StreamResponse(7), StreamFailure(8),
    OneWayMessage(9), User(-1);

    private final byte id;

    Type(int id) {
      assert id < 128 : "Cannot have more than 128 message types";
      this.id = (byte) id;
    }

    public byte id() { return id; }

    @Override public int encodedLength() { return 1; }

    /** 消息序列化成ByteBuf */
    @Override public void encode(ByteBuf buf) { buf.writeByte(id); }

    /** 消息反序列化 */
    public static Type decode(ByteBuf buf) {
      byte id = buf.readByte();
      switch (id) {
        case 0: return ChunkFetchRequest;
        case 1: return ChunkFetchSuccess;
        case 2: return ChunkFetchFailure;
        case 3: return RpcRequest;
        case 4: return RpcResponse;
        case 5: return RpcFailure;
        case 6: return StreamRequest;
        case 7: return StreamResponse;
        case 8: return StreamFailure;
        case 9: return OneWayMessage;
        case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
        default: throw new IllegalArgumentException("Unknown message type: " + id);
      }
    }
  }
}

 Spark所有的消息类都直接或间接的实现了RequestMessage或ResponseMessage接口,其中RequestMessage的具体实现有四种,分别是:

  • ChunkFetchRequest:请求获取流的单个块的序列。
  • RpcRequest:此消息类型由远程的RPC服务端进行处理,是一种需要服务端向客户端回复的RPC请求信息类型。
  • OneWayMessage:此消息也需要由远程的RPC服务端进行处理,与RpcRequest不同的是不需要服务端向客户端回复。
  • StreamRequest:此消息表示向远程的服务发起请求,以获取流式数据。
     由于OneWayMessage 不需要响应,所以ResponseMessage的对于成功或失败状态的实现各有三种,分别是:
  • ChunkFetchSuccess:处理ChunkFetchRequest成功后返回的消息;
  • ChunkFetchFailure:处理ChunkFetchRequest失败后返回的消息;
  • RpcResponse:处理RpcRequest成功后返回的消息;
  • RpcFailure:处理RpcRequest失败后返回的消息;
  • StreamResponse:处理StreamRequest成功后返回的消息;
  • StreamFailure:处理StreamRequest失败后返回的消息;
消息处理Handler

 Spark RPC对所有处理的消息都做了编码和解码处理逻辑,Spark RPC的Handler处理逻辑如下图:


 其中较复杂的处理handler为TransportChannelHandler:代理由TransportRequestHandler处理的请求和由TransportResponseHandler处理的响应,并加入传输层的处理。

/***/
public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {

  private final TransportClient client;

  /**持有Request和Response的handler*/
  private final TransportResponseHandler responseHandler;
  private final TransportRequestHandler requestHandler;
  private final long requestTimeoutNs;
  private final boolean closeIdleConnections;

  public TransportChannelHandler(
      TransportClient client,
      TransportResponseHandler responseHandler,
      TransportRequestHandler requestHandler,
      long requestTimeoutMs,
      boolean closeIdleConnections) {
    this.client = client;
    this.responseHandler = responseHandler;
    this.requestHandler = requestHandler;
    this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000;
    this.closeIdleConnections = closeIdleConnections;
  }

  public TransportClient getClient() {
    return client;
  }

  @Override
  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
    logger.warn("Exception in connection from " + getRemoteAddress(ctx.channel()),
      cause);
    requestHandler.exceptionCaught(cause);
    responseHandler.exceptionCaught(cause);
    ctx.close();
  }

  @Override
  public void channelActive(ChannelHandlerContext ctx) throws Exception {
    try {
      requestHandler.channelActive();
    } catch (RuntimeException e) {
      logger.error("Exception from request handler while registering channel", e);
    }
    try {
      responseHandler.channelActive();
    } catch (RuntimeException e) {
      logger.error("Exception from response handler while registering channel", e);
    }
    super.channelRegistered(ctx);
  }

  @Override
  public void channelInactive(ChannelHandlerContext ctx) throws Exception {
    try {
      requestHandler.channelInactive();
    } catch (RuntimeException e) {
      logger.error("Exception from request handler while unregistering channel", e);
    }
    try {
      responseHandler.channelInactive();
    } catch (RuntimeException e) {
      logger.error("Exception from response handler while unregistering channel", e);
    }
    super.channelUnregistered(ctx);
  }

  /**
    *处理读取的request和response消息
   */
  @Override
  public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
    if (request instanceof RequestMessage) {
      requestHandler.handle((RequestMessage) request);
    } else {
      responseHandler.handle((ResponseMessage) request);
    }
  }

  @Override
  public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
    if (evt instanceof IdleStateEvent) {
      IdleStateEvent e = (IdleStateEvent) evt;
      synchronized (this) {
        boolean isActuallyOverdue =
          System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs;
        if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) {
          if (responseHandler.numOutstandingRequests() > 0) {
            String address = getRemoteAddress(ctx.channel());
            logger.error("Connection to {} has been quiet for {} ms while there are outstanding " +
              "requests. Assuming connection is dead; please adjust spark.network.timeout if " +
              "this is wrong.", address, requestTimeoutNs / 1000 / 1000);
            client.timeOut();
            ctx.close();
          } else if (closeIdleConnections) {
            // While CloseIdleConnections is enable, we also close idle connection
            client.timeOut();
            ctx.close();
          }
        }
      }
    }
    ctx.fireUserEventTriggered(evt);
  }

  public TransportResponseHandler getResponseHandler() {
    return responseHandler;
  }

}
TransportRequestHandler

 TransportRequestHandler核心处理逻辑如下:

/**
 *TransportRequestHandler核心处理逻辑如下
 */
@Override
  public void handle(RequestMessage request) {
    /**处理块请求消息*/
    if (request instanceof ChunkFetchRequest) {
      processFetchRequest((ChunkFetchRequest) request);
     /**处RPC请求消息*/  
    } else if (request instanceof RpcRequest) {
      processRpcRequest((RpcRequest) request);
      /**处理不需要返回的单方请求消息*/
    } else if (request instanceof OneWayMessage) {
      processOneWayMessage((OneWayMessage) request);
       /**处理流处理请求消息*/
    } else if (request instanceof StreamRequest) {
      processStreamRequest((StreamRequest) request);
    } else {
      throw new IllegalArgumentException("Unknown request type: " + request);
    }
  }
TransportResponseHandler

 TransportResponseHandler核心处理逻辑如下:

/**
 *TransportResponseHandler核心处理逻辑
 */
@Override
  public void handle(ResponseMessage message) throws Exception {

    /**返回块请求成功消息*/
    if (message instanceof ChunkFetchSuccess) {
      ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
      ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
      if (listener == null) {
        logger.warn("Ignoring response for block {} from {} since it is not outstanding",
          resp.streamChunkId, getRemoteAddress(channel));
        resp.body().release();
      } else {
        outstandingFetches.remove(resp.streamChunkId);
        listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
        resp.body().release();
      }

     /**处理块请求失败消息*/
    } else if (message instanceof ChunkFetchFailure) {
      ChunkFetchFailure resp = (ChunkFetchFailure) message;
      ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
      if (listener == null) {
        logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding",
          resp.streamChunkId, getRemoteAddress(channel), resp.errorString);
      } else {
        outstandingFetches.remove(resp.streamChunkId);
        listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException(
          "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
      }
    /**处理RPC请求成功消息*/
    } else if (message instanceof RpcResponse) {
      RpcResponse resp = (RpcResponse) message;
      RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
      if (listener == null) {
        logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
          resp.requestId, getRemoteAddress(channel), resp.body().size());
      } else {
        outstandingRpcs.remove(resp.requestId);
        try {
          listener.onSuccess(resp.body().nioByteBuffer());
        } finally {
          resp.body().release();
        }
      }

    /**处理RPC请求失败消息*/
    } else if (message instanceof RpcFailure) {
      RpcFailure resp = (RpcFailure) message;
      RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
      if (listener == null) {
        logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
          resp.requestId, getRemoteAddress(channel), resp.errorString);
      } else {
        outstandingRpcs.remove(resp.requestId);
        listener.onFailure(new RuntimeException(resp.errorString));
      }

     /**处理流请求成功消息*/
    } else if (message instanceof StreamResponse) {
      StreamResponse resp = (StreamResponse) message;
      StreamCallback callback = streamCallbacks.poll();
      if (callback != null) {
        if (resp.byteCount > 0) {
          StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
            callback);
          try {
            TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
              channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
            frameDecoder.setInterceptor(interceptor);
            streamActive = true;
          } catch (Exception e) {
            logger.error("Error installing stream handler.", e);
            deactivateStream();
          }
        } else {
          try {
            callback.onComplete(resp.streamId);
          } catch (Exception e) {
            logger.warn("Error in stream handler onComplete().", e);
          }
        }
      } else {
        logger.error("Could not find callback for StreamResponse.");
      }

     /**处理流请求失败消息*/  
    } else if (message instanceof StreamFailure) {
      StreamFailure resp = (StreamFailure) message;
      StreamCallback callback = streamCallbacks.poll();
      if (callback != null) {
        try {
          callback.onFailure(resp.streamId, new RuntimeException(resp.error));
        } catch (IOException ioe) {
          logger.warn("Error in stream failure handler.", ioe);
        }
      } else {
        logger.warn("Stream failure with unknown callback: {}", resp.error);
      }
    } else {
      throw new IllegalStateException("Unknown response type: " + message.type());
    }
  }
总结

 spark-network-common模块使用Netty4实现封装了Spark核心RPC请求处理消息逻辑封装,spark-core核心启动过程,使用spark-network-common模块完成spark-core不同启动模式下RPC消息处理。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值