CountDownLatch、CyclicBarrier、Semaphore实现原理 - litter-fish/ReadSource GitHub Wiki

CountDownLatch

基本结构

// 同步器
private final Sync sync;

构造函数

// 构造方法,会使用传入的count参数直接创建一个Sync对象
public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

同步器的构造方法,直接设置AQS的state值

Sync(int count) {
    // 直接将state设置为count
    setState(count);
}

await操作

await操作相当于锁的获取操作,获得锁的条件是state值为0 提供两个版本

public void await() throws InterruptedException {
    // Sync的acquireSharedInterruptibly()方法来自于父类AQS
    sync.acquireSharedInterruptibly(1);
}

// 超时版本
public boolean await(long timeout, TimeUnit unit)
    throws InterruptedException {
    // Sync的tryAcquireSharedNanos()方法来自于父类AQS
    return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

await()方法调用了sync对象的acquireSharedInterruptibly()方法,该方法继承自AQS, acquireSharedInterruptibly()方法会调用tryAcquireShared(arg),该方法Sync进行了重写 源码如下:

protected int tryAcquireShared(int acquires) {
    // 当state为0时表示获取共享锁成功,否则失败
    return (getState() == 0) ? 1 : -1;
}

countDown操作

countDown相当于锁的释放

public void countDown() {
    // Sync的releaseShared()方法来自于父类AQS
    sync.releaseShared(1);
}

countDown()方法调用了sync的releaseShared(1),该方法继承自AQS的 releaseShared(int)方法中又会调用tryReleaseShared(int),这个方法被CountDownLatch的Sync重写了

protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    // 自旋更新state,当state为0时表示可以唤醒阻塞的线程了
    for (;;) {
        int c = getState();
        if (c == 0)
            return false;
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

CyclicBarrier

基本结构

/** 用于控制栅栏入口的锁 */
private final ReentrantLock lock = new ReentrantLock();
/** 用于控制线程等待的Condition对象 */
private final Condition trip = lock.newCondition();
/** 参与的线程总数 */
private final int parties;
/* 当栅栏被放开时执行的回调 */
private final Runnable barrierCommand;
/** 当前这一代栅栏 */
private Generation generation = new Generation();

// 用于记录当前已经处于等待状态的线程数
private int count;

CyclicBarrier使用了ReentrantLock来保证多线程对栅栏的同步访问,并且使用了Condition方式来控制线程的等待和唤醒。

CyclicBarrier是可以重复使用的,而实现重复使用的方式就是在其内部维护了一个辅助类Generation,该类是一个静态内部类,只有一个成员变量broken用于标识当前这一代的栅栏是否被强制释放:

// 表示一代栅栏
private static class Generation {
    // 栅栏是否被强制释放
    boolean broken = false;
}

CyclicBarrier提供了快捷方法nextGeneration()用于直接开启下一代栅栏:

// 开启下一代栅栏
private void nextGeneration() {
    // 唤醒所有阻塞的线程
    trip.signalAll();
    // 重置count
    count = parties;
    // 创建一个新的Generation对象
    generation = new Generation();
}

唤醒阻塞在当前这一代栅栏上的所有线程,然后重置count参数为原始的parties,接着创建新一代Generation。

reset()重置方法,用于强制释放栅栏并开启下一代Generation:

// 重置操作,这个操作会强制释放栅栏、唤醒所有线程并开启下一代
public void reset() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        // 强制释放栅栏
        breakBarrier();   // break the current generation
        // 开启下一代Generation
        nextGeneration(); // start a new generation
    } finally {
        lock.unlock();
    }
}

breakBarrier()方法如下:

// 强制释放栅栏
private void breakBarrier() {
    // 更新栅栏是否是被强制释放的记录
    generation.broken = true;
    // 重置count
    count = parties;
    // 唤醒所有阻塞的线程
    trip.signalAll();
}

一些获取某些内部状态的辅助方法:

// 获取parties
public int getParties() {
    return parties;
}

// 判断栅栏是否已被释放
public boolean isBroken() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        return generation.broken;
    } finally {
        lock.unlock();
    }
}

// 获取等待的线程数量
public int getNumberWaiting() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        return parties - count;
    } finally {
        lock.unlock();
    }
}

await操作

当我们调用CyclicBarrier的await()方法时就会让当前线程进入等待阻塞状态,直到指定数量的所有的线程都调用了await(),这些进入等待阻塞状态的线程才会被唤醒。await()方法另外有一个带有超时机制的重载方法,它们都调用了内部的dowait(false, 0L)方法:

// 主要的await等待方法
private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException,
           TimeoutException {
    // 拿到重入锁并上锁
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        // 拿到Generation对象
        final Generation g = generation;
        // 如果当前Generation对象的栅栏是否已经被释放
        if (g.broken)
            throw new BrokenBarrierException();

        // 如果出现异常,将强制释放栅栏,避免被阻塞的线程饿死,同时抛出异常
        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }

        // 计数减1
        int index = --count;
        // 释放栅栏
        if (index == 0) {  // tripped
            // 如果计数变为0,表示栅栏可以释放了
            // 记录是否执行了释放栅栏的变量
            boolean ranAction = false;
            try {
                // 执行释放栅栏的回调线程
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run();

                /**
                 * 更新是否执行了释放栅栏的变量为true
                 * 注意,如果ranAction更新为true,说明barrierCommand执行没有抛错
                 * 如果如果ranAction没有被更新为true,则可能是barrierCommand执行抛错了,
                 * 将不会执行后面的nextGeneration()代码唤醒阻塞的线程,
                 * 因此需要在finally块中强制释放栅栏,避免阻塞线程饿死
                 */
                ranAction = true;
                /**
                 * 释放栅栏后开启下一个新的Generation
                 * nextGeneration()里会调用trip.signalAll()唤醒所有阻塞线程
                 */
                nextGeneration();
                return 0;
            } finally {
                /**
                 * 检查是否执行了释放栅栏的回调线程,
                 * 如果没执行说明可能是在执行释放栅栏的回调线程时抛错了
                 * 因此就强制释放栅栏
                 */
                if (!ranAction)
                    breakBarrier();
            }
        }

        // loop until tripped, broken, interrupted, or timed out
        // 自旋循环,直到栅栏被放开、被强制释放栅栏、遇到中断操作或超时
        for (;;) {
            try {
                if (!timed)
                    /**
                     * 没有使用超时机制
                     * 这里的await会将线程所在节点移入条件队列,
                     * 然后释放state值,挂起线程
                     */
                    trip.await();
                else if (nanos > 0L)
                    // 带有超时的挂起
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                // 运行到这里,说明在挂起期间被中断了
                if (g == generation && ! g.broken) {
                    // 如果还是当前这一代的栅栏,但栅栏没有被释放,就强制释放栅栏
                    breakBarrier();
                    // 然后抛出异常
                    throw ie;
                } else {
                    // We're about to finish waiting even if we had not
                    // been interrupted, so this interrupt is deemed to
                    // "belong" to subsequent execution.
                    /**
                     * 这种捕获了InterruptException之后调用Thread.currentThread().interrupt()是一种通用的方式。
                     * 其实就是为了保存中断状态,从而让其他更高层次的代码注意到这个中断
                     */
                    Thread.currentThread().interrupt();
                }
            }

            /**
             * 当栅栏被释放时,抛出BrokenBarrierException异常
             * 比如某个线程在await期间被中断了,它会调用breakBarrier()
             * 而breakBarrier()会将g.broken设置为true,然后唤醒所有线程
             * 因此其他线程唤醒后运行到这里就会抛出BrokenBarrierException异常
             */
            if (g.broken)
                throw new BrokenBarrierException();

            // 已经重置为下一代了,直接返回index
            if (g != generation)
                return index;

            if (timed && nanos <= 0L) {
                // 超时了,强制释放栅栏,并抛出超时异常
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        // 释放锁
        lock.unlock();
    }
}

Semaphore

基本结构

Semaphore的实现使用的是AQS的共享模式;Semaphore中Sync是一个抽象类,它实现了大部分的基础结构和方法。在创建Semaphore实例时,Sync的构造方法会被调用,并将传给Semaphore构造方法的参数premits传入:

// Sync的构造方法
Sync(int permits) {
    // 直接将state设置为permits
    setState(permits);
}

Sync对permits参数的处理是直接使用AQS提供的state来控制:

// 获取permits
final int getPermits() {
    // 直接获取state
    return getState();
}

// 返回立即可用的所有许可,并将state置为0
final int drainPermits() {
    for (;;) {
        int current = getState();
        if (current == 0 || compareAndSetState(current, 0))
            return current;
    }
}

同时Semaphore提供了对permits进行各类快捷操作的方法:

// 获取可用的permit数量
public int availablePermits() {
    return sync.getPermits();
}

// 获取并返回立即可用的所有许可
public int drainPermits() {
    return sync.drainPermits();
}

// 根据指定的缩减量减小可用许可的数目
protected void reducePermits(int reduction) {
    if (reduction < 0) throw new IllegalArgumentException();
    sync.reducePermits(reduction);
}

Sync的继承有公平同步器FairSync和非公平同步器NonfairSync两种:

// 默认情况下使用的是非公平同步器
public Semaphore(int permits) {
    sync = new NonfairSync(permits);
}

public Semaphore(int permits, boolean fair) {
    // 根据传入的fair参数决定使用哪种同步器
    sync = fair ? new FairSync(permits) : new NonfairSync(permits);
}

acquire操作

Semaphore的acquire(int)方法,当线程调用Semaphore实例的acquire(int)方法时会传入要求的许可量,如果获取不到要求的许可量就会被阻塞,acquire(int)方法存在一个没有参数的重载acquire(),默认传入的是permits参数为1,它们的源码如下:

public void acquire() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

public void acquire(int permits) throws InterruptedException {
    if (permits < 0) throw new IllegalArgumentException();
    sync.acquireSharedInterruptibly(permits);
}

发现都是调用的是Sync的acquireSharedInterruptibly(int)方法,这个方法继承自AQS:

public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

AQS的acquireSharedInterruptibly(int)方法内部先判断了发生中断的情况,然后调用了tryAcquireShared(arg),这个方法是被Semaphore的内部类Sync重写的,即以获取共享锁的方式来控制线程是否成功获取要求的许可量; 以FairSync中重写的为例:

protected int tryAcquireShared(int acquires) {
    for (;;) {
        if (hasQueuedPredecessors())
            // 公平模式下需要先判断同步队列中是否有线程已经等待很久了
            return -1;
        int available = getState();
        // 计算从state中减去获取的许可量后的值remaining
        int remaining = available - acquires;
        /**
         * 如果remaining小于0,则直接返回remaining,注意此时返回的remaining是小于0的,表示获取失败
         * 如果remaining大于等于0,则尝试CAS修改state为remaining,如果修改成功就返回remaining,表示获取成功
         * 否则自旋进入下一次的尝试
         */
        if (remaining < 0 || compareAndSetState(available, remaining))
            return remaining;
    }
}

release操作

release操作用于释放许可量,它在Semaphore中体现为两个重载方法:

// 释放1个许可量
public void release() {
    sync.releaseShared(1);
}

// 释放指定数量的许可量
public void release(int permits) {
    if (permits < 0) throw new IllegalArgumentException();
    sync.releaseShared(permits);
}

与acquire操作非常类似,release操作调用了Sync的releaseShared(int)方法,而这个方法继承自AQS:

public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

tryReleaseShared(arg)则被Sync重写:

protected final boolean tryReleaseShared(int releases) {
    for (;;) {
        // 获取当前state值
        int current = getState();
        // 计算添加释放的许可量后state的值
        int next = current + releases;
        // 判断是否溢出
        if (next < current) // overflow
            throw new Error("Maximum permit count exceeded");
        if (compareAndSetState(current, next))
            // 修改state为新值后返回true表示释放成功
            return true;
    }
}

非公平模式

Semaphore中的非公平同步器的实现是NonFairSync,非公平信号量许可的释放与公平信号量许可的释放是一样的,不同的是它们获取许可量的机制不同,非公平同步器的tryAcquireShared(int)调用了父类Sync中nonfairTryAcquireShared(int):

protected int tryAcquireShared(int acquires) {
    return nonfairTryAcquireShared(acquires);
}

nonfairTryAcquireShared(int)方法的实现如下:

final int nonfairTryAcquireShared(int acquires) {
    for (;;) {
        // 设置可以获得的信号量的许可数
        int available = getState();
        // 设置获得acquires个信号量许可之后,剩余的信号量许可数
        int remaining = available - acquires;
        // 如果剩余的信号量许可数>=0,则设置可以获得的信号量许可数为remaining。
        if (remaining < 0 ||
            compareAndSetState(available, remaining))
            return remaining;
    }
}