林炳文Evankaka原创作品。转载请注明出处http://blog.youkuaiyun.com/evankaka
摘要:本文主要讲了Java中ConcurrentHashMap 的源码
ConcurrentHashMap 是java并发包中非常有用的一个类,在高并发场景用得非常多,它是线程安全的。要注意到虽然HashTable虽然也是线程安全的,但是它的性能在高并发场景下完全比不上ConcurrentHashMap ,这也是由它们的结构所决定的,可以认为ConcurrentHashMap 是HashTable的加强版,不过这加强版和原来的HashTable有非常大的区别,不仅是在结构上,而且在方法上也有差别。
下面先来简单说说它们的区别吧HashMap、 HashTable、ConcurrentHashMap 的区别吧
1、HashMap是非线程安全,HashTable、ConcurrentHashMap 都是线程安全,而且ConcurrentHashMap 、Hashtable不能传入nul的key或value,HashMap可以。
2、Hashtable是将数据放入到一个Entrty数组或者它Entrty数组上一个Entrty的链表节点。而ConcurrentHashMap 是由Segment数组组成,每一个Segment可以看成是一个单独的Map.然后每个Segment里又有一个HashEntrty数组用来存放数据。
3、HashTable的get/put/remove方法都是基于同步的synchronized方法,而ConcurrentHashMap 是基本锁的机制,并且每次不是锁全表,而是锁单独的一个Segment。所以ConcurrentHashMap 的性能比HashTable好。
4、如果不考虑线程安全因素,推荐使用HashMap,因为它性能最好。
首先来看看它包含的结构吧!
public class ConcurrentHashMap<K, V> extends AbstractMap<K, V>
implements ConcurrentMap<K, V>, Serializable {
private static final long serialVersionUID = 7249069246763182397L;
//默认数组容量大小
static final int DEFAULT_INITIAL_CAPACITY = 16;
//默认装载因子
static final float DEFAULT_LOAD_FACTOR = 0.75f;
//默认层级
static final int DEFAULT_CONCURRENCY_LEVEL = 16;
//最大的每一个数组容量大小
static final int MAXIMUM_CAPACITY = 1 << 30;
//最大的分组数目
static final int MAX_SEGMENTS = 1 << 16; // slightly conservative
//调用remove/contain/replace方法时不加锁的情况下操作重试次数
static final int RETRIES_BEFORE_LOCK = 2;
//segments 数组索引相关
final int segmentMask;
//segments 数组偏移相关
final int segmentShift;
//segments数组,每个segments单独就可以认为是一个map
final Segment<K,V>[] segments;
这里看到了一个Segment<K,V>[] segments;数组。下面再来看看Segment这个类。在下面可以看到Segment这个类继承了ReentrantLock锁。所以它也是一个锁。然后它里面还有一个HashEntry<K,V>[] table。这是真正用来存放数据的结构。
/**
* Segment内部类,注意它也是一个锁!可以认为它是一个带有锁方法的map
*/
static final class Segment<K,V> extends ReentrantLock implements Serializable {
private static final long serialVersionUID = 2249069246763182397L;
//元素个数
transient volatile int count;
//修改次数
transient int modCount;
//阈值,超过这个数会重新reSize
transient int threshold;
//注意,这里又一个数组,这个是真正存放数据元素的地方
transient volatile HashEntry<K,V>[] table;
//装载因子,用来计算threshold
final float loadFactor;
HashEntry它的结构很简单:
static final class HashEntry<K,V> {
final K key;
final int hash;//用来保存Segment索引的信息
volatile V value;
final HashEntry<K,V> next;
HashEntry(K key, int hash, HashEntry<K,V> next, V value) {
this.key = key;
this.hash = hash;
this.next = next;
this.value = value;
}
@SuppressWarnings("unchecked")
static final <K,V> HashEntry<K,V>[] newArray(int i) {
return new HashEntry[i];
}
}
经过上面的分析:可以得出如下的ConcurrentHashMap结构图
全部源码分析:
package java.util.concurrent;
import java.util.concurrent.locks.*;
import java.util.*;
import java.io.Serializable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
public class ConcurrentHashMap<K, V> extends AbstractMap<K, V>
implements ConcurrentMap<K, V>, Serializable {
private static final long serialVersionUID = 7249069246763182397L;
//默认数组容量大小
static final int DEFAULT_INITIAL_CAPACITY = 16;
//默认装载因子
static final float DEFAULT_LOAD_FACTOR = 0.75f;
//默认层级
static final int DEFAULT_CONCURRENCY_LEVEL = 16;
//最大的每一个数组容量大小
static final int MAXIMUM_CAPACITY = 1 << 30;
//最大的分组数目
static final int MAX_SEGMENTS = 1 << 16; // slightly conservative
//调用remove/contain/replace方法时不加锁的情况下操作重试次数
static final int RETRIES_BEFORE_LOCK = 2;
//segments 数组索引相关
final int segmentMask;
//segments 数组偏移相关
final int segmentShift;
//segments数组,每个segments单独就可以认为是一个map
final Segment<K,V>[] segments;
/**
* 哈希算法
*/
private static int hash(int h) {
// Spread bits to regularize both segment and index locations,
// using variant of single-word Wang/Jenkins hash.
h += (h << 15) ^ 0xffffcd7d;
h ^= (h >>> 10);
h += (h << 3);
h ^= (h >>> 6);
h += (h << 2) + (h << 14);
return h ^ (h >>> 16);
}
/**
* 根据哈希值计算应该落在哪个segments上
*/
final Segment<K,V> segmentFor(int hash) {
return segments[(hash >>> segmentShift) & segmentMask];
}
/**
* 内部类,每个HashEntry都会存入到一个Segment中去
*/
static final class HashEntry<K,V> {
final K key;//关键字
final int hash;//哈希值
volatile V value;//值
final HashEntry<K,V> next;//不同的关键字,相再的哈希值时会组成 一个链表
HashEntry(K key, int hash, HashEntry<K,V> next, V value) {
this.key = key;
this.hash = hash;
this.next = next;
this.value = value;
}
@SuppressWarnings("unchecked")
static final <K,V> HashEntry<K,V>[] newArray(int i) {
return new HashEntry[i];
}
}
/**
* Segment内部类,注意它也是一个锁!可以认为它是一个带有锁方法的map
*/
static final class Segment<K,V> extends ReentrantLock implements Serializable {
private static final long serialVersionUID = 2249069246763182397L;
//元素个数
transient volatile int count;
//修改次数
transient int modCount;
//阈值,超过这个数会重新reSize
transient int threshold;
//注意,这里又一个数组,这个是真正存放数据元素的地方
transient volatile HashEntry<K,V>[] table;
//装载因子,用来计算threshold
final float loadFactor;
//构造函数,由initialCapacity确定table的大小
Segment(int initialCapacity, float lf) {
loadFactor = lf;
setTable(HashEntry.<K,V>newArray(initialCapacity));
}
@SuppressWarnings("unchecked")
static final <K,V> Segment<K,V>[] newArray(int i) {
return new Segment[i];
}
//设置threshold、table
void setTable(HashEntry<K,V>[] newTable) {
threshold = (int)(newTable.length * loadFactor);//注意,当table的元素个数超过这个时,会触发reSize;
table = newTable;
}
//取得头一个
HashEntry<K,V> getFirst(int hash) {
HashEntry<K,V>[] tab = table;
return tab[hash & (tab.length - 1)];
}
//在加锁情况下读数据,注意这个类继续了锁的方法
V readValueUnderLock(HashEntry<K,V> e) {
lock();
try {
return e.value;
} finally {
unlock();
}
}
//取元素
V get(Object key, int hash) {
if (count != 0) { //注意,没有加锁
HashEntry<K,V> e = getFirst(hash);//取得头一个
while (e != null) { //依次从table中取出元素判断
if (e.hash == hash && key.equals(e.key)) { //hash和key同时相等才表示存在
V v = e.value;
if (v != null) //有可能在这里时,运行了删除元素导致为Null,一般发生比较少
return v;
return readValueUnderLock(e); // 重新在加锁情况下读数据
}
e = e.next;
}
}
return null;
}
//是否包含一个元素
boolean containsKey(Object key, int hash) {
if (count != 0) { //
HashEntry<K,V> e = getFirst(hash);
while (e != null) {
if (e.hash == hash && key.equals(e.key))
return true;
e = e.next;
}
}
return false;
}
//是否包含一个元素
boolean containsValue(Object value) {
if (count != 0) { // read-volatile
HashEntry<K,V>[] tab = table;
int len = tab.length;
for (int i = 0 ; i < len; i++) {
for (HashEntry<K,V> e = tab[i]; e != null; e = e.next) { //table数组循环读数
V v = e.value;
if (v == null) // recheck
v = readValueUnderLock(e);
if (value.equals(v))
return true;
}
}
}
return false;
}
//替换时要加锁
boolean replace(K key, int hash, V oldValue, V newValue) {
lock();
try {
HashEntry<K,V> e = getFirst(hash);
while (e != null && (e.hash != hash || !key.equals(e.key)))//hash和key要同时相等才表示是找到了这个元素
e = e.next;
boolean replaced = false;
if (e != null && oldValue.equals(e.value)) { //判断是否要进行替换
replaced = true;
e.value = newValue;
}
return replaced;
} finally {
unlock();
}
}
//替换时要加锁
V replace(K key, int hash, V newValue) {
lock();
try {
HashEntry<K,V> e = getFirst(hash);
while (e != null && (e.hash != hash || !key.equals(e.key)))
e = e.next;
V oldValue = null;
if (e != null) {
oldValue = e.value;
e.value = newValue;
}
return oldValue;
} finally {
unlock();
}
}
//放入一个元素,onlyIfAbsent如果有false表示替换原来的旧值
V put(K key, int hash, V value, boolean onlyIfAbsent) {
lock();
try {
int c = count;
if (c++ > threshold) // table数组里的元素超过threshold。触发rehash,其实也就是扩大table
rehash();
HashEntry<K,V>[] tab = table;
int index = hash & (tab.length - 1);
HashEntry<K,V> first = tab[index];//头一个
HashEntry<K,V> e = first;
while (e != null && (e.hash != hash || !key.equals(e.key)))//一直不断判断不重复才停止
e = e.next;
V oldValue;
if (e != null) { //这个key、hash已经存在,修改原来的
oldValue = e.value;
if (!onlyIfAbsent)
e.value = value; //替换原来的旧值
}
else { //这个key、hash已经不存在,加入一个新的
oldValue = null;
++modCount;
tab[index] = new HashEntry<K,V>(key, hash, first, value);//加入一个新的元素
count = c; // 个数变化
}
return oldValue;
} finally {
unlock();
}
}
//重新哈希
void rehash() {
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;//旧容量
if (oldCapacity >= MAXIMUM_CAPACITY) //超过默认的最大容量时就退出了
return;
HashEntry<K,V>[] newTable = HashEntry.newArray(oldCapacity<<1);//这里直接在原来的基础上扩大1倍
threshold = (int)(newTable.length * loadFactor);//重新计算新的阈值
int sizeMask = newTable.length - 1;
for (int i = 0; i < oldCapacity ; i++) { //下面要做的就是将旧的table上的数据拷贝到新的table
HashEntry<K,V> e = oldTable[i];
if (e != null) { //旧table上该处有数据
HashEntry<K,V> next = e.next;
int idx = e.hash & sizeMask;
// 单个节点key-value
if (next == null)
newTable[idx] = e;
else { //链表节点key-value
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
for (HashEntry<K,V> last = next;
last != null;
last = last.next) { //这里重新计算了链表上最后一个节点的位置
int k = last.hash & sizeMask;
if (k != lastIdx) {
lastIdx = k;
lastRun = last;
}
}
newTable[lastIdx] = lastRun;//将原table上对应的链表上的最后一个元素放在新table对应链表的首位置
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) { //for循环依次拷贝链表上的数据,注意最后整个链表相对原来会倒序排列
int k = p.hash & sizeMask;
HashEntry<K,V> n = newTable[k];
newTable[k] = new HashEntry<K,V>(p.key, p.hash,
n, p.value);//新table数据赋值
}
}
}
}
table = newTable;
}
//删除一个元素
V remove(Object key, int hash, Object value) {
lock(); //删除要加锁
try {
int c = count - 1;
HashEntry<K,V>[] tab = table;
int index = hash & (tab.length - 1);
HashEntry<K,V> first = tab[index];
HashEntry<K,V> e = first;
while (e != null && (e.hash != hash || !key.equals(e.key)))
e = e.next;
V oldValue = null;
if (e != null) { ///找到元素
V v = e.value;
if (value == null || value.equals(v)) { //为null或者value相等时才删除
oldValue = v;
++modCount;
HashEntry<K,V> newFirst = e.next;
for (HashEntry<K,V> p = first; p != e; p = p.next)
newFirst = new HashEntry<K,V>(p.key, p.hash,
newFirst, p.value);//注意它这里会倒换原来链表的位置
tab[index] = newFirst;
count = c; //记录数减去1
}
}
return oldValue;
} finally {
unlock();
}
}
//清空整个map
void clear() {
if (count != 0) {
lock();
try {
HashEntry<K,V>[] tab = table;
for (int i = 0; i < tab.length ; i++)
tab[i] = null;//直接赋值为null
++modCount;
count = 0; // write-volatile
} finally {
unlock();
}
}
}
}
//构造函数
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0) //loadFactor和initialCapacity都得大于0
throw new IllegalArgumentException();
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
// Find power-of-two sizes best matching arguments
int sshift = 0;
int ssize = 1;
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
segmentShift = 32 - sshift;
segmentMask = ssize - 1;
this.segments = Segment.newArray(ssize);
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
int c = initialCapacity / ssize;
if (c * ssize < initialCapacity)
++c;
int cap = 1;
while (cap < c)
cap <<= 1;
for (int i = 0; i < this.segments.length; ++i)
this.segments[i] = new Segment<K,V>(cap, loadFactor);//为每个segments初始化其里面的数组
}
//构造函数
public ConcurrentHashMap(int initialCapacity, float loadFactor) {
this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL);
}
//构造函数
public ConcurrentHashMap(int initialCapacity) {
this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
//默认构造函数
public ConcurrentHashMap() {
this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
//构造函数
public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
this(Math.max((int) (m.size() / DEFAULT_LOAD_FACTOR) + 1,
DEFAULT_INITIAL_CAPACITY),
DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
putAll(m);
}
//判断 是否为空
public boolean isEmpty() {
final Segment<K,V>[] segments = this.segments;//取得整个Segment数组
int[] mc = new int[segments.length];
int mcsum = 0;
for (int i = 0; i < segments.length; ++i) {
if (segments[i].count != 0) //有一个segments数组元素个数不为0,那么整个map肯定不为空
return false;
else
mcsum += mc[i] = segments[i].modCount;//累加总的修改次数
}
if (mcsum != 0) {
for (int i = 0; i < segments.length; ++i) {
if (segments[i].count != 0 ||
mc[i] != segments[i].modCount)//这里又做了一次mc[i] != segments[i].modCount判断,因为segments[i].count = 0时才会跳到这里,不相等那么肯定是又有元素加入
return false;
}
}
return true;
}
//整个mapr的大小,这里的大小指的是存放元素的个数
public int size() {
final Segment<K,V>[] segments = this.segments;
long sum = 0;
long check = 0;
int[] mc = new int[segments.length];
//这是里的for循环是尝试在不加锁的情况下来获取整个map的元素个数
for (int k = 0; k < RETRIES_BEFORE_LOCK; ++k) { //这里RETRIES_BEFORE_LOCK=2,最大会做两次的循环
check = 0;
sum = 0;
int mcsum = 0;
for (int i = 0; i < segments.length; ++i) {
sum += segments[i].count;//累加每一个segments上的count
mcsum += mc[i] = segments[i].modCount;//累加每一个segments上的modCount
}
if (mcsum != 0) { //修改次数不为0,要再做一次判断前后两次的modCount,count的累加
for (int i = 0; i < segments.length; ++i) {
check += segments[i].count;
if (mc[i] != segments[i].modCount) { //前后两次数据发生了变化
check = -1; // 前后两次取的个数不一到,注意sum还是之前的
break;
}
}
}
if (check == sum) //前后两次取的元素个数一样,直接跳出循环
break;
}
//这里会尝试在加锁的情况下来获取整个map的元素个数
if (check != sum) { // 这里一般check会等于-1才发生
sum = 0;//重新置0
for (int i = 0; i < segments.length; ++i)
segments[i].lock();//每一个segments上锁
for (int i = 0; i < segments.length; ++i)
sum += segments[i].count;//重新累加
for (int i = 0; i < segments.length; ++i)
segments[i].unlock();//依次释放锁
}
if (sum > Integer.MAX_VALUE)
return Integer.MAX_VALUE;//如果大于最大值,返回最大值
else
return (int)sum;
}
//取得一个元素,先是不加锁情况下去读
public V get(Object key) {
int hash = hash(key.hashCode());
return segmentFor(hash).get(key, hash);//具体看上面的代码注释
}
//是否包含一个元素,根据key来获取
public boolean containsKey(Object key) {
int hash = hash(key.hashCode());
return segmentFor(hash).containsKey(key, hash);
}
//是否包含一个元素,根据value来获取
public boolean containsValue(Object value) {
if (value == null)
throw new NullPointerException();
//取得整个Segment的内容
final Segment<K,V>[] segments = this.segments;
int[] mc = new int[segments.length];
// 尝试在不加锁的情况下做判断
for (int k = 0; k < RETRIES_BEFORE_LOCK; ++k) { //RETRIES_BEFORE_LOCK这里=2
int sum = 0;
int mcsum = 0;
for (int i = 0; i < segments.length; ++i) {
int c = segments[i].count;//累加个数
mcsum += mc[i] = segments[i].modCount;//累加修改次数
if (segments[i].containsValue(value)) //判断是否包含
return true;
}
boolean cleanSweep = true;
if (mcsum != 0) { //成立说明发生了改变
for (int i = 0; i < segments.length; ++i) { //再循环一次取得segments
int c = segments[i].count;//累加第二次循环得到的count
if (mc[i] != segments[i].modCount) { //如果有一个segments前后两次不一样,那么它的元素肯定发生了变化
cleanSweep = false;//
break;//跳出
}
}
}
if (cleanSweep) //为ture表示经过上面的两次判断还是无法找到
return false;
}
// cleanSweepo为false时,进行下面。注意,这里是在加锁情况下
for (int i = 0; i < segments.length; ++i)
segments[i].lock();//取得每一个segments的锁
boolean found = false;
try {
for (int i = 0; i < segments.length; ++i) { //每个segments取出来做判断
if (segments[i].containsValue(value)) {
found = true;
break;
}
}
} finally {
for (int i = 0; i < segments.length; ++i) //依次释放segments的锁
segments[i].unlock();
}
return found;
}
//是否包含
public boolean contains(Object value) {
return containsValue(value);
}
//放入一个元素
public V put(K key, V value) {
if (value == null)
throw new NullPointerException();
int hash = hash(key.hashCode());
return segmentFor(hash).put(key, hash, value, false);//放入时先根据key的hash值找到存放 的segments,再调用其put方法
}
//放入一个元素,如果key或value为null,那么为招出一个异常
public V putIfAbsent(K key, V value) {
if (value == null)
throw new NullPointerException();
int hash = hash(key.hashCode());
return segmentFor(hash).put(key, hash, value, true);
}
//放入一个map
public void putAll(Map<? extends K, ? extends V> m) {
for (Map.Entry<? extends K, ? extends V> e : m.entrySet()) //遍历entrySet取出再放入
put(e.getKey(), e.getValue());
}
//删除一个元素,根据key
public V remove(Object key) {
int hash = hash(key.hashCode());
return segmentFor(hash).remove(key, hash, null);//根据key的hash值找到存放的segments,再调用其remove方法
}
//删除一个元素,根据key-value
public boolean remove(Object key, Object value) {
int hash = hash(key.hashCode());
if (value == null)
return false;
return segmentFor(hash).remove(key, hash, value) != null;//根据key的hash值找到存放的segments,再调用其remove方法
}
//替换元素
public boolean replace(K key, V oldValue, V newValue) {
if (oldValue == null || newValue == null)
throw new NullPointerException();
int hash = hash(key.hashCode());
return segmentFor(hash).replace(key, hash, oldValue, newValue);//根据key的hash值找到存放的segments,再调用其replace方法
}
//替换元素
public V replace(K key, V value) {
if (value == null)
throw new NullPointerException();
int hash = hash(key.hashCode());
return segmentFor(hash).replace(key, hash, value);
}
//清空
public void clear() {
for (int i = 0; i < segments.length; ++i)
segments[i].clear();
}
//序列化方法
private void writeObject(java.io.ObjectOutputStream s) throws IOException {
s.defaultWriteObject();
for (int k = 0; k < segments.length; ++k) {
Segment<K,V> seg = segments[k];
seg.lock();//注意加锁了
try {
HashEntry<K,V>[] tab = seg.table;
for (int i = 0; i < tab.length; ++i) {
for (HashEntry<K,V> e = tab[i]; e != null; e = e.next) {
s.writeObject(e.key);
s.writeObject(e.value);
}
}
} finally {
seg.unlock();
}
}
s.writeObject(null);
s.writeObject(null);
}
//反序列化方法
private void readObject(java.io.ObjectInputStream s)
throws IOException, ClassNotFoundException {
s.defaultReadObject();
for (int i = 0; i < segments.length; ++i) {
segments[i].setTable(new HashEntry[1]);
}
for (;;) {
K key = (K) s.readObject();
V value = (V) s.readObject();
if (key == null)
break;
put(key, value);
}
}
}