1. 问题描述
目前在跑批任务中,有多种类型,存在很多冗余代码,且发现任务跑批后数据丢失或者卡死问题,故进行优化改造
2. 核心代码
/**
* 抽象类,原有逻辑不变
*/
public abstract class AbstractTaskService implements TaskService {
/**
* 线程池
*/
protected static final ExecutorService EXECUTOR_SERVICE = new ThreadPoolExecutor(20, 40, 120L, TimeUnit.MILLISECONDS,
new LinkedBlockingDeque<>(5000), new ThreadPoolExecutor.CallerRunsPolicy());
public AbstractTaskService() {
}
@Override
public void apply(DataStatistics dataStatistics) {
throw new UnsupportedOperationException("未实现apply");
}
@Override
public boolean support(String type) {
throw new UnsupportedOperationException("未实现support");
}
@Override
public void exportEvaluate(DataStatistics dataStatistics, HttpServletResponse httpServletResponse) {
throw new UnsupportedOperationException("未实现exportEvaluate");
}
@Override
public void exportTaskResult(SeaweedRequestDto seaweedRequestDto, HttpServletResponse httpServletResponse) {
throw new UnsupportedOperationException("未实现exportTaskResult");
}
@Override
public Page<DataStatisticsDto> listTask(String keywords, Pageable pageable) {
throw new UnsupportedOperationException("未实现listTask");
}
@Override
public synchronized void sendMsg(String taskId, Long totalCount, AtomicLong progressCount, String dataPackageName, DataStatisticsRepository dataStatisticsRepository,SimpMessagingTemplate messagingTemplate) {
// 执行进度为0,不进行操作
final Long progress = progressCount.longValue() * 100 / totalCount;
if (progress <= 0) {
return ;
}
DataStatistics dataStatisticsNow = dataStatisticsRepository.findById(taskId).get();
System.out.println("进度:"+progress);
dataStatisticsNow.setCount(progress);
if (progress >= 100L) {
dataStatisticsNow.setCount(100L);
dataStatisticsNow.setEndTime(LocalDateTime.now());
dataStatisticsNow.setStatus(Constant.TaskStatus.SUCCESS.val());
}
dataStatisticsRepository.save(dataStatisticsNow);
DataStatisticsDto dataStatisticsDto = CommonDtoUtils.transform(dataStatisticsNow, DataStatisticsDto.class);
if (dataStatisticsDto.getCount() > 0) {
dataStatisticsDto.setDataPackageName(dataPackageName);
messagingTemplate.convertAndSend("/topic/listTask", FangJiaApiResponse.ok(dataStatisticsDto));
}
}
}
新增子类抽象类,抽离多线程逻辑,新任务或者需改造的任务进行继承
/**
* 抽象批量任务服务类,封装多线程处理和批量更新逻辑
*/
@Slf4j
public abstract class AbstractBatchTaskService extends AbstractTaskService {
protected final SimpMessagingTemplate messagingTemplate;
private final ReentrantLock updateLock = new ReentrantLock();
private final Map<String, AddressResult> activeMap = new ConcurrentHashMap<>();
private final Map<String, AddressResult> bufferMap = new ConcurrentHashMap<>();
private final DataStatisticsRepository dataStatisticsRepository;
public AbstractBatchTaskService(SimpMessagingTemplate messagingTemplate, DataStatisticsRepository dataStatisticsRepository) {
this.messagingTemplate = messagingTemplate;
this.dataStatisticsRepository = dataStatisticsRepository;
}
/**
* 执行任务
*/
public void execute(DataStatistics dataStatistics, List<Address> addresses, String dataPackageName, boolean rerun) {
final String taskId = dataStatistics.getId();
final AtomicLong progressCount = new AtomicLong(0L);
final AtomicLong timeLong = new AtomicLong(System.currentTimeMillis());
final Long totalCount = (long) addresses.size();
log.info("任务开始: taskId={}, totalCount={}, rerun={}", taskId, totalCount, rerun);
for (final Address address : addresses) {
EXECUTOR_SERVICE.submit(() ->
processAndUpdate(address, taskId, totalCount, progressCount, timeLong, dataPackageName, rerun, dataStatistics));
}
}
/**
* 处理单个地址并更新进度
*/
private void processAndUpdate(Address address, String taskId, Long totalCount, AtomicLong progressCount,
AtomicLong timeLong, String dataPackageName, boolean rerun, DataStatistics dataStatistics) {
AddressResult addressResult = processAddress(address, dataStatistics);
addressResult.setOriginCity(address.getCity());
addressResult.setAddress(address.getAddress());
addressResult.setParentId(taskId);
addressResult.setNo(address.getNo());
if (rerun) {
addressResult.setRerun(Constant.YesOrNo.NO.desc());
}
activeMap.put(address.getNo(), addressResult);
long currentProgress = progressCount.incrementAndGet();
boolean shouldUpdate = (totalCount >= 10000 && currentProgress % 5000 == 0) ||
(totalCount < 10000 && currentProgress % 1000 == 0) ||
currentProgress >= totalCount;
if (shouldUpdate) {
updateLock.lock();
try {
bufferMap.putAll(activeMap);
activeMap.clear();
updateStructureTaskProgress(taskId, totalCount, progressCount, dataPackageName, bufferMap, rerun);
timeLong.set(System.currentTimeMillis());
} finally {
updateLock.unlock();
}
}
}
/**
* 更新任务进度并持久化数据
*/
private void updateStructureTaskProgress(String taskId, Long totalCount, AtomicLong progressCount,
String dataPackageName, Map<String, AddressResult> mapToUpdate, boolean rerun) {
if (mapToUpdate.isEmpty()) {
log.warn("任务进度更新时数据为空: taskId={}", taskId);
return;
}
try {
if (rerun) {
EsUtils.bulkUpdate(mapToUpdate.values(), IndexEnum.ADDRESS_RESULT.val());
} else {
EsUtils.bulkWrite(mapToUpdate.values(), IndexEnum.ADDRESS_RESULT.val());
}
log.info("批量写入ES成功: taskId={}, count={}", taskId, mapToUpdate.size());
} catch (Exception e) {
log.error("批量写入ES失败: taskId={}, error={}", taskId, e.getMessage(), e);
} finally {
mapToUpdate.clear();
}
sendMsg(taskId, totalCount, progressCount, dataPackageName, dataStatisticsRepository, messagingTemplate);
}
/**
* 抽象方法:子类实现具体的地址处理逻辑
*/
@Retryable(value = {feign.RetryableException.class}, maxAttempts = 10, backoff = @Backoff(delay = 10000, multiplier = 2))
protected abstract AddressResult processAddress(Address address, DataStatistics dataStatistics);
}
其中我们采用重入锁+双缓冲map进行保证线程安全性,欢迎指教~