场景如下:一共三个线程a,b,c,其中c需要用到a,b,执行的结果,应该怎么处理?
1)CountDownLatch,主线程中调用await方法,每个线程调用countdown
上面两种方法需要分别调用多次join或future的get方法,不太好,有一种方法是使用CountDownLatch类
认知CountDownLatch的方法:
await():阻塞主线程,直到countDownLatch的计数器减少到0的位置,
countDown:将当前的计数器减1
getCount:返回当前的数
思路如下:让a,b线程使用CountDowmLatch,然后执行countDownLatch的await方法
主类:
public static void main(String[] args) {
CountDownLatch countDownLaunch = new CountDownLatch(5);
for(int i=1;i<6;i++){
ThreadWithCountDownLatch threadWithCountDownLatch = new ThreadWithCountDownLatch(i*1000L,countDownLaunch,"THREAD"+i);
Thread thread = new Thread(threadWithCountDownLatch);
thread.start();
}
mainThreadWork();
try {
// 强烈建议带上等待时间,避免出现因为countdown未执行而出现死等问题
countDownLaunch.await();//下面的代码要等待所有的countDown的计数器为零再执行
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
System.out.println("all done");
}
线程类:
public class ThreadWithCountDownLatch implements Runnable{
long time;
CountDownLatch countDownLatch;
String name;
public ThreadWithCountDownLatch(long time,CountDownLatch countDownLatch,String name){
this.time=time;
this.countDownLatch = countDownLatch;
this.name=name;
}
@Override
public void run() {
try {
System.out.println(name+" start");
work(time);
System.out.println(name+" end");
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}finally{
// 在finally里面执行非常重要,否则可能导致countdown未执行然后一直阻塞线程
countDownLatch.countDown();
}
}
private void work(long time2) throws InterruptedException {
Thread.sleep(time2);
}
}
2)CyclicBarrier-每个线程中调用await,调用达到次数再一起放行
调用栅栏(计数器为5)的await,相当于告诉栅栏已经有一个线程到达栅栏,同时线程本身不再继续执行;当cyclicBarrier的await方法调用5次时,所有线程继续执行,同时触发栅栏的run方法
主类:
public static void main(String[] args) {
CyclicBarrier cylicBarrier = new CyclicBarrier(5,new Runnable() {
//栅栏动作,在计数器为零的时候执行
@Override
public void run() {//栅栏的await方法执行五次后会调用此处
// TODO Auto-generated method stub
System.out.println("all work done");
}
});
ExecutorService pool = Executors.newFixedThreadPool(5);
for(int i=1;i<6;i++){
pool.submit(new ThreadWithCyclicBarrier(i,cylicBarrier));
}
pool.shutdown();
System.out.println("last ");
}
线程类:
public class ThreadWithCyclicBarrier implements Runnable{
int id;
CyclicBarrier cyclicBarrier;
public ThreadWithCyclicBarrier(int id,CyclicBarrier cyclicBarrier){
this.id = id;
this.cyclicBarrier=cyclicBarrier;
}
@Override
public void run() {
System.out.println("thread"+id+" start");
try {
cyclicBarrier.await();//栅栏计数值为零后会继续执行
Thread.sleep(1000L);
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (BrokenBarrierException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
System.out.println("thread"+id+" end");
}
}
3)join方法实现
对a,b使用a.join()和b.join(); 然后下面使用c.start()
SubTThread subthread = new SubTThread("subthread");
SubTThread subthread2 = new SubTThread("subthread2");
SubTThread subthread3 = new SubTThread("subthread3");
Thread thread = new Thread(subthread);
Thread thread2 = new Thread(subthread2);
Thread thread3 = new Thread(subthread3);
thread.start();
thread2.start();
mainThreadWork();
System.out.println("wait subthread done");
try {
thread.join();
thread2.join();//当thread和thread2的run方法执行完毕后才会继续下面的代码
} catch (InterruptedException e) {
e.printStackTrace();
}
thread3.start();
System.out.println("all work done");
4)线程池submit线程返回Future,并调用future的get方法
使用线程池,线程池submit线程返回Future,并调用future的get方法.则get方法下的代码都需要等待调用了get方法的线程完全执行后再执行
public static void main(String[] args) {
ExecutorService pool = Executors.newFixedThreadPool(3);
SubTThreadByRun subthread = new SubTThreadByRun("subthread");
SubTThreadByRun subthread2 = new SubTThreadByRun("subthread2");
Thread thread = new Thread(subthread);
Thread thread2 = new Thread(subthread2);
ThreadByCall threadByCall = new ThreadByCall("call1");
Future future1 = pool.submit(thread);
Future future2 = pool.submit(thread2);
Future future3 = pool.submit(threadByCall);
mainThreadWork();
try {
future1.get();
future2.get();
String str = (String) future3.get();//下面的代码需要等待thread1/thread2和threadByCall 执行完毕后执行下面的代码
System.out.println("得到结果:"+str);
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
SubTThreadByRun subthread3 = new SubTThreadByRun("subthread3");
Thread thread3 = new Thread(subthread3);
thread3.start();
System.out.println("all work done");
pool.shutdown();//关闭线程池
}
四种实现方式比较
join--适用于少量线程
线程池submit+future.get--适用于实际使用的线程数量不定,且所有线程执行完毕
CountDownLatch--设置线程数,每个线程结尾使用countDown()方法减值,主线程中调用await方法来阻塞,当数据减为0时,主线程继续执行,适用于实际使用线程数量固定
栅栏--多个线程都执行到某个点暂停,然后再一起开始,并且增加栅栏的run方法
new CyclicBarrier(5,new Runnable());每个线程执行到CyclicBarrier.await方法,线程会暂停,当await方法执行5次时,所有线程继续运行,并且执行cyclicBarrier的runner的run方法
countdownLatch在项目中的实际使用
背景:现在n个班级缺了个需要展示的属性,需要rpc从外部获取,但是一次只能查50个班级的;所以需要使用多线程提升性能,分页查询,但是需要所有班级属性查询完成后,才能继续执行,这里就可以使用countdownLatch进行控制
线程的个数(即rpc的次数): (n-1)/50 +1
具体代码如下:
private Map<String, String> concurrentClassByTeacherProcessor(ClassByTeacherParamDto param, List<ClassByCourseDto> rows) {
Stopwatch stopwatch = Stopwatch.createStarted();
// 线程安全 - 同时提效
Map<String, String> concurrentHashMap = new ConcurrentHashMap<>();
if (CollectionUtils.isNotEmpty(rows)) {
// 获取classId list
List<String> classIds = rows.stream()
.map(ClassByCourseDto::getId)
.collect(Collectors.toList());
int times = (classIds.size() - 1) / 50 + 1;
CountDownLatch countDownLatch = new CountDownLatch(times);
ExecutorService executorService = Executors.newFixedThreadPool(times);
try {
// 50 一组开启多线程
for (int i = 0; i < classIds.size(); i += 50) {
List<String> classIdList = classIds.subList(i, Math.min(i + 50, classIds.size()));
executorService.submit(()->{
List<ClassBySelfBuiltDto> classBySelfBuiltDtoList = courseClient.queueSelfBuiltClassById(classIdList, null);
classBySelfBuiltDtoList.forEach(classBySelfBuiltDto -> concurrentHashMap.put(classBySelfBuiltDto.getClassId(), classBySelfBuiltDto.getClassTimeDisplay()));
countDownLatch.countDown();
});
}
try {
countDownLatch.await();// 这里阻塞主线程,等待子线程执行完成
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
} finally {
executorService.shutdown();
}
}
log.info("concurrentClassByTeacherProcessor param:[{}], it costs:[{}]", param, stopwatch.elapsed(TimeUnit.MILLISECONDS));
return concurrentHashMap;
}