JAVA并发-并发集合

115 阅读25分钟

ThreadLocal

ThreadLocal可以称为线程本地变量,或者线程本地存储;作用是提供线程内的局部变量(变量只在当前线程有效,并且在线程的生命周期内起作用,线程消失后变量也会消失)

原理:每个Thread创建时自带一个ThreadLocal.ThreadLocalMap容器,key是ThreadLocal的实例,value是对应ThreadLocal存储的值

public class Thread implements Runnable {
    // 保存在Thread实例中的ThreadLocalMap对象
    ThreadLocal.ThreadLocalMap threadLocals = null;
}

public class ThreadLocal<T> {
    // 构造方法
    public ThreadLocal() {
    }
     
    // 内部类ThreadLocalMap
    static class ThreadLocalMap {
        // Entry 是一个以ThreadLocal为key,Object为value的键值对
        // Entry 是弱引用,赋值为null是一定会被GC回收
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
 		// 初始容量
        private static final int INITIAL_CAPACITY = 16;
 		// Entry 数组
        private Entry[] table;
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); // 计算哈希
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }
        // set 方法
        private void set(ThreadLocal<?> key, Object value) {  
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1); // 计算hash
			// 循环,直到e == null; i = nextIndex(i, len)采用线性探测
            for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get(); 
                if (k == key) {
                    e.value = value;
                    return;
                } 
                // e != null && k == null 说明 ThreadLocal被GC回收了
                if (k == null) {
                    // 覆盖原来的 key
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
			// tab[i] == null 然后重新赋值
            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }
        // 可以看出这里解决hash冲突,向数组的下一个位置
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }
    }
}

set方法

// 获取当前线程所对应的ThreadLocalMap,如果不为空,则调用ThreadLocalMap的set()方法,key就是当前ThreadLocal,
public void set(T value) {
    // 当前线程
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) // 调用 ThreadLocalMap 的set()方法,key就是当前ThreadLocal
        map.set(this, value);
    else // 如果不存在,则调用createMap()方法新建一个
        createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}
// 未当前线程t 创建一个 ThreadLocalMap
void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

get方法

public T get() {
    Thread t = Thread.currentThread(); // 获取当前线程
    // 获取当前线程的成员变量 ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        // 获取map中 key == 当前threadLocal实例 的 entry
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}
// 初始化方法,和set方法一样,把初始值放入容器
private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}
// 初始值,默认为null,提供模板可以让子类重写,赋值其他的初始值
protected T initialValue() {
    return null;
}

总结

  • ThreadLocal 不是用于解决共享变量的问题的,也不是为了协调线程同步而存在,而是为了方便每个线程处理自己的状态而引入的一个机制。这点至关重要。
  • 每个Thread内部都有一个ThreadLocal.ThreadLocalMap类型的成员变量,该成员变量用来存储实际的ThreadLocal变量副本。
  • ThreadLocal并不是为线程保存对象的副本,它仅仅只起到一个索引的作用。它的主要目的是为每一个线程隔离一个类的实例,这个实例的作用范围仅限于线程内部。

注意事项

脏数据

线程复用会产生脏数据。由于线程池会重用Thread对象,那么与Thread绑定的类的静态属性ThreadLocal变量也会被重用。如果在实现的线程run()方法体中不显式地调用remove() 清理与线程相关的ThreadLocal信息,那么倘若下一个结程不调用set() 设置初始值,就可能get() 到重用的线程信息,包括 ThreadLocal所关联的线程对象的value值。

内存泄漏

通常我们会使用使用static关键字来修饰ThreadLocal(这也是在源码注释中所推荐的)。在此场景下,其生命周期就不会随着线程结束而结束,寄希望于ThreadLocal对象失去引用后,触发弱引用机制来回收Entry的Value就不现实了。如果不进行remove() 操作,那么这个线程执行完成后,通过ThreadLocal对象持有的对象是不会被释放的。

以上两个问题的解决办法很简单,就是在每次用完ThreadLocal时, 必须要及时调用 remove()方法清理。

父子线程共享线程变量

很多场景下通过ThreadLocal来透传全局上下文,会发现子线程的value和主线程不一致。比如用ThreadLocal来存储监控系统的某个标记位,暂且命名为traceId。某次请求下所有的traceld都是一致的,以获得可以统一解析的日志文件。但在实际开发过程中,发现子线程里的traceld为null,跟主线程的并不一致。这就需要使用InheritableThreadLocal来解决父子线程之间共享线程变量的问题,使整个连接过程中的traceId一致。

并发容器(线程安全集合)

ConcurrentHashMap

public class ConcurrentHashMap<K,V> extends AbstractMap<K,V>
    implements ConcurrentMap<K,V>, Serializable {
    transient volatile Node<K,V>[] table; // Node数组
    private transient volatile Node<K,V>[] nextTable; // 扩容数组
    private transient volatile CounterCell[] counterCells; // 元素数量数组,空间换取时间的思想
    // 用来控制table数组的大小
    // -1,表示有线程正在进行初始化操作 
    // -(1 + nThreads),表示有n个线程正在一起扩容
	// 0,默认值,后续在真正初始化的时候使用默认容量
	// > 0,初始化或扩容完成后下一次的扩容门槛
    private transient volatile int sizeCtl;
    // Unsafe类,Native方法,直接操作内存,通过CAS保证线程安全
    private static final sun.misc.Unsafe U;
    // 内部类,Node结点,单向链表
    static class Node<K,V> implements Map.Entry<K,V> {
        final int hash; // hash值
        final K key;
        volatile V val; // 值,和HashMap相比,多了volatile修饰符,
        volatile Node<K,V> next; // next结点,和HashMap相比,多了volatile修饰符,
        Node(int hash, K key, V val, Node<K,V> next) {
            this.hash = hash;
            this.key = key;
            this.val = val;
            this.next = next;
        } 
        public final K getKey()       { return key; }
        public final V getValue()     { return val; }
        public final int hashCode()   { return key.hashCode() ^ val.hashCode(); }
        public final String toString(){ return key + "=" + val; }
        public final V setValue(V value) {
            throw new UnsupportedOperationException();
        }
        // 遍历链表
        Node<K,V> find(int h, Object k) {
            Node<K,V> e = this;
            if (k != null) {
                do {
                    K ek;
                    if (e.hash == h &&
                        ((ek = e.key) == k || (ek != null && k.equals(ek))))
                        return e;
                } while ((e = e.next) != null);
            }
            return null;
        }
    }
    // 构造方法,和HashMap相比取消了loadFactor和threshold,使用sizeCtl控制table数组
    public ConcurrentHashMap() {
    } 
    public ConcurrentHashMap(int initialCapacity) {
        if (initialCapacity < 0)
            throw new IllegalArgumentException();
        int cap = ((initialCapacity >= (MAXIMUM_CAPACITY >>> 1)) ?
                   MAXIMUM_CAPACITY :
                   tableSizeFor(initialCapacity + (initialCapacity >>> 1) + 1));
        this.sizeCtl = cap;
    } 
    public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
        this.sizeCtl = DEFAULT_CAPACITY;
        putAll(m);
    } 
    public ConcurrentHashMap(int initialCapacity, float loadFactor) {
        this(initialCapacity, loadFactor, 1);
    } 
    public ConcurrentHashMap(int initialCapacity,  float loadFactor, int concurrencyLevel) {
        if (!(loadFactor > 0.0f) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        if (initialCapacity < concurrencyLevel)   // Use at least as many bins
            initialCapacity = concurrencyLevel;   // as estimated threads
        long size = (long)(1.0 + (long)initialCapacity / loadFactor);
        int cap = (size >= (long)MAXIMUM_CAPACITY) ?
            MAXIMUM_CAPACITY : tableSizeFor((int)size);
        this.sizeCtl = cap;
    }

}

put方法

public V put(K key, V value) {
    return putVal(key, value, false);
} 
final V putVal(K key, V value, boolean onlyIfAbsent) {
    if (key == null || value == null) 
        throw new NullPointerException();
    // 获取key的hash值 32位的hashCode 前16位和后16位做^(异或运算),
    // 尽快能保证hash 不重复,离散性高
    int hash = spread(key.hashCode());
    int binCount = 0;
    // 自旋,直到成功 高并发 table 有 volatile 修饰,保证可见性
    for (Node<K,V>[] tab = table;;) {
        Node<K,V> f; int n, i, fh;
        if (tab == null || (n = tab.length) == 0)
            tab = initTable(); // 初始化数组
        // (n - 1) & hash  key映射数组的下标
        // 注意f 在这里赋值了
        else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
            // 如果这个位置为空,通过cas直接复制
            if (casTabAt(tab, i, null,new Node<K,V>(hash, key, value, null)))
                break;                 
        }
        // f.hash == -1 说明有其他线程对数组进行扩容
        else if ((fh = f.hash) == MOVED)
            // 帮助扩容
            tab = helpTransfer(tab, f);
        else {
            // 找到hash结点
            V oldVal = null;
            synchronized (f) { // 锁住这个结点
                // 再次判断,防止别的线程修改,如果不等,会重新循环
                if (tabAt(tab, i) == f) {
                    if (fh >= 0) { // fh >= 0 说明结点是链表
                        binCount = 1;
                        // 遍历这个链表,找到key对应的结点
                        for (Node<K,V> e = f;; ++binCount) {
                            K ek;
                            if (e.hash == hash &&
                                ((ek = e.key) == key ||
                                 (ek != null && key.equals(ek)))) {
                                oldVal = e.val;
                                // 如果目标key和当前结点key相同,直接赋值
                                if (!onlyIfAbsent)
                                    e.val = value;
                                break;
                            }
                            Node<K,V> pred = e;
                            if ((e = e.next) == null) {
                                // 如果目标key和链表中所有结点key都不同,
                                // 创建一个新结点放到链表尾部
                                pred.next = new Node<K,V>(hash, key,value, null);
                                break;
                            }
                        }
                    }
                    // 判断结点是不是红黑树
                    else if (f instanceof TreeBin) {
                        Node<K,V> p;
                        binCount = 2;
                        // 红黑树put
                        if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key, value)) != null) {
                            oldVal = p.val;
                            if (!onlyIfAbsent)
                                p.val = value;
                        }
                    }
                }
            }
            if (binCount != 0) {
                //如果链表长度已经达到临界值8 就需要把链表转换为红黑树
                if (binCount >= TREEIFY_THRESHOLD)
                    treeifyBin(tab, i);
                if (oldVal != null)
                    return oldVal;
                break;
            }
        }
    }
    // 计算元素个数,并检查是否需要扩容
    addCount(1L, binCount);
    return null;
} 

hash算法

// 获取key的hash值 32位的hashCode 前16位和后16位做^(异或运算),
// 尽快能保证hash 不重复,离散性高
int hash = spread(key.hashCode());
static final int HASH_BITS = 0x7fffffff; 
static final int spread(int h) {
    // 前16位和后16位异或运算
    return (h ^ (h >>> 16)) & HASH_BITS;
}
// 根据hash计算tab数组的下标 i = (n - 1) & hash 和 HashMap一致
// 使用tabAt
Node<K,V>[] tab = table;
f = tabAt(tab, i = (n - 1) & hash)

CAS

使用Unsafe试下Node结点的线程安全操作

static final <K,V> Node<K,V> tabAt(Node<K,V>[] tab, int i) {
    return (Node<K,V>)U.getObjectVolatile(tab, ((long)i << ASHIFT) + ABASE);
} 
static final <K,V> boolean casTabAt(Node<K,V>[] tab, int i, Node<K,V> c, Node<K,V> v) {
    return U.compareAndSwapObject(tab, ((long)i << ASHIFT) + ABASE, c, v);
} 
static final <K,V> void setTabAt(Node<K,V>[] tab, int i, Node<K,V> v) {
    U.putObjectVolatile(tab, ((long)i << ASHIFT) + ABASE, v);
}

扩容

// 初始化Node数组
private final Node<K,V>[] initTable() {
    Node<K,V>[] tab; int sc;
    while ((tab = table) == null || tab.length == 0) {
        // sizeCtl < 0 说明存在其他线程执行初始化
        if ((sc = sizeCtl) < 0)
            Thread.yield(); // 让出CPU时间片
        //通过cas操作,将sizeCtl替换为-1,标识当前线程抢占到了初始化资格
        else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
            try {
                if ((tab = table) == null || tab.length == 0) {
                    int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
                    @SuppressWarnings("unchecked")
                    Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                    table = tab = nt;
                    // 计算下次扩容的阈值 n - n/4 = 0.75n
                    sc = n - (n >>> 2);
                }
            } finally {
                sizeCtl = sc;
            }
            break;
        }
    }
    return tab;
}
// 协助扩容,并返回新的Node数组
final Node<K,V>[] helpTransfer(Node<K,V>[] tab, Node<K,V> f) {
    Node<K,V>[] nextTab; int sc; 
    // 扩容时会把第一个元素置为ForwardingNode,并让其nextTab指向新的数组
    if (tab != null && (f instanceof ForwardingNode) &&
        (nextTab = ((ForwardingNode<K,V>)f).nextTable) != null) {
        int rs = resizeStamp(tab.length);
        while (nextTab == nextTable && table == tab &&
               (sc = sizeCtl) < 0) { // sizeCtl<0,说明正在扩容
            if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                sc == rs + MAX_RESIZERS || transferIndex <= 0)
                break;
            if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1)) {
                transfer(tab, nextTab);
                break;
            }
        }
        return nextTab;
    }
    return table;
}
// 扩容操作,把tab数组元素复制给nextTab(新的Node数组)
private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
    int n = tab.length, stride;
    if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
        stride = MIN_TRANSFER_STRIDE; // subdivide range
    if (nextTab == null) {            // initiating
        // 如果nextTab为空,说明还没开始迁移 
        try {
            @SuppressWarnings("unchecked")
            // 就新建一个新数组,大小翻倍n << 1
            Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
            nextTab = nt;
        } catch (Throwable ex) {      // try to cope with OOME
            sizeCtl = Integer.MAX_VALUE;
            return;
        }
        nextTable = nextTab;
        transferIndex = n;
    }
    int nextn = nextTab.length;
    // 新建一个ForwardingNode类型的节点,并把nextTab存储在里面
    // ForwardingNode(Node<K,V>[] tab) { ForwardingNode构造方法会把结点的hash值赋值为MOVED(-1)
    //	  super(MOVED, null, null, null);
    //    this.nextTable = tab; }
    ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
    boolean advance = true;
    boolean finishing = false; // to ensure sweep before committing nextTab
    for (int i = 0, bound = 0;;) {// 自旋
        Node<K,V> f; int fh;
        while (advance) {
            int nextIndex, nextBound;
            if (--i >= bound || finishing)
                advance = false;
            else if ((nextIndex = transferIndex) <= 0) {
                i = -1;
                advance = false;
            }
            else if (U.compareAndSwapInt
                     (this, TRANSFERINDEX, nextIndex,
                      nextBound = (nextIndex > stride ?
                                   nextIndex - stride : 0))) {
                bound = nextBound;
                i = nextIndex - 1;
                advance = false;
            }
        }
        if (i < 0 || i >= n || i + n >= nextn) {// 遍历结束
            int sc;
            if (finishing) { // 扩容完成时,将新数组nextTab赋值给table
                nextTable = null;
                table = nextTab;
                sizeCtl = (n << 1) - (n >>> 1); // sizeCtl 1.5n = 0.75N(新数组大小)
                return;
            }
            if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
                if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
                    return;
                finishing = advance = true;
                i = n; // recheck before commit
            }
        }
        // 如果位置 i 是空的,放入fwd结点,通知其他线程数组正在扩容
        else if ((f = tabAt(tab, i)) == null)
            advance = casTabAt(tab, i, null, fwd);
        else if ((fh = f.hash) == MOVED)// 表示该位置已经完成了迁移
            advance = true; // already processed
        else {
            synchronized (f) {
                if (tabAt(tab, i) == f) {
                    Node<K,V> ln, hn; // 高低位链表,和HashMap中的方案一样
                    if (fh >= 0) { // 处理链表
                        int runBit = fh & n;
                        Node<K,V> lastRun = f;
                        for (Node<K,V> p = f.next; p != null; p = p.next) {
                            int b = p.hash & n;
                            if (b != runBit) {
                                runBit = b;
                                lastRun = p;
                            }
                        }
                        if (runBit == 0) {
                            ln = lastRun;
                            hn = null;
                        }
                        else {
                            hn = lastRun;
                            ln = null;
                        }
                        for (Node<K,V> p = f; p != lastRun; p = p.next) {
                            int ph = p.hash; K pk = p.key; V pv = p.val;
                            if ((ph & n) == 0)
                                ln = new Node<K,V>(ph, pk, pv, ln); // 低位链表
                            else
                                hn = new Node<K,V>(ph, pk, pv, hn); // 高位链表
                        }
                        setTabAt(nextTab, i, ln);// 低位链表还在原来的位置i
                        setTabAt(nextTab, i + n, hn);// 高位链表 在原来的位置i+n
                        setTabAt(tab, i, fwd); // 在table的i位置上插入forwardNode节点  表示已处理
                        advance = true;
                    }
                    // TreeBin(TreeNode<K,V> b) { super(TREEBIN, null, null, null);}
            		// TreeBin的构造方法中把hash字段赋值为TREEBIN(-2)
                    else if (f instanceof TreeBin) { // 红黑树
                        TreeBin<K,V> t = (TreeBin<K,V>)f;
                        TreeNode<K,V> lo = null, loTail = null;
                        TreeNode<K,V> hi = null, hiTail = null;
                        int lc = 0, hc = 0;
                        for (Node<K,V> e = t.first; e != null; e = e.next) {
                            int h = e.hash;
                            TreeNode<K,V> p = new TreeNode<K,V>
                                (h, e.key, e.val, null, null);
                            if ((h & n) == 0) {
                                if ((p.prev = loTail) == null)
                                    lo = p;
                                else
                                    loTail.next = p;
                                loTail = p;
                                ++lc;
                            }
                            else {
                                if ((p.prev = hiTail) == null)
                                    hi = p;
                                else
                                    hiTail.next = p;
                                hiTail = p;
                                ++hc;
                            }
                        }
                        ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
                        (hc != 0) ? new TreeBin<K,V>(lo) : t;
                        hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
                        (lc != 0) ? new TreeBin<K,V>(hi) : t;
                        setTabAt(nextTab, i, ln);
                        setTabAt(nextTab, i + n, hn);
                        setTabAt(tab, i, fwd);
                        advance = true;
                    }
                }
            }
        }
    }
}

size方法

public int size() {
    long n = sumCount();
    return ((n < 0L) ? 0 :
            (n > (long)Integer.MAX_VALUE) ? Integer.MAX_VALUE :
            (int)n);
}
// 对counterCells数组中的元素求和,空间换取时间
final long sumCount() {
    CounterCell[] as = counterCells; CounterCell a;
    long sum = baseCount;
    if (as != null) {
        for (int i = 0; i < as.length; ++i) {
            if ((a = as[i]) != null)
                sum += a.value;
        }
    }
    return sum;
}
// 计算元素个数
private final void addCount(long x, int check) {
    CounterCell[] as; long b, s;
    if ((as = counterCells) != null ||
        !U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
        CounterCell a; long v; int m;
        boolean uncontended = true;
        if (as == null || (m = as.length - 1) < 0 ||
            // 通过 ThreadLocalRandom.getProbe() & m 找到 线程所在的数组下标,减少并发
            (a = as[ThreadLocalRandom.getProbe() & m]) == null ||
            !(uncontended =
              U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
            fullAddCount(x, uncontended);
            return;
        }
        if (check <= 1)
            return;
        // 计算元素个数
        s = sumCount();
    }
    // 检验是否需要进行扩容操作
    if (check >= 0) {
        Node<K,V>[] tab, nt; int n, sc;
        while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
               (n = tab.length) < MAXIMUM_CAPACITY) {
            int rs = resizeStamp(n);
            if (sc < 0) { // sc < 0说明正在扩容中
                if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                    sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
                    transferIndex <= 0)
                    break;
                if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
                    transfer(tab, nt);
            }
            else if (U.compareAndSwapInt(this, SIZECTL, sc, (rs << RESIZE_STAMP_SHIFT) + 2))
                transfer(tab, null);
            s = sumCount();
        }
    }
}

get方法

public V get(Object key) {
    Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
    int h = spread(key.hashCode());  //获取hash
    if ((tab = table) != null && (n = tab.length) > 0 &&
        (e = tabAt(tab, (n - 1) & h)) != null) {
        // 第一个元素就是要找的元素,直接返回
        if ((eh = e.hash) == h) {
            if ((ek = e.key) == key || (ek != null && key.equals(ek)))
                return e.val;
        }
        else if (eh < 0) // 当前节点hash小于0说明为树节点,在红黑树中查找即可
            // TreeBin(TreeNode<K,V> b) { super(TREEBIN, null, null, null);}
            // TreeBin的构造方法中把hash字段赋值为TREEBIN(-2)
            return (p = e.find(h, key)) != null ? p.val : null;
        while ((e = e.next) != null) {// 链表循环
            if (e.hash == h &&
                ((ek = e.key) == key || (ek != null && key.equals(ek))))
                return e.val;
        }
    }
    return null;
}

remove方法

public V remove(Object key) {
    return replaceNode(key, null, null);
} 
final V replaceNode(Object key, V value, Object cv) {
    int hash = spread(key.hashCode());
    for (Node<K,V>[] tab = table;;) { // 自旋
        Node<K,V> f; int n, i, fh;
        if (tab == null || (n = tab.length) == 0 ||
            // 目标key的元素不存在,跳出循环返回null
            (f = tabAt(tab, i = (n - 1) & hash)) == null)
            break;
        else if ((fh = f.hash) == MOVED)
            tab = helpTransfer(tab, f); // 协助扩容
        else {
            V oldVal = null;
            boolean validated = false;
            synchronized (f) {
                // 验证单曲Node结点是否被修改
                if (tabAt(tab, i) == f) {
                    if (fh >= 0) {
                        validated = true;
                        // 遍历链表寻找目标节点
                        for (Node<K,V> e = f, pred = null;;) {
                            K ek;
                            if (e.hash == hash &&
                                ((ek = e.key) == key ||
                                 (ek != null && key.equals(ek)))) {
                                V ev = e.val;
                                if (cv == null || cv == ev ||
                                    (ev != null && cv.equals(ev))) {
                                    oldVal = ev;
                                    if (value != null)
                                        e.val = value;
                                    else if (pred != null)
                                        pred.next = e.next;
                                    else
                                        setTabAt(tab, i, e.next);
                                }
                                break;
                            }
                            pred = e;
                            if ((e = e.next) == null)
                                break;
                        }
                    }
                    else if (f instanceof TreeBin) { // 红黑树
                        validated = true;
                        TreeBin<K,V> t = (TreeBin<K,V>)f;
                        TreeNode<K,V> r, p;
                        // 遍历树找到目标节点
                        if ((r = t.root) != null &&
                            (p = r.findTreeNode(hash, key, null)) != null) {
                            V pv = p.val;
                            if (cv == null || cv == pv ||
                                (pv != null && cv.equals(pv))) {
                                oldVal = pv;
                                if (value != null)
                                    p.val = value;
                                else if (t.removeTreeNode(p))
                                    setTabAt(tab, i, untreeify(t.first));
                            }
                        }
                    }
                }
            }
            if (validated) {
                if (oldVal != null) {
                    if (value == null)
                        addCount(-1L, -1); // 重新计算size
                    return oldVal;
                }
                break;
            }
        }
    }
    return null;
}

CopyOnWriteArrayList

CopyOnWriteArrayList是线程安全的List,采用数组存储,通过ReentrantLock保证线程安全

public class CopyOnWriteArrayList<E>
    implements List<E>, RandomAccess, Cloneable, java.io.Serializable {
    /** 通过重入锁保证线程安全 */
    final transient ReentrantLock lock = new ReentrantLock();
    /** 通过数组记录元素,并通过volatile保证可见性 */
    private transient volatile Object[] array;
    // set方法,给array赋值
    final void setArray(Object[] a) {
        array = a;
    }
 	// 默认构造方法,创建空集合
    public CopyOnWriteArrayList() {
        setArray(new Object[0]);
    }
 	// Collection构造方法
    public CopyOnWriteArrayList(Collection<? extends E> c) {
        Object[] elements;
        // 把Collection转换为数组
        if (c.getClass() == CopyOnWriteArrayList.class)
            elements = ((CopyOnWriteArrayList<?>)c).getArray();
        else {
            elements = c.toArray(); 
            if (elements.getClass() != Object[].class)
                elements = Arrays.copyOf(elements, elements.length, Object[].class);
        }
        setArray(elements);
    } 
    public CopyOnWriteArrayList(E[] toCopyIn) {
        setArray(Arrays.copyOf(toCopyIn, toCopyIn.length, Object[].class));
    }
    
    // get操作,通过volatile保证可见性 
    public E get(int index) {
        return get(getArray(), index);
    }
    private E get(Object[] a, int index) {
        return (E) a[index];
    }
}

添加元素

// 新增元素
public boolean add(E e) {
    ensureCapacityInternal(size + 1);  // 扩容
    elementData[size++] = e; // 直接赋值
    return true;
}
// 新增元素,通过ReentrantLock保证线程安全 
public boolean add(E e) {
    final ReentrantLock lock = this.lock;
    lock.lock(); // 获得锁
    try {
        Object[] elements = getArray();
        int len = elements.length;
        // 每次把原来的数组复制到新数组,新数组大小是旧数组大小加1
        Object[] newElements = Arrays.copyOf(elements, len + 1);
        newElements[len] = e; // 将元素e放到最后
        // 每次重新赋值
        setArray(newElements);
        return true;
    } finally {
        lock.unlock();// 释放锁
    }
}
// 在 index 位置插入元素
public void add(int index, E element) {
    final ReentrantLock lock = this.lock;
    lock.lock(); // 获得锁
    try {
        Object[] elements = getArray();
        int len = elements.length;
        if (index > len || index < 0) // 检查index是否越界
            throw new IndexOutOfBoundsException("Index: "+index+ ", Size: "+len);
        Object[] newElements;
        int numMoved = len - index;
        // 如果不需要移动元素,直接复制
        if (numMoved == 0)
            newElements = Arrays.copyOf(elements, len + 1);
        else {
            // 如果需要移动元素(在中间进行插入) ,先创建len + 1大小的数组
            newElements = new Object[len + 1];
            // 把旧数组0~index位置的元素复制到新数组
            System.arraycopy(elements, 0, newElements, 0, index);
            // 把旧数组index~最后位置的元素复制到新数组index+1位置
            System.arraycopy(elements, index, newElements, index + 1, numMoved);
        }
        // index位置赋值
        newElements[index] = element;
        setArray(newElements);
    } finally {
        lock.unlock(); // 释放锁
    }
}

删除元素

// 移除第index个元素
public E remove(int index) {
    final ReentrantLock lock = this.lock;
    lock.lock(); // 获得锁
    try {
        Object[] elements = getArray();
        int len = elements.length;
        E oldValue = get(elements, index);
        int numMoved = len - index - 1;
        // 如果不需要移动元素,直接复制0~len-1到新数组
        if (numMoved == 0)
            setArray(Arrays.copyOf(elements, len - 1));
        else {
            // 如果需要移动元素(在中间进行删除) ,先创建len - 1大小的数组
            Object[] newElements = new Object[len - 1];
            // 把旧数组0~index位置的元素复制到新数组
            System.arraycopy(elements, 0, newElements, 0, index);
            // 把旧数组index+1~最后位置的元素复制到新数组index位置 
            System.arraycopy(elements, index + 1, newElements, index, numMoved);
            setArray(newElements);
        }
        return oldValue;
    } finally {
        lock.unlock(); // 释放锁
    }
}

修改元素

// 替换元素
public E set(int index, E element) {
    final ReentrantLock lock = this.lock;
    lock.lock(); // 获得锁
    try {
        Object[] elements = getArray();
        E oldValue = get(elements, index); 
        if (oldValue != element) {
            int len = elements.length;
            // 把旧数组全部复制,然后在新数组上修改index
            Object[] newElements = Arrays.copyOf(elements, len);
            newElements[index] = element;
            setArray(newElements);
        } else { 
            setArray(elements);
        }
        return oldValue;
    } finally {
        lock.unlock(); // 释放锁
    }
}

查找元素

// 返回list中索引为index的元素
public E get(int index) { 
    return get(getArray(), index);
}
// 获取数组中的元素,并类型转换 
private E get(Object[] a, int index) {
    return (E) a[index];
}

CopyOnWriteArraySet

通过对CopyOnWriteArrayList新增时判断元素是否存在,保证Set元素不重复

public class CopyOnWriteArraySet<E> extends AbstractSet<E>
        implements java.io.Serializable {
    // 内部封装了一个CopyOnWriteArrayList    
    private final CopyOnWriteArrayList<E> al;   
    // 构造方法
    public CopyOnWriteArraySet() {
        al = new CopyOnWriteArrayList<E>();
    }
    // 新增方法 CopyOnWriteArrayList 先判断元素e,是否在数组中存在,不存在新增
    public boolean add(E e) {
        return al.addIfAbsent(e); 
    }
}
// CopyOnWriteArrayList
public boolean addIfAbsent(E e) {
    Object[] snapshot = getArray();
    // 判断元素e是否存在,存在则不用add直接返回false,保证Set元素不重复
    return indexOf(e, snapshot, 0, snapshot.length) >= 0 ? false : 
    addIfAbsent(e, snapshot);
}
private boolean addIfAbsent(E e, Object[] snapshot) {
    final ReentrantLock lock = this.lock;
    lock.lock(); // 获得锁
    try {
        Object[] current = getArray(); // 获取数组元素
        int len = current.length;
        // 判断数组是否被修改过
        if (snapshot != current) {
            // Optimize for lost race to another addXXX operation
            int common = Math.min(snapshot.length, len);
            for (int i = 0; i < common; i++)
                if (current[i] != snapshot[i] && eq(e, current[i]))
                    return false;
            if (indexOf(e, current, common, len) >= 0) // 重新判断元素e是否存在
                return false;
        }
        // 创建一个新数组,长度为len+1,把旧数组的元素复制到新数组
        Object[] newElements = Arrays.copyOf(current, len + 1);
        newElements[len] = e; // 把元素e放到最后
        setArray(newElements);
        return true;
    } finally {
        lock.unlock(); // 释放锁
    }
}

ConcurrentLinkedQueue

线程安全的队列

public class ConcurrentLinkedQueue<E> extends AbstractQueue<E>
        implements Queue<E>, java.io.Serializable {
    // 头结点
    private transient volatile Node<E> head;
	// 尾结点
    private transient volatile Node<E> tail;
    // 内部类Node 结点 -- 单向链表    
    private static class Node<E> {
        volatile E item;
        volatile Node<E> next; 
        Node(E item) {
            UNSAFE.putObject(this, itemOffset, item);
        } 
        // CAS操作
        boolean casItem(E cmp, E val) {
            return UNSAFE.compareAndSwapObject(this, itemOffset, cmp, val);
        }

        void lazySetNext(Node<E> val) {
            UNSAFE.putOrderedObject(this, nextOffset, val);
        }
		// CAS next结点
        boolean casNext(Node<E> cmp, Node<E> val) {
            return UNSAFE.compareAndSwapObject(this, nextOffset, cmp, val);
        }
        private static final sun.misc.Unsafe UNSAFE;
        private static final long itemOffset;
        private static final long nextOffset;
        static {
            try {
                UNSAFE = sun.misc.Unsafe.getUnsafe();
                Class<?> k = Node.class;
                // 直接取内存地址偏移量
                itemOffset = UNSAFE.objectFieldOffset(k.getDeclaredField("item"));
                nextOffset = UNSAFE.objectFieldOffset(k.getDeclaredField("next"));
            } catch (Exception e) {
                throw new Error(e);
            }
        }
    }
	// 构造方法
    public ConcurrentLinkedQueue() { head = tail = new Node<E>(null);}
     
}

入队

public boolean offer(E e) {
    checkNotNull(e); // 元素e不能为null
    final Node<E> newNode = new Node<E>(e); // 新节点
    for (Node<E> t = tail, p = t;;) { // 自旋,找到尾结点
        Node<E> q = p.next; 
        if (q == null) { // 确定是尾结点
            // 通过CAS把当前结点设位 p结点的next
            if (p.casNext(null, newNode)) {
                // 成功,通过CAS将 newNode 设置为 tail 结点
                if (p != t) // hop two nodes at a time
                    casTail(t, newNode);  // Failure is OK.
                return true; // 结束自旋
            }

        } else if (p == q)
            // 如果p等于q(p的next),说明p已经出队了
            p = (t != (t = tail)) ? t : head;
        else 
            // 如果尾部t发生变化,则重新获取尾部,再重试
            p = (p != t && t != (t = tail)) ? t : q;
    }
}

出队

public E poll() {
    restartFromHead: //goto,跳出循环
    for (;;) {
        for (Node<E> h = head, p = h, q;;) { // 自旋
            E item = p.item; 
            if (item != null && p.casItem(item, null)) {// 通过CAS尝试把p.item设置为null 
                if (p != h) // hop two nodes at a time
                    updateHead(h, ((q = p.next) != null) ? q : p);
                return item;
            }
            else if ((q = p.next) == null) { // 如果p的next为空,说明队列中没有元素了
                updateHead(h, p);
                return null;
            }
            else if (p == q)
                 // 如果p等于q(p的next),说明p已经出队了 
                continue restartFromHead;
            else
                p = q; // p = p.next 重新CAS
        }
    }
}

取队首元素

public E peek() {
    restartFromHead: //goto,跳出循环
    for (;;) {
        for (Node<E> h = head, p = h, q;;) {
            E item = p.item;
            if (item != null || (q = p.next) == null) {
                updateHead(h, p);
                return item;
            }
            else if (p == q) // 如果p等于q(p的next),说明p已经出队了 
                continue restartFromHead;
            else
                p = q; // p = p.next 重新CAS
        }
    }
}

阻塞队列

阻塞队列:一个指定长度的队列,如果队列满了,添加新元素的操作会被阻塞等待,直到有空位为止。同样,当队列为空时候,请求队列元素的操作同样会阻塞等待,直到有可用元素为止。

// 阻塞队列接口BlockingQueue
public interface BlockingQueue<E> extends Queue<E> {
    // 元素入队,队列为满时阻塞
    void put(E e) throws InterruptedException;
    // 元素入队,带有超时机制
    boolean offer(E e, long timeout, TimeUnit unit) throws InterruptedException;
    // 元素出队,队列为空时阻塞   
    E take() throws InterruptedException;
    // 元素出队,带有超时机制
    E poll(long timeout, TimeUnit unit) throws InterruptedException;
}

ArrayBlockingQueue

// 阻塞队列 数组实现
public class ArrayBlockingQueue<E> extends AbstractQueue<E>
		implements BlockingQueue<E>, java.io.Serializable { 
    // 数组元素集合,循环队列
    final Object[] items;
    int takeIndex;// head元素指针,指向数组中head的位置 
    int putIndex;// tail元素指针,指向数组中tail的位置
    int count; // 队列元素个数
    // 重入锁,保证线程安全   
    final ReentrantLock lock;
    // 通过Condition,出队时,队列如果为空,take方法阻塞 
    private final Condition notEmpty;
    // 通过Condition,入队时,队列如果满了,put方法阻塞
    private final Condition notFull;
   
    // 默认构造方法,默认非公平锁
    public ArrayBlockingQueue(int capacity) {
        this(capacity, false);
    } 
    public ArrayBlockingQueue(int capacity, boolean fair) {
        if (capacity <= 0)
            throw new IllegalArgumentException();
        // 创建数组
        this.items = new Object[capacity];
        // 创建重入锁
        lock = new ReentrantLock(fair);
        notEmpty = lock.newCondition();
        notFull =  lock.newCondition();
    }     
}

入队

// 元素入队,队列为满时阻塞
public void put(E e) throws InterruptedException {
    checkNotNull(e);
    final ReentrantLock lock = this.lock;
    lock.lockInterruptibly(); // 加锁
    try {
        // 如果数组满了,使用notFull等待
        while (count == items.length) // 队列元素已满
            notFull.await(); // 阻塞,等待
        enqueue(e);// 入队
    } finally {
        lock.unlock(); // 释放锁
    }
}
// 元素入队,带有超时机制
public boolean offer(E e, long timeout, TimeUnit unit)
    throws InterruptedException {
    checkNotNull(e);
    long nanos = unit.toNanos(timeout); // 超时时间
    final ReentrantLock lock = this.lock;
    lock.lockInterruptibly(); // 加锁
    try {
        while (count == items.length) { // 队列元素已满
            if (nanos <= 0)
                return false;
            nanos = notFull.awaitNanos(nanos); // notFull阻塞,等待时长为 nanos
        }
        enqueue(e); // 入队
        return true;
    } finally {
        lock.unlock(); // 释放锁
    }
}
// 入队,唤醒notEmpty 
private void enqueue(E x) { 
    final Object[] items = this.items;
    items[putIndex] = x;
    if (++putIndex == items.length)
        putIndex = 0;
    count++;
    // 唤醒notEmpty ,因为入队了一个元素,所以肯定不为空了
    notEmpty.signal(); 
}

出队

// 元素出队,队列为空时阻塞
public E take() throws InterruptedException {
    final ReentrantLock lock = this.lock;
    lock.lockInterruptibly(); // 加锁
    try {
        while (count == 0) // 队列元素为空
            notEmpty.await(); // 阻塞,等待
        return dequeue(); // 出队
    } finally {
        lock.unlock(); // 释放锁
    }
}
// 元素出队,带有超时机制
public E poll(long timeout, TimeUnit unit) throws InterruptedException {
    long nanos = unit.toNanos(timeout); // 超时时间
    final ReentrantLock lock = this.lock;
    lock.lockInterruptibly(); // 加锁
    try {
        while (count == 0) {// 队列元素为空
            if (nanos <= 0)
                return null;
            nanos = notEmpty.awaitNanos(nanos);// 阻塞,等待时长为 nanos
        }
        return dequeue(); // 出队
    } finally {
        lock.unlock(); // 释放锁
    }
}
// 出队,唤醒notFull
private E dequeue() { 
    final Object[] items = this.items;
    @SuppressWarnings("unchecked")
    E x = (E) items[takeIndex];
    items[takeIndex] = null;
    if (++takeIndex == items.length)
        takeIndex = 0;
    count--;
    if (itrs != null)
        itrs.elementDequeued();
    // 唤醒notFull,因为出队了一个元素,所以肯定不为满
    notFull.signal(); 
    return x;
}

取队首元素

public E peek() {
    final ReentrantLock lock = this.lock;
    lock.lock(); // 加锁
    try {
        return itemAt(takeIndex); // null when queue is empty
    } finally {
        lock.unlock();// 释放锁
    }
}
final E itemAt(int i) {
    return (E) items[i];
}

LinkedBlockingQueue

// 阻塞队列 链表实现
public class LinkedBlockingQueue<E> extends AbstractQueue<E> 
		implements BlockingQueue<E>, java.io.Serializable {
    // 队列元素,单链表    
    static class Node<E> {
        E item; 
        Node<E> next;
        Node(E x) { item = x; }
    }
   	// 队列容量 
    private final int capacity;
	// 队列大小,用原子类保证线程安全
    private final AtomicInteger count = new AtomicInteger();
	// 队首结点
    transient Node<E> head;
	// 队尾结点
    private transient Node<E> last;
 	// take锁和take Condition
    private final ReentrantLock takeLock = new ReentrantLock();
    private final Condition notEmpty = takeLock.newCondition();
	// put锁和put Condition
    private final ReentrantLock putLock = new ReentrantLock();
    private final Condition notFull = putLock.newCondition();   
    
    // 默认构造方法,默认Integer.MAX_VALUE个元素
    public LinkedBlockingQueue() {
        this(Integer.MAX_VALUE);
    } 
    public LinkedBlockingQueue(int capacity) {
        if (capacity <= 0) throw new IllegalArgumentException();
        this.capacity = capacity;
        last = head = new Node<E>(null);
    }
}

入队

// 元素入队,队列为满时阻塞
public void put(E e) throws InterruptedException {
    if (e == null) throw new NullPointerException(); 
    int c = -1;
    Node<E> node = new Node<E>(e);
    final ReentrantLock putLock = this.putLock;
    final AtomicInteger count = this.count;
    putLock.lockInterruptibly(); // 加锁
    try { 
        while (count.get() == capacity) { // 队列元素已满
            notFull.await(); // 阻塞,等待 
        }
        enqueue(node);
        // count 是原子类,多线程下能保证可见性
        // 这里 c 获取到的是count原来的值,而不是自增后的值
        c = count.getAndIncrement();
        // 如果现队列长度如果小于容量 唤醒 notFull 
        if (c + 1 < capacity)
            notFull.signal();
    } finally {
        putLock.unlock(); // 释放锁
    }
    // c 获取到的是count原来的值,如果原队列长度为0,现在加了一个元素后立即唤醒notEmpty条件
    if (c == 0)
        signalNotEmpty(); // 唤醒notEmpty 
}
// 元素入队,带有超时机制
public boolean offer(E e, long timeout, TimeUnit unit)
    throws InterruptedException {
    if (e == null) throw new NullPointerException();
    long nanos = unit.toNanos(timeout);
    int c = -1;
    final ReentrantLock putLock = this.putLock;
    final AtomicInteger count = this.count;
    putLock.lockInterruptibly(); // 加锁
    try {
        while (count.get() == capacity) { // 队列元素已满
            if (nanos <= 0)
                return false;
            nanos = notFull.awaitNanos(nanos); // 阻塞,等待时长为 nanos
        }
        enqueue(new Node<E>(e));
        // count 是原子类,多线程下能保证可见性
        // 这里 c 获取到的是count原来的值,而不是自增后的值
        c = count.getAndIncrement();  
        // 如果现队列长度如果小于容量 唤醒 notFull 
        if (c + 1 < capacity)
            notFull.signal();  
    } finally {
        putLock.unlock(); // 释放锁
    }
    // c 获取到的是count原来的值,如果原队列长度为0,现在加了一个元素后立即唤醒notEmpty条件
    if (c == 0)
        signalNotEmpty(); // 唤醒 notEmpty 
    return true;
} 
// 入队,直接加到last之后
private void enqueue(Node<E> node) { 
    last = last.next = node;
}
private void signalNotEmpty() {
    final ReentrantLock takeLock = this.takeLock;
    takeLock.lock();
    try { 
        notEmpty.signal();
    } finally {
        takeLock.unlock();
    }
}

出队

// 元素出队,队列为空时阻塞
public E take() throws InterruptedException {
    E x;
    int c = -1;
    final AtomicInteger count = this.count;
    final ReentrantLock takeLock = this.takeLock;
    takeLock.lockInterruptibly(); // 加锁
    try {
        while (count.get() == 0) { // 队列元素为空
            notEmpty.await(); // 阻塞,等待 
        }
        x = dequeue();
        // count 是原子类,多线程下能保证可见性
        // 这里 c 获取到的是count原来的值,而不是自减后的值
        c = count.getAndDecrement();
        if (c > 1)  
            notEmpty.signal(); 
    } finally {
        takeLock.unlock(); // 释放锁
    }
    // c 获取到的是count原来的值,如果原队列长度为capacity(已满),
    // 现在出队一个元素后立即达成唤醒 notFull 条件
    if (c == capacity)
        signalNotFull();  // 唤醒 notFull 
    return x;
}
// 元素出队,带有超时机制
public E poll(long timeout, TimeUnit unit) throws InterruptedException {
    E x = null;
    int c = -1;
    long nanos = unit.toNanos(timeout);
    final AtomicInteger count = this.count;
    final ReentrantLock takeLock = this.takeLock;
    takeLock.lockInterruptibly(); // 加锁
    try {
        while (count.get() == 0) { // 队列元素为空
            if (nanos <= 0)
                return null;
            nanos = notEmpty.awaitNanos(nanos); // 阻塞,等待时长为 nanos
        }
        x = dequeue();
        // count 是原子类,多线程下能保证可见性
        // 这里 c 获取到的是count原来的值,而不是自减后的值
        c = count.getAndDecrement();
        if (c > 1)
            notEmpty.signal();
    } finally {
        takeLock.unlock(); // 释放锁
    }
    // c 获取到的是count原来的值,如果原队列长度为capacity(已满),
    // 现在出队一个元素后立即达成唤醒 notFull 条件
    if (c == capacity)
        signalNotFull(); // 唤醒 notFull 
    return x;
}

取队首元素

public E peek() {
    if (count.get() == 0)
        return null;
    final ReentrantLock takeLock = this.takeLock;
    takeLock.lock(); // 加锁
    try {
        Node<E> first = head.next;
        if (first == null)
            return null;
        else
            return first.item;
    } finally {
        takeLock.unlock();// 释放锁
    }
}

ArrayBlockingQueue与LinkedBlockingQueue的比较

相同点:ArrayBlockingQueue和LinkedBlockingQueue都是通过condition通知机制来实现可阻塞式插入和删除元素,并满足线程安全的特性;

不同点

  1. ArrayBlockingQueue底层是采用的数组进行实现,而LinkedBlockingQueue则是采用链表数据结构;
  2. ArrayBlockingQueue插入和删除数据,只采用了一个lock,而LinkedBlockingQueue则是在插入和删除分别采用了putLocktakeLock,这样可以降低线程由于线程无法获取到lock而进入WAITING状态的可能性,从而提高了线程并发执行的效率。