本文基于 Java 11+实现
构建可靠的分布式系统时,一致性问题是核心挑战之一。ZooKeeper 的 ZAB 协议和 Paxos 算法作为两种主流解决方案,在理论基础和工程实现上各有特点。本文深入分析它们的实现机制、性能特性和最佳实践。
一、基本概念
ZAB 协议
ZAB (ZooKeeper Atomic Broadcast) 是专为 ZooKeeper 设计的分布式一致性协议,核心目标是保证分布式系统中数据更新的原子性和顺序一致性。
Paxos 算法
Paxos 是 Leslie Lamport 提出的通用分布式一致性算法,是众多分布式系统的理论基础,解决的是在不可靠网络中如何达成共识的问题。
二、ZAB 协议实现
ZAB 协议工作在两种模式下:
- 恢复模式:系统启动或 Leader 崩溃时触发
- 广播模式:正常运行时处理写请求
核心接口定义
public interface ZabProcessor {
// 恢复模式接口
boolean startRecovery() throws RecoveryException;
// 广播模式接口
CompletableFuture<Boolean> processWrite(Request request);
CompletableFuture<Result> processRead(String key, ConsistencyLevel level);
// 状态查询接口
boolean isLeader();
long getCurrentEpoch();
}
public interface NetworkClient {
// 基础网络通信接口
void connect(String serverId, String address, int port) throws IOException;
void disconnect(String serverId);
// ZAB协议消息
ACK sendProposal(String serverId, ProposalPacket proposal) throws IOException;
void sendCommit(String serverId, CommitPacket commit) throws IOException;
LastZxidResponse sendEpochRequest(String serverId, EpochPacket epochPkt) throws IOException;
boolean sendTruncate(String serverId, TruncatePacket truncPkt) throws IOException;
boolean sendTransactions(String serverId, List<Transaction> txns) throws IOException;
boolean sendNewLeader(String serverId, NewLeaderPacket newLeaderPkt) throws IOException;
void sendHeartbeat(String serverId, long zxid) throws IOException;
void sendSnapshot(String serverId, byte[] snapshot, long zxid) throws IOException;
}
public interface StateMachine {
void apply(long zxid, byte[] command) throws Exception;
long getLastAppliedZxid();
byte[] takeSnapshot() throws Exception;
void restoreSnapshot(byte[] snapshot, long zxid) throws Exception;
}
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
ZAB 恢复模式实现
public class ZABRecovery {
private final AtomicLong zxid = new AtomicLong(0);
private final AtomicInteger epoch = new AtomicInteger(0);
private volatile ServerState state = ServerState.LOOKING;
private final Logger logger = LoggerFactory.getLogger(ZABRecovery.class);
private final ConcurrentMap<String, ServerData> serverDataMap;
private final int quorumSize;
private final NetworkClient networkClient;
private final StateMachine stateMachine;
private final String serverId;
// 构造函数
public ZABRecovery(String serverId, int quorumSize, NetworkClient networkClient,
StateMachine stateMachine) {
this.serverId = serverId;
this.quorumSize = quorumSize;
this.networkClient = networkClient;
this.stateMachine = stateMachine;
this.serverDataMap = new ConcurrentHashMap<>();
}
// Leader恢复流程
public boolean startRecovery() throws RecoveryException {
MDC.put("component", "zab-recovery");
MDC.put("serverId", serverId);
try {
// 1. 更新选举轮次
int newEpoch = epoch.incrementAndGet();
logger.info("Starting recovery with epoch: {}", newEpoch);
// 2. 发现阶段:收集所有Follower状态
Map<Long, Set<String>> commitMap = discoverFollowerStates();
// 3. 确定截断点和提交点
long truncateZxid = determineMaxCommittedZxid(commitMap);
logger.info("Determined truncate zxid: {}", Long.toHexString(truncateZxid));
// 4. 解决可能的冲突(脑裂后)
resolveConflictsAfterPartition(truncateZxid, commitMap);
// 5. 同步阶段:将历史事务同步给Follower
syncFollowers(truncateZxid);
// 6. 切换到广播模式
state = ServerState.LEADING;
logger.info("Recovery completed, switching to broadcast mode");
return true;
} catch (IOException e) {
logger.error("Recovery failed due to I/O error", e);
state = ServerState.LOOKING;
throw new RecoveryException("I/O error during recovery", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.warn("Recovery interrupted", e);
state = ServerState.LOOKING;
throw new RecoveryException("Recovery process interrupted", e);
} catch (Exception e) {
logger.error("Unexpected error during recovery", e);
state = ServerState.LOOKING;
throw new RecoveryException("Unexpected error during recovery", e);
} finally {
MDC.remove("component");
MDC.remove("serverId");
}
}
// 发现阶段:收集所有Follower的最新事务信息
private Map<Long, Set<String>> discoverFollowerStates() throws IOException, InterruptedException {
Map<Long, Set<String>> acceptedZxids = new ConcurrentHashMap<>();
CountDownLatch latch = new CountDownLatch(serverDataMap.size());
List<CompletableFuture<?>> futures = new ArrayList<>();
// 向所有Follower发送CEPOCH消息
for (var entry : serverDataMap.entrySet()) {
final String targetServerId = entry.getKey();
final ServerData serverData = entry.getValue();
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
MDC.put("targetServerId", targetServerId);
try {
// 发送新的epoch
EpochPacket epochPkt = new EpochPacket(epoch.get());
LastZxidResponse response = networkClient.sendEpochRequest(
targetServerId, epochPkt);
// 记录该服务器的最新zxid
synchronized (acceptedZxids) {
acceptedZxids.computeIfAbsent(response.getLastZxid(), k -> new HashSet<>())
.add(targetServerId);
}
logger.info("Server {} last zxid: {}", targetServerId,
Long.toHexString(response.getLastZxid()));
} catch (IOException e) {
logger.error("Failed to discover state from server: {}", targetServerId, e);
} finally {
MDC.remove("targetServerId");
latch.countDown();
}
});
futures.add(future);
}
// 等待大多数响应或超时
if (!latch.await(10, TimeUnit.SECONDS)) {
logger.warn("Discovery phase timed out, proceeding with available responses");
}
// 取消未完成的任务
for (CompletableFuture<?> future : futures) {
if (!future.isDone()) {
future.cancel(true);
}
}
return acceptedZxids;
}
// 确定需要保留的最大已提交事务ID
private long determineMaxCommittedZxid(Map<Long, Set<String>> commitMap) {
// 寻找被多数派确认的最大ZXID
long maxZxid = 0;
int quorum = getQuorum();
for (var entry : commitMap.entrySet()) {
if (entry.getValue().size() >= quorum && entry.getKey() > maxZxid) {
maxZxid = entry.getKey();
}
}
return maxZxid;
}
// 解决网络分区后可能的数据冲突
private void resolveConflictsAfterPartition(long truncateZxid,
Map<Long, Set<String>> commitMap) {
logger.info("Checking for potential conflicts after network partition");
// 1. 识别潜在冲突事务 - 那些不在多数派中的更高zxid
List<ConflictingTransaction> conflicts = new ArrayList<>();
for (var entry : commitMap.entrySet()) {
long txnZxid = entry.getKey();
Set<String> servers = entry.getValue();
// 如果zxid大于已确定的截断点,但不是多数派确认的
if (txnZxid > truncateZxid && servers.size() < getQuorum()) {
// 获取事务的epoch
int txnEpoch = ZxidUtils.getEpochFromZxid(txnZxid);
int truncateEpoch = ZxidUtils.getEpochFromZxid(truncateZxid);
conflicts.add(new ConflictingTransaction(txnZxid, truncateZxid,
txnEpoch, truncateEpoch,
servers));
}
}
// 2. 处理冲突
if (!conflicts.isEmpty()) {
logger.warn("Found {} potential conflicting transactions after partition",
conflicts.size());
for (ConflictingTransaction conflict : conflicts) {
if (conflict.isFromHigherEpoch()) {
logger.warn("Conflict: transaction with zxid {} from higher epoch {} " +
"found but not in majority. Will be discarded.",
Long.toHexString(conflict.getConflictZxid()),
conflict.getConflictEpoch());
} else {
logger.warn("Conflict: transaction with zxid {} from same epoch {} " +
"found but not in majority. Will be discarded.",
Long.toHexString(conflict.getConflictZxid()),
conflict.getConflictEpoch());
}
// 通知这些服务器截断这些事务
notifyServersToTruncate(conflict.getServers(), truncateZxid);
}
} else {
logger.info("No conflicting transactions found");
}
}
// 通知服务器截断超出安全点的事务
private void notifyServersToTruncate(Set<String> servers, long truncateZxid) {
for (String serverId : servers) {
CompletableFuture.runAsync(() -> {
try {
TruncatePacket truncPkt = new TruncatePacket(truncateZxid);
boolean success = networkClient.sendTruncate(serverId, truncPkt);
if (success) {
logger.info("Successfully notified server {} to truncate to zxid {}",
serverId, Long.toHexString(truncateZxid));
} else {
logger.warn("Failed to notify server {} to truncate", serverId);
}
} catch (IOException e) {
logger.error("Error notifying server {} to truncate", serverId, e);
}
});
}
}
// 同步阶段:将历史事务同步给Follower
private void syncFollowers(long truncateZxid) throws IOException, InterruptedException {
// 获取从truncateZxid开始的所有事务
List<Transaction> txns = loadTransactionsFromLog(truncateZxid);
logger.info("Syncing {} transactions to followers", txns.size());
// 并行同步给所有Follower
CountDownLatch syncLatch = new CountDownLatch(serverDataMap.size());
AtomicInteger successCount = new AtomicInteger(0);
List<CompletableFuture<?>> futures = new ArrayList<>();
for (var entry : serverDataMap.entrySet()) {
final String targetServerId = entry.getKey();
final ServerData serverData = entry.getValue();
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
MDC.put("targetServerId", targetServerId);
try {
// 检查Follower是否需要使用快照追赶
long followerZxid = serverData.getLastZxid();
if (truncateZxid - followerZxid > SNAPSHOT_THRESHOLD) {
syncFollowerWithSnapshot(targetServerId, followerZxid);
} else {
// 1. 发送TRUNC命令,通知Follower截断日志
TruncatePacket truncPkt = new TruncatePacket(truncateZxid);
if (networkClient.sendTruncate(targetServerId, truncPkt)) {
// 2. 发送DIFF命令,同步缺失的事务
if (networkClient.sendTransactions(targetServerId, txns)) {
// 3. 发送NEWLEADER命令,确认同步完成
NewLeaderPacket newLeaderPkt = new NewLeaderPacket(epoch.get());
if (networkClient.sendNewLeader(targetServerId, newLeaderPkt)) {
// 同步成功
successCount.incrementAndGet();
logger.info("Successfully synced server: {}", targetServerId);
}
}
}
}
} catch (IOException e) {
logger.error("Failed to sync server {} with {} transactions, last zxid: {}",
targetServerId, txns.size(), Long.toHexString(truncateZxid), e);
} finally {
MDC.remove("targetServerId");
syncLatch.countDown();
}
});
futures.add(future);
}
// 等待同步完成或超时
if (!syncLatch.await(30, TimeUnit.SECONDS)) {
logger.warn("Sync phase timed out");
}
// 取消未完成的任务
for (CompletableFuture<?> future : futures) {
if (!future.isDone()) {
future.cancel(true);
}
}
// 检查是否有足够的服务器同步成功
if (successCount.get() < quorumSize) {
throw new QuorumNotFoundException("Failed to sync with quorum of followers",
successCount.get(), quorumSize);
}
}
// 使用快照同步落后太多的Follower
private void syncFollowerWithSnapshot(String followerId, long followerZxid) throws IOException {
try {
logger.info("Follower {} is too far behind (zxid: {}), syncing with snapshot",
followerId, Long.toHexString(followerZxid));
// 1. 获取当前状态快照
byte[] snapshot = stateMachine.takeSnapshot();
// 2. 发送快照给Follower
networkClient.sendSnapshot(followerId, snapshot, zxid.get());
logger.info("Successfully sent snapshot to follower: {}", followerId);
} catch (Exception e) {
logger.error("Failed to sync follower {} with snapshot", followerId, e);
throw new IOException("Snapshot sync failed", e);
}
}
// 从事务日志加载事务
private List<Transaction> loadTransactionsFromLog(long fromZxid) throws IOException {
List<Transaction> result = new ArrayList<>();
// 实际实现会从持久化存储读取事务记录
logger.info("Loading transactions starting from zxid: {}", Long.toHexString(fromZxid));
return result;
}
private int getQuorum() {
return quorumSize / 2 + 1;
}
// 常量定义
private static final long SNAPSHOT_THRESHOLD = 100000; // 事务差距超过10万时使用快照
// 冲突事务数据结构
static class ConflictingTransaction {
private final long conflictZxid;
private final long truncateZxid;
private final int conflictEpoch;
private final int truncateEpoch;
private final Set<String> servers;
public ConflictingTransaction(long conflictZxid, long truncateZxid,
int conflictEpoch, int truncateEpoch,
Set<String> servers) {
this.conflictZxid = conflictZxid;
this.truncateZxid = truncateZxid;
this.conflictEpoch = conflictEpoch;
this.truncateEpoch = truncateEpoch;
this.servers = new HashSet<>(servers);
}
public boolean isFromHigherEpoch() {
return conflictEpoch > truncateEpoch;
}
public long getConflictZxid() {
return conflictZxid;
}
public int getConflictEpoch() {
return conflictEpoch;
}
public Set<String> getServers() {
return Collections.unmodifiableSet(servers);
}
}
// 其他内部类定义...
enum ServerState {
LOOKING, // 寻找Leader
FOLLOWING, // Follower角色
LEADING // Leader角色
}
}
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.
- 102.
- 103.
- 104.
- 105.
- 106.
- 107.
- 108.
- 109.
- 110.
- 111.
- 112.
- 113.
- 114.
- 115.
- 116.
- 117.
- 118.
- 119.
- 120.
- 121.
- 122.
- 123.
- 124.
- 125.
- 126.
- 127.
- 128.
- 129.
- 130.
- 131.
- 132.
- 133.
- 134.
- 135.
- 136.
- 137.
- 138.
- 139.
- 140.
- 141.
- 142.
- 143.
- 144.
- 145.
- 146.
- 147.
- 148.
- 149.
- 150.
- 151.
- 152.
- 153.
- 154.
- 155.
- 156.
- 157.
- 158.
- 159.
- 160.
- 161.
- 162.
- 163.
- 164.
- 165.
- 166.
- 167.
- 168.
- 169.
- 170.
- 171.
- 172.
- 173.
- 174.
- 175.
- 176.
- 177.
- 178.
- 179.
- 180.
- 181.
- 182.
- 183.
- 184.
- 185.
- 186.
- 187.
- 188.
- 189.
- 190.
- 191.
- 192.
- 193.
- 194.
- 195.
- 196.
- 197.
- 198.
- 199.
- 200.
- 201.
- 202.
- 203.
- 204.
- 205.
- 206.
- 207.
- 208.
- 209.
- 210.
- 211.
- 212.
- 213.
- 214.
- 215.
- 216.
- 217.
- 218.
- 219.
- 220.
- 221.
- 222.
- 223.
- 224.
- 225.
- 226.
- 227.
- 228.
- 229.
- 230.
- 231.
- 232.
- 233.
- 234.
- 235.
- 236.
- 237.
- 238.
- 239.
- 240.
- 241.
- 242.
- 243.
- 244.
- 245.
- 246.
- 247.
- 248.
- 249.
- 250.
- 251.
- 252.
- 253.
- 254.
- 255.
- 256.
- 257.
- 258.
- 259.
- 260.
- 261.
- 262.
- 263.
- 264.
- 265.
- 266.
- 267.
- 268.
- 269.
- 270.
- 271.
- 272.
- 273.
- 274.
- 275.
- 276.
- 277.
- 278.
- 279.
- 280.
- 281.
- 282.
- 283.
- 284.
- 285.
- 286.
- 287.
- 288.
- 289.
- 290.
- 291.
- 292.
- 293.
- 294.
- 295.
- 296.
- 297.
- 298.
- 299.
- 300.
- 301.
- 302.
- 303.
- 304.
- 305.
- 306.
- 307.
- 308.
- 309.
- 310.
- 311.
- 312.
- 313.
- 314.
- 315.
- 316.
- 317.
- 318.
- 319.
- 320.
- 321.
- 322.
- 323.
- 324.
- 325.
- 326.
- 327.
- 328.
- 329.
- 330.
- 331.
- 332.
- 333.
- 334.
- 335.
- 336.
- 337.
- 338.
- 339.
- 340.
- 341.
- 342.
- 343.
- 344.
- 345.
- 346.
- 347.
- 348.
- 349.
ZAB 广播模式实现
public class ZABBroadcast implements AutoCloseable {
private final AtomicLong zxid;
private final AtomicInteger epoch;
private final ConcurrentMap<String, ServerData> followers;
private final Logger logger = LoggerFactory.getLogger(ZABBroadcast.class);
private final CircuitBreaker circuitBreaker;
private final NetworkClient networkClient;
private final StateMachine stateMachine;
private final String serverId;
private final ReadWriteLock rwLock = new ReentrantReadWriteLock();
private final ScheduledExecutorService scheduler;
private final MetricsCollector metrics;
private final RateLimiter heartbeatLogLimiter = RateLimiter.create(0.1); // 每10秒最多一条日志
public ZABBroadcast(String serverId, AtomicLong zxid, AtomicInteger epoch,
NetworkClient networkClient, StateMachine stateMachine) {
this.serverId = serverId;
this.zxid = zxid;
this.epoch = epoch;
this.networkClient = networkClient;
this.stateMachine = stateMachine;
this.followers = new ConcurrentHashMap<>();
this.circuitBreaker = new CircuitBreaker(5, 10000); // 5次失败,10秒重置
this.scheduler = Executors.newScheduledThreadPool(2, r -> {
Thread t = new Thread(r, "zab-scheduler-" + serverId);
t.setDaemon(true);
return t;
});
this.metrics = new MetricsCollector("zab_broadcast");
// 启动心跳任务
scheduler.scheduleWithFixedDelay(this::sendHeartbeats,
500, 500, TimeUnit.MILLISECONDS);
}
// 添加Follower
public void addFollower(ServerData follower) {
followers.put(follower.getId(), follower);
}
// Leader处理写请求
public CompletableFuture<Boolean> processWrite(Request request) {
Stopwatch stopwatch = Stopwatch.createStarted();
MDC.put("component", "zab-broadcast");
MDC.put("serverId", serverId);
MDC.put("requestId", request.getId());
try {
return GlobalExceptionHandler.withExceptionHandling(
circuitBreaker.execute(() -> {
try {
// 1. 为请求生成zxid (高32位是epoch,低32位是计数器)
long newZxid = createNewZxid();
MDC.put("zxid", Long.toHexString(newZxid));
logger.info("Processing write request: {} with zxid: {}",
request.getId(), Long.toHexString(newZxid));
// 2. 将请求发送给所有Follower
List<Future<ACK>> futures = sendToFollowers(request, newZxid);
// 3. 等待过半Follower的ACK
if (waitForMajority(futures)) {
// 4. 通知所有Follower提交事务
commit(newZxid);
logger.info("Request {} committed successfully", request.getId());
// 5. 记录指标
metrics.recordSuccessfulWrite(stopwatch.elapsed(TimeUnit.MILLISECONDS));
return CompletableFuture.completedFuture(true);
} else {
logger.warn("Failed to get majority ACKs for request {}", request.getId());
metrics.recordFailedWrite();
return CompletableFuture.completedFuture(false);
}
} catch (IOException e) {
logger.error("Failed to process write request: {}", request.getId(), e);
metrics.recordFailedWrite();
return CompletableFuture.failedFuture(
new ProcessingException("Failed to process write request", e));
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.warn("Interrupted while processing write request: {}", request.getId(), e);
metrics.recordFailedWrite();
return CompletableFuture.failedFuture(
new ProcessingException("Interrupted during write processing", e));
}
})
);
} catch (CircuitBreakerOpenException e) {
logger.error("Circuit breaker is open, rejecting request: {}", request.getId());
metrics.recordRejectedWrite();
return CompletableFuture.failedFuture(
new ProcessingException("Circuit breaker open, system overloaded", e));
} finally {
MDC.remove("component");
MDC.remove("serverId");
MDC.remove("requestId");
MDC.remove("zxid");
}
}
// 处理批量写请求,提高吞吐量
public CompletableFuture<Map<String, Boolean>> processBatchWrite(List<Request> requests) {
if (requests.isEmpty()) {
return CompletableFuture.completedFuture(Collections.emptyMap());
}
Stopwatch stopwatch = Stopwatch.createStarted();
MDC.put("component", "zab-broadcast");
MDC.put("serverId", serverId);
MDC.put("batchSize", String.valueOf(requests.size()));
try {
return GlobalExceptionHandler.withExceptionHandling(
circuitBreaker.execute(() -> {
Map<String, Boolean> results = new HashMap<>();
try {
// 创建批处理包
BatchRequest batch = new BatchRequest();
for (Request req : requests) {
batch.addRequest(req);
results.put(req.getId(), false); // 默认失败
}
// 为批次生成一个zxid
long batchZxid = createNewZxid();
MDC.put("zxid", Long.toHexString(batchZxid));
logger.info("Processing batch of {} requests with zxid: {}",
requests.size(), Long.toHexString(batchZxid));
// 发送批处理请求给所有Follower
List<Future<ACK>> futures = sendBatchToFollowers(batch, batchZxid);
// 等待多数派确认
if (waitForMajority(futures)) {
// 提交批次
commitBatch(batchZxid);
logger.info("Batch with {} requests committed successfully", requests.size());
// 设置所有请求结果为成功
for (Request req : requests) {
results.put(req.getId(), true);
}
metrics.recordSuccessfulBatchWrite(
requests.size(), stopwatch.elapsed(TimeUnit.MILLISECONDS));
} else {
logger.warn("Failed to get majority ACKs for batch");
metrics.recordFailedBatchWrite(requests.size());
}
} catch (Exception e) {
logger.error("Error processing batch write of {} requests", requests.size(), e);
metrics.recordFailedBatchWrite(requests.size());
}
return CompletableFuture.completedFuture(results);
})
);
} catch (CircuitBreakerOpenException e) {
logger.error("Circuit breaker is open, rejecting batch of {} requests", requests.size());
metrics.recordRejectedBatchWrite(requests.size());
Map<String, Boolean> results = new HashMap<>();
for (Request req : requests) {
results.put(req.getId(), false);
}
return CompletableFuture.failedFuture(
new ProcessingException("Circuit breaker open, system overloaded", e));
} finally {
MDC.remove("component");
MDC.remove("serverId");
MDC.remove("batchSize");
MDC.remove("zxid");
}
}
// 读取操作的一致性保证
public CompletableFuture<Result> readWithConsistency(String key, ConsistencyLevel level) {
Stopwatch stopwatch = Stopwatch.createStarted();
MDC.put("component", "zab-broadcast");
MDC.put("serverId", serverId);
MDC.put("key", key);
MDC.put("consistency", level.name());
try {
ReadStrategy strategy = readStrategies.getOrDefault(
level, readStrategies.get(ConsistencyLevel.EVENTUAL));
CompletableFuture<Result> result = strategy.execute(key, this::readLocal);
result.thenAccept(r ->
metrics.recordRead(level, stopwatch.elapsed(TimeUnit.MILLISECONDS)));
return result;
} catch (Exception e) {
logger.error("Error performing {} read for key: {}", level, key, e);
metrics.recordFailedRead(level);
return CompletableFuture.failedFuture(
new ProcessingException("Read operation failed", e));
} finally {
MDC.remove("component");
MDC.remove("serverId");
MDC.remove("key");
MDC.remove("consistency");
}
}
// 本地读取数据
private Result readLocal(String key) {
rwLock.readLock().lock();
try {
// 实际实现会从本地数据库读取
return new Result(key, "value", true);
} finally {
rwLock.readLock().unlock();
}
}
// 生成新的zxid,处理溢出情况
private long createNewZxid() {
rwLock.writeLock().lock();
try {
long currentCounter = zxid.get() & 0xFFFFFFFFL;
// 检测溢出并处理
if (currentCounter >= 0xFFFFFFFFL) {
// 计数器即将溢出,增加epoch
int newEpoch = epoch.incrementAndGet();
logger.warn("ZXID counter overflow, incrementing epoch to {}", newEpoch);
long newZxid = ((long)newEpoch << 32); // 重置计数器
zxid.set(newZxid);
return newZxid;
}
return zxid.incrementAndGet();
} finally {
rwLock.writeLock().unlock();
}
}
// 发送提案给所有Follower
private List<Future<ACK>> sendToFollowers(Request request, long newZxid)
throws IOException {
List<Future<ACK>> futures = new ArrayList<>();
ProposalPacket proposal = new ProposalPacket(newZxid, request);
ExecutorService executor = Executors.newFixedThreadPool(followers.size(),
r -> {
Thread t = new Thread(r, "proposal-sender-" + serverId);
t.setDaemon(true);
return t;
});
try {
for (var entry : followers.entrySet()) {
final String targetServerId = entry.getKey();
futures.add(executor.submit(() -> {
MDC.put("targetServerId", targetServerId);
try {
ACK ack = networkClient.sendProposal(targetServerId, proposal);
logger.debug("Received ACK from {} for zxid {}",
targetServerId, Long.toHexString(newZxid));
return ack;
} catch (IOException e) {
logger.error("Failed to send proposal to follower {}, zxid: {}",
targetServerId, Long.toHexString(newZxid), e);
return null;
} finally {
MDC.remove("targetServerId");
}
}));
}
} finally {
executor.shutdown();
try {
if (!executor.awaitTermination(200, TimeUnit.MILLISECONDS)) {
List<Runnable> pendingTasks = executor.shutdownNow();
logger.warn("Force shutdown executor with {} pending tasks", pendingTasks.size());
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.warn("Interrupted while waiting for executor to terminate");
}
}
return futures;
}
// 等待多数派响应
private boolean waitForMajority(List<Future<ACK>> futures)
throws InterruptedException {
int ackCount = 0;
int majority = (followers.size() / 2) + 1;
for (Future<ACK> future : futures) {
try {
ACK ack = future.get(5, TimeUnit.SECONDS);
if (ack != null && ack.isSuccess()) {
ackCount++;
if (ackCount >= majority) {
// 已获得多数派确认,可以提前返回
return true;
}
}
} catch (ExecutionException e) {
logger.warn("Error getting ACK", e.getCause());
} catch (TimeoutException e) {
logger.warn("Timeout waiting for ACK");
}
}
return ackCount >= majority;
}
// 通知所有Follower提交事务
private void commit(long zxid) throws IOException {
CommitPacket commit = new CommitPacket(zxid);
for (var entry : followers.entrySet()) {
final String targetServerId = entry.getKey();
CompletableFuture.runAsync(() -> {
MDC.put("targetServerId", targetServerId);
try {
networkClient.sendCommit(targetServerId, commit);
logger.debug("Sent commit to {} for zxid {}",
targetServerId, Long.toHexString(zxid));
} catch (IOException e) {
logger.error("Failed to send commit to follower {}, zxid: {}",
targetServerId, Long.toHexString(zxid), e);
} finally {
MDC.remove("targetServerId");
}
});
}
}
// 发送批处理请求
private List<Future<ACK>> sendBatchToFollowers(BatchRequest batch, long batchZxid)
throws IOException {
ProposalPacket proposal = new ProposalPacket(batchZxid, batch);
return sendProposalToFollowers(proposal, batchZxid);
}
// 提交批处理请求
private void commitBatch(long batchZxid) throws IOException {
commit(batchZxid);
}
// 发送心跳给所有Follower
private void sendHeartbeats() {
long currentZxid = zxid.get();
for (var entry : followers.entrySet()) {
final String targetServerId = entry.getKey();
CompletableFuture.runAsync(() -> {
try {
networkClient.sendHeartbeat(targetServerId, currentZxid);
} catch (IOException e) {
// 心跳失败,使用限流器避免日志泛滥
if (heartbeatLogLimiter.tryAcquire()) {
logger.debug("Failed to send heartbeat to {}", targetServerId, e);
}
}
});
}
}
// 发送提案给所有Follower(通用方法)
private List<Future<ACK>> sendProposalToFollowers(ProposalPacket proposal, long zxid)
throws IOException {
List<Future<ACK>> futures = new ArrayList<>();
ExecutorService executor = Executors.newFixedThreadPool(followers.size(),
r -> {
Thread t = new Thread(r, "proposal-sender-" + serverId);
t.setDaemon(true);
return t;
});
try {
for (var entry : followers.entrySet()) {
final String targetServerId = entry.getKey();
futures.add(executor.submit(() -> {
MDC.put("targetServerId", targetServerId);
try {
ACK ack = networkClient.sendProposal(targetServerId, proposal);
logger.debug("Received ACK from {} for zxid {}",
targetServerId, Long.toHexString(zxid));
return ack;
} catch (IOException e) {
logger.error("Failed to send proposal to follower {}, zxid: {}",
targetServerId, Long.toHexString(zxid), e);
return null;
} finally {
MDC.remove("targetServerId");
}
}));
}
} finally {
executor.shutdown();
try {
if (!executor.awaitTermination(200, TimeUnit.MILLISECONDS)) {
List<Runnable> pendingTasks = executor.shutdownNow();
logger.warn("Force shutdown executor with {} pending tasks", pendingTasks.size());
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.warn("Interrupted while waiting for executor to terminate");
}
}
return futures;
}
// 定义读取策略接口和实现
private interface ReadStrategy {
CompletableFuture<Result> execute(String key, Supplier<Result> readFromLocal);
}
private final Map<ConsistencyLevel, ReadStrategy> readStrategies = new EnumMap<>(ConsistencyLevel.class);
{
// 初始化读取策略
readStrategies.put(ConsistencyLevel.LINEARIZABLE, new LinearizableReadStrategy());
readStrategies.put(ConsistencyLevel.SEQUENTIAL, new SequentialReadStrategy());
readStrategies.put(ConsistencyLevel.READ_YOUR_WRITES, new ReadYourWritesStrategy());
readStrategies.put(ConsistencyLevel.BOUNDED_STALENESS, new BoundedStalenessStrategy());
readStrategies.put(ConsistencyLevel.EVENTUAL, new EventualReadStrategy());
}
// 线性一致性读取策略
private class LinearizableReadStrategy implements ReadStrategy {
private final AtomicLong leaseExpirationTime = new AtomicLong(0);
private final long leaderLeaseMs = 5000; // 5秒租约
@Override
public CompletableFuture<Result> execute(String key, Supplier<Result> readFromLocal) {
// Leader需要确认自己仍然是Leader (租约机制)
if (System.currentTimeMillis() < leaseExpirationTime.get()) {
// 租约有效,可以安全读取
return CompletableFuture.completedFuture(readFromLocal.get());
} else {
// 租约过期,需要重新获取多数派确认
return renewLease().thenApply(renewed -> {
if (renewed) {
return readFromLocal.get();
} else {
throw new ConsistencyException("Cannot guarantee linearizable read");
}
});
}
}
private CompletableFuture<Boolean> renewLease() {
// 实际实现中,需要获取多数派确认
leaseExpirationTime.set(System.currentTimeMillis() + leaderLeaseMs);
logger.info("Renewed leader lease until {}", leaseExpirationTime.get());
return CompletableFuture.completedFuture(true);
}
}
// 顺序一致性读取策略
private class SequentialReadStrategy implements ReadStrategy {
@Override
public CompletableFuture<Result> execute(String key, Supplier<Result> readFromLocal) {
// 确保应用了所有已提交的事务
return ensureAppliedUpToDate()
.thenApply(v -> readFromLocal.get());
}
private CompletableFuture<Void> ensureAppliedUpToDate() {
// 实际实现会确保所有已提交的事务都已应用
logger.debug("Ensuring all committed transactions are applied");
return CompletableFuture.completedFuture(null);
}
}
// 读己所写策略
private class ReadYourWritesStrategy implements ReadStrategy {
private final ConcurrentMap<String, Long> writeTimestamps = new ConcurrentHashMap<>();
@Override
public CompletableFuture<Result> execute(String key, Supplier<Result> readFromLocal) {
// 检查是否有该key的写入记录
Long writeTime = writeTimestamps.get(key);
if (writeTime != null) {
// 确保经过足够时间,写入已经完成
long elapsed = System.currentTimeMillis() - writeTime;
if (elapsed < 100) { // 假设100ms足够写入完成
try {
Thread.sleep(100 - elapsed);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
return CompletableFuture.completedFuture(readFromLocal.get());
}
// 记录写入操作
public void recordWrite(String key) {
writeTimestamps.put(key, System.currentTimeMillis());
}
}
// 有界陈旧性策略
private class BoundedStalenessStrategy implements ReadStrategy {
private final ConcurrentMap<String, CacheEntry> cache = new ConcurrentHashMap<>();
private final long maxStalenessMs = 1000; // 最大陈旧时间1秒
@Override
public CompletableFuture<Result> execute(String key, Supplier<Result> readFromLocal) {
// 检查缓存
CacheEntry entry = cache.get(key);
if (entry != null) {
long age = System.currentTimeMillis() - entry.getTimestamp();
if (age <= maxStalenessMs) {
// 缓存未过期,直接返回
return CompletableFuture.completedFuture(entry.getResult());
}
}
// 缓存过期或不存在,从本地读取并更新缓存
Result result = readFromLocal.get();
cache.put(key, new CacheEntry(result, System.currentTimeMillis()));
return CompletableFuture.completedFuture(result);
}
// 定期清理过期缓存
public void cleanup() {
long now = System.currentTimeMillis();
cache.entrySet().removeIf(entry ->
now - entry.getValue().getTimestamp() > maxStalenessMs);
}
}
// 最终一致性策略
private class EventualReadStrategy implements ReadStrategy {
@Override
public CompletableFuture<Result> execute(String key, Supplier<Result> readFromLocal) {
// 直接从本地读取,不保证看到最新写入
return CompletableFuture.completedFuture(readFromLocal.get());
}
}
// 缓存条目
private static class CacheEntry {
private final Result result;
private final long timestamp;
public CacheEntry(Result result, long timestamp) {
this.result = result;
this.timestamp = timestamp;
}
public Result getResult() {
return result;
}
public long getTimestamp() {
return timestamp;
}
}
@Override
public void close() {
try {
List<Runnable> pendingTasks = scheduler.shutdownNow();
if (!pendingTasks.isEmpty()) {
logger.warn("Scheduler shutdown with {} pending tasks", pendingTasks.size());
}
if (!scheduler.awaitTermination(5, TimeUnit.SECONDS)) {
logger.warn("Scheduler did not terminate in time");
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.warn("Interrupted while waiting for scheduler termination");
}
}
// 断路器实现(更安全的版本)
static class CircuitBreaker {
private final AtomicReference<State> state = new AtomicReference<>(State.CLOSED);
private final AtomicLong failureCount = new AtomicLong(0);
private final AtomicLong lastFailureTime = new AtomicLong(0);
private final int threshold;
private final long resetTimeoutMs;
private final StampedLock stateLock = new StampedLock();
private final Logger logger = LoggerFactory.getLogger(CircuitBreaker.class);
public enum State { CLOSED, OPEN, HALF_OPEN }
public CircuitBreaker(int threshold, long resetTimeoutMs) {
this.threshold = threshold;
this.resetTimeoutMs = resetTimeoutMs;
}
public <T> CompletableFuture<T> execute(Supplier<CompletableFuture<T>> action)
throws CircuitBreakerOpenException {
State currentState = getCurrentState();
if (currentState == State.OPEN) {
// 检查是否应该尝试半开状态
if (System.currentTimeMillis() - lastFailureTime.get() > resetTimeoutMs) {
boolean transitioned = tryTransitionState(State.OPEN, State.HALF_OPEN);
if (!transitioned) {
throw new CircuitBreakerOpenException("Circuit breaker is open");
}
currentState = State.HALF_OPEN;
} else {
throw new CircuitBreakerOpenException("Circuit breaker is open");
}
}
final State executionState = currentState;
try {
CompletableFuture<T> future = action.get();
return future.handle((result, ex) -> {
if (ex != null) {
recordFailure();
throw new CompletionException(ex);
} else {
// 成功执行,重置失败计数
if (executionState == State.HALF_OPEN) {
tryTransitionState(State.HALF_OPEN, State.CLOSED);
}
failureCount.set(0);
return result;
}
});
} catch (Exception e) {
recordFailure();
throw e;
}
}
private void recordFailure() {
long stamp = stateLock.writeLock();
try {
long failures = failureCount.incrementAndGet();
lastFailureTime.set(System.currentTimeMillis());
if (failures >= threshold && state.get() == State.CLOSED) {
logger.warn("Circuit breaker opening after {} failures", failures);
state.set(State.OPEN);
}
} finally {
stateLock.unlockWrite(stamp);
}
}
private boolean tryTransitionState(State fromState, State toState) {
long stamp = stateLock.writeLock();
try {
if (state.get() == fromState) {
state.set(toState);
logger.info("Circuit breaker state changed from {} to {}", fromState, toState);
return true;
}
return false;
} finally {
stateLock.unlockWrite(stamp);
}
}
// 使用乐观读获取当前状态
public State getCurrentState() {
long stamp = stateLock.tryOptimisticRead();
State result = state.get();
if (!stateLock.validate(stamp)) {
stamp = stateLock.readLock();
try {
result = state.get();
} finally {
stateLock.unlockRead(stamp);
}
}
return result;
}
}
// 全局异常处理器
static class GlobalExceptionHandler {
private static final Logger logger = LoggerFactory.getLogger(GlobalExceptionHandler.class);
public static <T> CompletableFuture<T> withExceptionHandling(CompletableFuture<T> future) {
return future.exceptionally(e -> {
Throwable cause = e instanceof CompletionException ? e.getCause() : e;
if (cause instanceof ConsistencyException) {
logger.error("Consistency error: {}", cause.getMessage());
} else if (cause instanceof IOException) {
logger.error("I/O error: {}", cause.getMessage());
} else if (cause instanceof InterruptedException) {
Thread.currentThread().interrupt();
logger.warn("Operation interrupted");
} else {
logger.error("Unexpected error: {}", cause.getClass().getName(), cause);
}
throw new CompletionException(cause);
});
}
}
// 指标收集类
private static class MetricsCollector {
private final Counter writeRequests;
private final Counter writeSuccess;
private final Counter writeFailed;
private final Counter writeRejected;
private final Counter batchWrites;
private final Counter batchWriteRequests;
private final Counter readRequests;
private final Map<ConsistencyLevel, Counter> readsByLevel = new EnumMap<>(ConsistencyLevel.class);
private final Histogram writeLatency;
private final Histogram batchWriteLatency;
private final Map<ConsistencyLevel, Histogram> readLatency = new EnumMap<>(ConsistencyLevel.class);
public MetricsCollector(String prefix) {
this.writeRequests = Counter.build()
.name(prefix + "_write_requests_total")
.help("Total number of write requests").register();
this.writeSuccess = Counter.build()
.name(prefix + "_write_success_total")
.help("Total number of successful writes").register();
this.writeFailed = Counter.build()
.name(prefix + "_write_failed_total")
.help("Total number of failed writes").register();
this.writeRejected = Counter.build()
.name(prefix + "_write_rejected_total")
.help("Total number of rejected writes").register();
this.batchWrites = Counter.build()
.name(prefix + "_batch_writes_total")
.help("Total number of batch writes").register();
this.batchWriteRequests = Counter.build()
.name(prefix + "_batch_write_requests_total")
.help("Total number of requests in batch writes").register();
this.readRequests = Counter.build()
.name(prefix + "_read_requests_total")
.help("Total number of read requests").register();
this.writeLatency = Histogram.build()
.name(prefix + "_write_latency_ms")
.help("Write latency in milliseconds").register();
this.batchWriteLatency = Histogram.build()
.name(prefix + "_batch_write_latency_ms")
.help("Batch write latency in milliseconds").register();
// 初始化各一致性级别的计数器和直方图
for (ConsistencyLevel level : ConsistencyLevel.values()) {
readsByLevel.put(level, Counter.build()
.name(prefix + "_reads_" + level.name().toLowerCase() + "_total")
.help("Total " + level + " reads").register());
readLatency.put(level, Histogram.build()
.name(prefix + "_read_" + level.name().toLowerCase() + "_latency_ms")
.help(level + " read latency in milliseconds").register());
}
}
public void recordSuccessfulWrite(long latencyMs) {
writeRequests.inc();
writeSuccess.inc();
writeLatency.observe(latencyMs);
}
public void recordFailedWrite() {
writeRequests.inc();
writeFailed.inc();
}
public void recordRejectedWrite() {
writeRequests.inc();
writeRejected.inc();
}
public void recordSuccessfulBatchWrite(int batchSize, long latencyMs) {
batchWrites.inc();
batchWriteRequests.inc(batchSize);
writeRequests.inc(batchSize);
writeSuccess.inc(batchSize);
batchWriteLatency.observe(latencyMs);
}
public void recordFailedBatchWrite(int batchSize) {
batchWrites.inc();
batchWriteRequests.inc(batchSize);
writeRequests.inc(batchSize);
writeFailed.inc(batchSize);
}
public void recordRejectedBatchWrite(int batchSize) {
batchWrites.inc();
batchWriteRequests.inc(batchSize);
writeRequests.inc(batchSize);
writeRejected.inc(batchSize);
}
public void recordRead(ConsistencyLevel level, long latencyMs) {
readRequests.inc();
readsByLevel.get(level).inc();
readLatency.get(level).observe(latencyMs);
}
public void recordFailedRead(ConsistencyLevel level) {
readRequests.inc();
// 可以添加失败计数器
}
}
// 异常类
public static class CircuitBreakerOpenException extends Exception {
public CircuitBreakerOpenException(String message) {
super(message);
}
}
public static class ConsistencyException extends RuntimeException {
public ConsistencyException(String message) {
super(message);
}
}
public static class ProcessingException extends RuntimeException {
public ProcessingException(String message, Throwable cause) {
super(mess

最低0.47元/天 解锁文章
797

被折叠的 条评论
为什么被折叠?



