解决问题
在很多并发业务场景下的时候,当多个 线程
过来获取某个数据时,如果需要通过从 缓存
中获取到对应数据时,但是第一次可能 缓存
数据还并没有建立,这时需要有一个 线程
进行热点缓存数据的建立,然后其他 线程
再进行获取。一般的解决方案可能就是通过 while
循环,一直等待数据的获取,一旦代码复杂起来了可能就很难看,而且这里每次线程来了都需要 加锁
可能并发高一点性能就很低下
while(true) {
lock.lock();
try {
Object obj = getData();
if(obj != null) {
return obj;
} else {
setData();
}
Thread.sleep(100);
} finally {
lock.unlock();
}
}
代码
这里的话,封装一个简单的 等待缓存
的工具类
- lockMap :key和锁对应的一个映射关系,每一个key可能都会有很多的线程来进行获取,但是不能锁住所有的线程
- obj :给某一个具体的 key 创建锁的时候需要要一个全局锁,这里只是通过获取然后创建,并不会导致很低的性能问题
- executorService:执行子任务获取数据时的线程池,需要获取到结果值和超时
public class WaitCache {
/** core */
private static final Integer core = Runtime.getRuntime().availableProcessors();
/** max */
private static final Integer max = Runtime.getRuntime().availableProcessors() * 2;
/** 锁映射 */
private final Map<Object, ReentrantLock> lockMap = new ConcurrentHashMap<>();
/** 用于锁住创建分段锁的地方 */
private final Lock obj = new ReentrantLock();
/** Executor service */
private final ExecutorService executorService;
/**
* Wait cache
*
* @param executorService executor service
* @since 1.0.0
*/
public WaitCache(ExecutorService executorService) {
this.executorService = Optional.ofNullable(executorService).orElseGet(() -> new ThreadPoolExecutor(core,
max,
60,
TimeUnit.SECONDS,
new ArrayBlockingQueue<>(500)));
}
/**
* Wait cache
*
* @since 1.0.0
*/
public WaitCache() {
this(null);
}
/**
* Write
*
* @param <T> parameter
* @param key key
* @param source 获取数据的方法
* @param getSource 当获取到数据的方法,返回的数据为0,这时需要构建缓存数据的方法
* @param consumer 消费的方法,拿到数据后需要执行的方法
* @since 1.0.0
*/
public <T> void write(Object key, Supplier<T> source, Supplier<T> getSource, Consumer<T> consumer) {
T t = source.get();
retry:
if (t == null) {
//创建缓存锁
Lock lock = createLock(key);
//给对应的key加上锁
lock.lock();
try {
//调用获取缓存数据
Future<T> task = executorService.submit(source::get);
//如果超时还没有获取到数据,抛出异常
t = task.get(200, TimeUnit.MILLISECONDS);
if (t != null) {
break retry;
} else {
//执行获取数据源
if (getSource != null) {
t = executorService.submit(getSource::get).get(200, TimeUnit.MILLISECONDS);
}else {
//如果没有获取数据源的地方,直接结束任务
return;
}
}
System.out.println(Thread.currentThread().getName() + "拿到缓存锁,设置数据:" + t);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
System.out.println("获取数据超时,线程ID:" + Thread.currentThread().getName());
e.printStackTrace();
return;
} finally {
//必须释放锁
lock.unlock();
}
}
consumer.accept(t);
}
/**
* Create lock
*
* @param key key
* @return the lock
* @since 1.0.0
*/
private Lock createLock(Object key) {
ReentrantLock reentrantLock = lockMap.get(key);
if (reentrantLock == null) {
obj.lock();
try {
reentrantLock = Optional.ofNullable(lockMap.get(key)).orElseGet(() -> {
ReentrantLock lock = new ReentrantLock();
lockMap.put(key, lock);
return lock;
});
} finally {
obj.unlock();
}
}
return reentrantLock;
}
/**
* 用于删除锁,注意如果删除了,可能会重复创建 ReentrantLock 导致效率低下
*
* @param key key
* @since 1.0.0
*/
public void release(Object key) {
Assert.notNull(key, () -> "key不能为空");
this.lockMap.remove(key);
}
}
3. 测试
开启了 400 个线程,正确的情况下时,只有4个线程能获取到锁,然后设置数据,其余的线程只需要等待获取到锁的线程设置数据即可,然后再执行获取数据的逻辑
- datasource :设置的数据库
- getConsumer() : 因为线程比较多,所以当线程id等于多少时,让他打印一下获取的数据对不对
public class ConcurrentApplicationTest
{
//设置一个数据库
private final Map<String, Object> datasource = new ConcurrentHashMap<>();
/**
* Rigorous Test :-)
*/
@Test
public void shouldAnswerWithTrue() throws InterruptedException {
WaitCache cache = new WaitCache();
List<Thread> list = new ArrayList<>();
String key1 = "asdgajsdg";
for (int i = 0; i < 100; i++) {
final int temp = i;
Thread thread = new Thread(() -> {
cache.write(key1, () -> datasource.get(key1), () -> {
User o = new User();
o.name = temp + ":" + 456 + temp;
datasource.put(key1, o);
return o;
}, getConsumer(temp == 10));
});
thread.setName("id:" + i);
list.add(thread);
}
String key2 = "999999";
for (int i = 100; i < 200; i++) {
final int temp = i;
Thread thread = new Thread(() -> {
cache.write(key2, () -> datasource.get(key2), () -> {
User o = new User();
o.name = temp + ":" + 999 + temp;
datasource.put(key2, o);
return o;
}, getConsumer(temp == 178));
});
thread.setName("id:" + i);
list.add(thread);
}
String key3 = "9qweqqweqw";
for (int i = 200; i < 300; i++) {
final int temp = i;
Thread thread = new Thread(() -> {
cache.write(key3, () -> datasource.get(key3), () -> {
User o = new User();
o.name = temp + ":" + 999 + temp;
datasource.put(key3, o);
return o;
}, getConsumer(temp == 210));
});
thread.setName("id:" + i);
list.add(thread);
}
String key4 = "zxcbzxcxc";
for (int i = 300; i < 400; i++) {
final int temp = i;
Thread thread = new Thread(() -> {
cache.write(key4, () -> {
try {
Thread.sleep(500);
} catch (InterruptedException e) {
e.printStackTrace();
}
return null;
}, () -> {
User o = new User();
o.name = temp + ":" + 999 + temp;
datasource.put(key4, o);
return o;
}, getConsumer(temp == 350));
});
thread.setName("id:" + i);
list.add(thread);
}
list.forEach(Thread::start);
Thread.currentThread().join();
}
private Consumer<Object> getConsumer(boolean b) {
return b ? c -> {
String name = Thread.currentThread().getName();
System.out.println("线程:" + name + "获取缓存数据:" + c);
} : c -> {};
}
private User get(int id) {
if (id == 1) {
User user = new User();
user.name = "123";
return user;
}
return null;
}
public static class User {
String name;
@Override
public String toString() {
return "名称:" + name;
}
}
}
测试结果可以看到,400个线程,只有3个线程获取到了锁,然后设置对应的数据到缓存里面,剩下的其余线程获取到的数据都是上面设置线程的id,因为通过 getConsumer()
方法,只在对应 id
的线程打印数据,所以这里 3个线程打印数据,还有一个线程由于设置了超时获取,所有直接抛出了异常