认真学习jdk1.7下ConcurrentHashMap的实现原理

本文详细解析了JDK1.7下ConcurrentHashMap的实现原理,包括其核心属性、构造函数、get、put、rehash等核心方法,以及统计元素个数的方法。ConcurrentHashMap通过分段锁机制实现了线程安全,支持高并发。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

HashMap无论是 1.7 还是 1.8 其实都能看出 JDK 没有对它做任何的同步操作,所以并发会出问题,甚至出现死循环导致系统不可用。这个问题就交给ConcurrentHashMap。

ConcurrentHashMap是一个 在juc包下的 map, 线程安全。 在jdk.1.8 之前采用数组+ 链表的结构 并且采用分段锁机制 来保证线程安全,而jdk1.8 改成了 数组+ 链表+ 红黑树,线程安全方面也改成了 cas+ synchronized 来保证线程安全。

ConcurrentHashMap类图如下:
在这里插入图片描述
本篇博文我们分析JDK1.7下ConcurrentHashMap的实现。

【1】核心属性和构造

① 核心属性

// table的默认初始化容量
static final int DEFAULT_INITIAL_CAPACITY = 16;

 // table的默认负载因子
static final float DEFAULT_LOAD_FACTOR = 0.75f;

/**
 * The default concurrency level for this table, used when not
 * otherwise specified in a constructor.
 */
// table的默认并发级别,换句话说其实是默认多少个“segment” 
static final int DEFAULT_CONCURRENCY_LEVEL = 16;

 // 最大容量
static final int MAXIMUM_CAPACITY = 1 << 30;

 // per-segment tables的最小容量,就是每段最少2个哈希桶位置
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;

//段的最大数量
static final int MAX_SEGMENTS = 1 << 16; // slightly conservative


// 在size()和containsValue()方法中,如果循环次数==RETRIES_BEFORE_LOCK ,
//则对每一段都进行加锁
static final int RETRIES_BEFORE_LOCK = 2;

/**
 * Mask value for indexing into segments. The upper bits of a
 * key's hash code are used to choose the segment.
 */
//用于索引到段的掩码值。key的散列码的高位用于选择段。 
final int segmentMask;


 //段内索引的移位值
final int segmentShift;

/**
 * The segments, each of which is a specialized hash table.
 */
 // 段数组,每一个端都是一个特殊的哈希表
final Segment<K,V>[] segments;

//三个常见的数据对象,使用了transient 不参与序列化和反序列
transient Set<K> keySet;
transient Set<Map.Entry<K,V>> entrySet;
transient Collection<V> values;

这里可能对segmentMask、segmentShift以及segments比较疑惑,别着急我们慢慢往下看。

// 默认情况下segmentShift =28  segmentMask =15
this.segmentShift = 32 - sshift;
this.segmentMask = ssize - 1;

② 核心对象HashEntry

如下所示,从成员来讲与HashMap中的Entry类似,都是hash、key、value、next。不同的是这里value和next使用了volatile 修饰,保证其他线程能够读取到当前变量的最新值。而且其内部使用了安全类 UNSAFE来保证volatile 语义

static final class HashEntry<K,V> {
    final int hash;
    final K key;
    volatile V value;
    volatile HashEntry<K,V> next;

    HashEntry(int hash, K key, V value, HashEntry<K,V> next) {
        this.hash = hash;
        this.key = key;
        this.value = value;
        this.next = next;
    }

    /**
     * Sets next field with volatile write semantics.  (See above
     * about use of putOrderedObject.)
     */
     // 这里使用了安全类 UNSAFE来保证volatile 语义
    final void setNext(HashEntry<K,V> n) {
        UNSAFE.putOrderedObject(this, nextOffset, n);
    }

    // Unsafe mechanics
    static final sun.misc.Unsafe UNSAFE;
    static final long nextOffset;
    static {
        try {
            UNSAFE = sun.misc.Unsafe.getUnsafe();
            Class k = HashEntry.class;
            nextOffset = UNSAFE.objectFieldOffset
                (k.getDeclaredField("next"));
        } catch (Exception e) {
            throw new Error(e);
        }
    }
}

③ 核心类Segment

如下所示,其实ConcurrentHashMap首先构建了Segment[],然后每一个Segment又包含了table,table最小长度是2。

static final class Segment<K,V> extends ReentrantLock implements Serializable {

         //tryLock 的最大次数
        static final int MAX_SCAN_RETRIES =
            Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;

         // 每一段的哈希桶,
         //元素通过entryAt/setEntryAt方法访问或者赋值以确保volatile 语义
        transient volatile HashEntry<K,V>[] table;

        //元素的个数 
        transient int count;

        // 结构修改计数器
        transient int modCount;

         // 需要rehashed的临界值/阈值 = capacity * loadFactor
         // 注意这些都是针对整个段来讲的,而不是某个tab[index].
        transient int threshold;

         // 负载因子 ,对所有segment来说是一致的,其是一个副本以避免与外部对象关联
        final float loadFactor;

//每一段的构造,主要包括负载因子、阈值以及哈希桶
        Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
            this.loadFactor = lf;
            this.threshold = threshold;
            this.table = tab;
        }
 //...
}       

③ 核心构造函数

如下所示是一系列重载的构造函数:

// DEFAULT_CONCURRENCY_LEVEL=16
public ConcurrentHashMap(int initialCapacity, float loadFactor) {
    this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL);
}

public ConcurrentHashMap(int initialCapacity) {
    this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}

// DEFAULT_INITIAL_CAPACITY = 16
// DEFAULT_LOAD_FACTOR = 0.75
// DEFAULT_CONCURRENCY_LEVEL = 16
public ConcurrentHashMap() {
    this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}

//使用给定的Map初始化
public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
    this(Math.max((int) (m.size() / DEFAULT_LOAD_FACTOR) + 1,
                  DEFAULT_INITIAL_CAPACITY),
         DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    putAll(m);
}

可以看到其本质都是依赖于下面这个构造函数,这也是我们需要重点分析的。

// 假设为16  0.75  16
public ConcurrentHashMap(int initialCapacity,
                         float loadFactor, int concurrencyLevel) {
    if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
        throw new IllegalArgumentException();
     
     //最大值65536
    if (concurrencyLevel > MAX_SEGMENTS)
        concurrencyLevel = MAX_SEGMENTS;
   
    // Find power-of-two sizes best matching arguments
    int sshift = 0;
    // 段的个数,如果小于16就增长到16 sshift记录增长的次数
    int ssize = 1;
    while (ssize < concurrencyLevel) {
        ++sshift;
        ssize <<= 1;
    }
    // 默认情况下 = 32-4 = 28
    this.segmentShift = 32 - sshift;
    //默认情况下 = 16-1 = 15
    this.segmentMask = ssize - 1;
    
    if (initialCapacity > MAXIMUM_CAPACITY)
        initialCapacity = MAXIMUM_CAPACITY;
    
    //默认c=16/16=1    
    int c = initialCapacity / ssize;
    if (c * ssize < initialCapacity)
        ++c;
    
    //    MIN_SEGMENT_TABLE_CAPACITY=2
    //cap 为每段内数组的大小,默认是2
    int cap = MIN_SEGMENT_TABLE_CAPACITY;
    while (cap < c)
        cap <<= 1;
    // create segments and segments[0]
    // 默认情况下so(0.75,1,HashEntry[2])
    Segment<K,V> s0 =
        new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                         (HashEntry<K,V>[])new HashEntry[cap]);
    //初始化16个段                     
    Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
	//把s0放到ss中
    UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
    this.segments = ss;
}

从上面代码可知s0段内tab[]的默认大小为2,阈值为1。这也就意味着上面初始化Segment<K,V> s0只能装入一个HashEntry即在插入第一个元素的时候不会触发扩容,插入第二个元素的时候就会进行第一次扩容。
也就是说,默认情况下数据结构如下图:
在这里插入图片描述

从上图可以发现其底层数据结构本质还是数组+链表。无非是在最外层分成了不同的段Segment,段内持有最少两个数组索引位置。同一个索引位置,通过next构成了链表。


【2】核心方法get

这里核心逻辑是首先定位到某个Segment,然后获取到Segment持有的tab[],再根据hash(key)定位到某个tab[i](索引位置或者称之为槽位)

public V get(Object key) {
    Segment<K,V> s; // manually integrate access methods to reduce overhead
    HashEntry<K,V>[] tab;
    // 获取到key的散列值
    int h = hash(key);
 
    // 默认情况下 h 无符号右移28位 & 15 ,然后 左移SSHIFT   ,然后 +SBASE
    // 其实也就是定位哪个Segment
    long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
   
    // 获取定位到的段
    if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
        (tab = s.table) != null) {
       
        // 使用(tab.length - 1) & h定位段内哪个数组位置 / 索引位置
        for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                 (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
             e != null; e = e.next) {
            K k;
            // 这段for循环就是基本的链表遍历
            if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                return e.value;
        }
    }
    return null;
}

方法总结如下:

  • ① 计算key的散列值并进而计算出属于哪个段Segment
  • ② UNSAFE获取到定位目标段
  • ((tab.length - 1) & h)) << TSHIFT) + TBASE定位到段内tab[]中索引位置
  • ④ 基本的链表遍历

可以看到无论是获取段还是获取某个元素,这里都是用了UNSAFE.getObjectVolatile来保证读取到目标的最新值(内存可见性)。

【3】核心方法put

虽然 HashEntry 中的 value 是用 volatile 关键词修饰的,但是并不能保证并发的原子性,所以 put 操作时仍然需要加锁处理。volatile 关键词只保证读取时候的内存可见性(读取到最新值)。

put是首先定位到段,对段进行加锁,然后put,最后解锁。故而其支持最大N个并发(N是段的个数,默认是16)。

public V put(K key, V value) {
    Segment<K,V> s;
    if (value == null)
        throw new NullPointerException();
     
     //计算key的散列值   
    int hash = hash(key);
    
    // hash 右移 28位  然后与 15 进行 & 操作
    // segmentMask:散列运算的掩码
    int j = (hash >>> segmentShift) & segmentMask;
    
    // 尝试获取段,判断是否为null,为null则创建
    if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
         (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
        s = ensureSegment(j);
   
    return s.put(key, hash, value, false);
}

可以看到这里首先计算key的hash值与段的索引位置,尝试获取到段然后触发s.put(key, hash, value, false)

① ensureSegment

ensureSegment方法是为了确保Segment,如果不存在则创建Segment。这里需要特别注意的是使用到了“自旋”(while循环)和CAS(UNSAFE.compareAndSwapObject)。

private Segment<K,V> ensureSegment(int k) {
    final Segment<K,V>[] ss = this.segments;
    // 段的索引位置进行偏移
    long u = (k << SSHIFT) + SBASE; // raw offset
    Segment<K,V> seg;
	//如果获取到的段位null
    if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
        Segment<K,V> proto = ss[0]; // use segment 0 as prototype
        int cap = proto.table.length;//获取到cap
        float lf = proto.loadFactor;//负载因子
        int threshold = (int)(cap * lf);//计算阈值
        //实例化tab[]
        HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];

        //再次判断是否为null
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
            == null) { // recheck
			//实例化段
            Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);

			//这一步是自旋,当获取到段位null时,调用UNSAFE的CAS算法进行赋值
			//while循环,直到成功
            while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                   == null) {
                if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                    break;
            }
        }
    }
    return seg;
}

ok,获取到段后我们继续往下看如何put。

② put(K key, int hash, V value, boolean onlyIfAbsent)

final V put(K key, int hash, V value, boolean onlyIfAbsent) {
//如果tryLock返回true,那么node为null;这里触发的是父类ReentrantLock的tryLock
//否则scanAndLockForPut 自旋获取锁
     HashEntry<K,V> node = tryLock() ? null :
         scanAndLockForPut(key, hash, value);
     V oldValue;
     try {
         HashEntry<K,V>[] tab = table;
         // 数组中的索引位置
         int index = (tab.length - 1) & hash;
         // 处于索引位置的结点
         HashEntry<K,V> first = entryAt(tab, index);
         for (HashEntry<K,V> e = first;;) {
             if (e != null) {//
                 K k;
                 // 如果key相等,则直接覆盖旧值
                 if ((k = e.key) == key ||
                     (e.hash == hash && key.equals(k))) {
                     oldValue = e.value;
                     if (!onlyIfAbsent) {
                         e.value = value;
                         ++modCount;
                     }
                     break;
                 }
                 // 向后遍历
                 e = e.next;
             }
             //如果不存在当前key,那么头插法,插入链表,first作为node的next结点
             else {
                 if (node != null)
                     node.setNext(first);
                 else
                     node = new HashEntry<K,V>(hash, key, value, first);
                 int c = count + 1;
                 //如果元素个数大于threshold ,且tab.length <  1 << 30
                 if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                 	//其实就是扩容
                     rehash(node);
                 else
                 //如果不需要扩容就将node放到目标位置
                     setEntryAt(tab, index, node);
                 ++modCount;
                 count = c;
                 oldValue = null;
                 break;
             }
         }
     } finally {
     	// 释放锁
         unlock();
     }
     return oldValue;
 }

put流程如下:

  • ① 尝试获取锁,tryLock() 或者scanAndLockForPut
  • ② 将当前 Segment 中的 table 通过 key 的 hash (tab.length - 1) & hash定位到 tab[]中的索引位置。
  • ③ 遍历索引位置的链表,获取每一个结点进行判断。如果不为空则判断传入的 key 和当前遍历的 key 是否相等,相等则覆盖旧的 value。
  • ④ 如果第三步不成功则判断node是否为空,如果不为空则node.setNext(first);否则需要新建一个 HashEntry 。判断是否需要扩容,如果需要就进行rehash(node),不需要扩容就setEntryAt(tab, index, node);
  • ⑤ 最后会解除在 1 中所获取当前 Segment 的锁。

可以看到这里链表插入元素采用了“头插法”。


③ scanAndLockForPut(K key, int hash, V value)

在put第一步的时候会尝试获取锁,如果获取失败肯定就有其他线程存在竞争,则利用 scanAndLockForPut() 自旋获取锁。

在尝试获取锁时扫描包含给定key的节点,如果未找到,则可能创建并返回一个。返回时,保证lock被保持。这个方法返回的node不一定为null,但是一定持有了锁。

private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {

// 根据hash确定其在哪个段的哪个数组位置
     HashEntry<K,V> first = entryForHash(this, hash);
     HashEntry<K,V> e = first;
     HashEntry<K,V> node = null;
     int retries = -1; // negative while locating node
     
     // 自旋获取锁
     while (!tryLock()) {
         HashEntry<K,V> f; // to recheck first below
         if (retries < 0) {
             if (e == null) {
                 if (node == null) // speculatively create node
                     node = new HashEntry<K,V>(hash, key, value, null);
                 retries = 0;
             }
             else if (key.equals(e.key))
                 retries = 0;
             else
                 e = e.next;
         }
         // 当自旋次数大于MAX_SCAN_RETRIES(1 或者 64),直接使用lock()来获取锁
         else if (++retries > MAX_SCAN_RETRIES) {
             lock();
             break;
         }
         // 当retries 为偶数时 且
         //f = entryForHash(this, hash)) != first-也就是entry发生了改变
         //比如第一次、第三次、第五次
         else if ((retries & 1) == 0 &&
                  (f = entryForHash(this, hash)) != first) {
             e = first = f; // re-traverse if entry changed
             retries = -1;
         }
     }
     return node;
 }

原理上来说:ConcurrentHashMap 采用了分段锁技术,其中 Segment 继承于 ReentrantLock。不会像 HashTable 那样不管是 put 还是 get 操作都需要做同步处理。理论上 ConcurrentHashMap 支持 CurrencyLevel (Segment 数组数量)的线程并发。每当一个线程占用锁访问一个 Segment 时,不会影响到其他的 Segment。

【4】核心方法rehash

这里会对某个段持有的tab[]进行二倍扩容,然后重新梳理链表进行定位并将新结点node放入。

private void rehash(HashEntry<K,V> node) {
    HashEntry<K,V>[] oldTable = table;
    int oldCapacity = oldTable.length;
    //二倍扩容
    int newCapacity = oldCapacity << 1;
    //新的临界值
    threshold = (int)(newCapacity * loadFactor);
    //实例化扩容后的数组进行迁移
    HashEntry<K,V>[] newTable =
        (HashEntry<K,V>[]) new HashEntry[newCapacity];
    // 容量大小掩码 其实就是length-1    
    int sizeMask = newCapacity - 1;
    for (int i = 0; i < oldCapacity ; i++) {
	      //数组索引位置的头结点
        HashEntry<K,V> e = oldTable[i];
        if (e != null) {
            HashEntry<K,V> next = e.next;
            // 在新tab[]中的索引位置
            int idx = e.hash & sizeMask;

            //只有一个节点,直接换地方
            if (next == null)   //  Single node on list
                newTable[idx] = e;
            else { // Reuse consecutive sequence at same slot
            	// 链表遍历
                HashEntry<K,V> lastRun = e;
                
                //记录idx=e.hash & sizeMask
                int lastIdx = idx;
                for (HashEntry<K,V> last = next;
                     last != null;
                     last = last.next) {
                     //当前遍历节点的索引位置
                    int k = last.hash & sizeMask;
                    if (k != lastIdx) {
                        lastIdx = k;//修改lastIdx  lastRun 
                        lastRun = last;
                    }
                }
                //记录 i 位置链表遍历最后的lastIdx,放到新的数组里面
                //把lastRun及以后的节点指向 lastIdx位置
                newTable[lastIdx] = lastRun;
                
                // 其他节点则采用头插法放到k = h & sizeMask位置,
                //k 可能等于idx等于lastIdx
                // Clone remaining nodes
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                    V v = p.value;
                    int h = p.hash;
                    int k = h & sizeMask;
                    HashEntry<K,V> n = newTable[k];
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                }
            }
        }
    }

    //把新结点node采用头插法插入newTable
    int nodeIndex = node.hash & sizeMask; // add the new node
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    table = newTable;
}

这个方法流程还是很清晰的,梳理如下:

  • 这里首先对oldCapacity进行了遍历,对每一个tab[i]进行链表遍历确定其在新数组的位置。
  • 在链表遍历过程中会尝试找到index发生变化的lastIdxlastRun,通过newTable[lastIdx] = lastRun;代码把lastRun及以后的节点指向 lastIdx位置。
  • 然后再处理节点e到节点lastRun的节点,采用头插法插入到newTable[k]位置。

【5】统计元素个数size

也就是统计map中key-value键值对的个数。如果map包含的元素个数超过了Integer.MAX_VALUE,那么就返回Integer.MAX_VALUE

尝试几次以获得准确的计数。如果由于表中的连续异步更改而导致失败,则求助于锁定(也就是会锁住全部Segment)。

public int size() {
    // Try a few times to get accurate count. On failure due to
    // continuous async changes in table, resort to locking.
    final Segment<K,V>[] segments = this.segments;
    int size;
    boolean overflow; // true if size overflows 32 bits
    long sum;         // sum of modCounts
    long last = 0L;   // previous sum
    int retries = -1; // first iteration isn't retry
    try {
    // 无限循环
        for (;;) {
        // 如果尝试次数达到了RETRIES_BEFORE_LOCK ,就将每一个segment加锁
        // 先拿retries与RETRIES_BEFORE_LOCK进行==判断,然后retries+1
            if (retries++ == RETRIES_BEFORE_LOCK) {
                for (int j = 0; j < segments.length; ++j)
                    ensureSegment(j).lock(); // force creation
            }
            sum = 0L;
            size = 0;
            overflow = false;
            for (int j = 0; j < segments.length; ++j) {
                Segment<K,V> seg = segmentAt(segments, j);
                if (seg != null) {
                    sum += seg.modCount;//结构修改次数
                    int c = seg.count; //元素个数
                    // 判断是否溢出
                    if (c < 0 || (size += c) < 0)
                        overflow = true;
                }
            }
            // 如果前后一致,break,否则就更新last为当前sum
            if (sum == last)
                break;
            last = sum;
        }
    } finally {
    // 解锁
        if (retries > RETRIES_BEFORE_LOCK) {
            for (int j = 0; j < segments.length; ++j)
                segmentAt(segments, j).unlock();
        }
    }
    // 返回size ,如果溢出了,就返回Integer.MAX_VALUE=2^31-1
    return overflow ? Integer.MAX_VALUE : size;
}

也就是说首先直接将所有的 Segment 不加锁, 直接统计数量。统计过程中同时对每个 Segment 的 modCount 进行加总(modCount 记录了每个 Segment 被修改的次数)。重复上面的过程, 然后比较前后两次 modCount 总和是否一样, 相等就说明中间没有线程更改过结构(比如添加或者移除), 直接返回得到的 size 大小即可。

如果重试次数达到了3次,也就是总共循环了四次,那么直接将所有的Segment加锁进行元素数量统计。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

流烟默

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值