执行器:如果以递归方式创建任务,如何同步等待所有任务完成?

2022-09-04 20:42:14

我的问题与这里的这个问题密切相关。正如在那里发布的那样,我希望主线程等到工作队列为空并且所有任务都已完成。但是,在我的情况下,问题是每个任务都可能递归地导致提交新任务进行处理。这使得收集所有这些任务的未来有点尴尬。

我们当前的解决方案使用忙等待循环来等待终止:

        do { //Wait until we are done the processing
      try {
        Thread.sleep(200);
      } catch (InterruptedException e) {
        throw new RuntimeException(e);
      }
    } while (!executor.getQueue().isEmpty()
             || numTasks.longValue() > executor.getCompletedTaskCount());

numTasks 是一个值,随着每个新任务的创建而增加。这有效,但我认为由于繁忙的等待,这不是很好。我想知道是否有一种好方法可以使主线程同步等待,直到被显式唤醒。


答案 1

非常感谢您的所有建议!

最后,我选择了我认为相当简单的东西。我发现CountDownLatch几乎是我需要的。它会阻塞,直到计数器达到 0。唯一的问题是它只能倒计时,不能向上倒计时,因此在我拥有的动态设置中不起作用,任务可以提交新任务。因此,我实现了一个提供附加功能的新类。(见下文)然后,我按如下方式使用此类。CountLatch

主线程调用,阻塞直到闩锁达到 0。latch.awaitZero()

任何线程,在调用 调用 之前。executor.execute(..)latch.increment()

任何任务,在完成之前,都会调用 。latch.decrement()

当最后一个任务终止时,计数器将达到 0,从而释放主线程。

我们非常欢迎您提出进一步的建议和反馈!

public class CountLatch {

@SuppressWarnings("serial")
private static final class Sync extends AbstractQueuedSynchronizer {

    Sync(int count) {
        setState(count);
    }

    int getCount() {
        return getState();
    }

    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }

    protected int acquireNonBlocking(int acquires) {
        // increment count
        for (;;) {
            int c = getState();
            int nextc = c + 1;
            if (compareAndSetState(c, nextc))
                return 1;
        }
    }

    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c - 1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

private final Sync sync;

public CountLatch(int count) {
    this.sync = new Sync(count);
}

public void awaitZero() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

public boolean awaitZero(long timeout, TimeUnit unit) throws InterruptedException {
    return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

public void increment() {
    sync.acquireNonBlocking(1);
}

public void decrement() {
    sync.releaseShared(1);
}

public String toString() {
    return super.toString() + "[Count = " + sync.getCount() + "]";
}

}

请注意,/调用可以封装到一个自定义的子类中,例如,由Sami Korhonen建议,或者与impl建议的那样。请参阅此处:increment()decrement()ExecutorbeforeExecuteafterExecute

public class CountingThreadPoolExecutor extends ThreadPoolExecutor {

protected final CountLatch numRunningTasks = new CountLatch(0);

public CountingThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit,
        BlockingQueue<Runnable> workQueue) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
}

@Override
public void execute(Runnable command) {
    numRunningTasks.increment();
    super.execute(command);
}

@Override
protected void afterExecute(Runnable r, Throwable t) {
    numRunningTasks.decrement();
    super.afterExecute(r, t);
}

/**
 * Awaits the completion of all spawned tasks.
 */
public void awaitCompletion() throws InterruptedException {
    numRunningTasks.awaitZero();
}

/**
 * Awaits the completion of all spawned tasks.
 */
public void awaitCompletion(long timeout, TimeUnit unit) throws InterruptedException {
    numRunningTasks.awaitZero(timeout, unit);
}

}

答案 2

Java 7 提供了一个适合此用例的同步器,称为 Phaser。它是CountDownLatch和CyclicBarrier的可重用混合体,既可以增加又可以减少注册方的数量(类似于可增量的CountDownLatch)。

在此方案中使用阶段程序的基本模式是在创建时向阶段程序注册任务,并在任务完成时到达。当到达方的数量与注册的人数匹配时,相位器“前进”到下一阶段,在发生时通知任何等待线程。

下面是我创建的等待递归任务完成的示例。它天真地找到斐波那契数列的前几个数字用于演示目的:

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Phaser;
import java.util.concurrent.atomic.AtomicLong;

/**
 * An example of using a Phaser to wait for the completion of recursive tasks.
 * @author Voxelot
 */
public class PhaserExample {
    /** Workstealing threadpool with reduced queue contention. */
    private static ForkJoinPool executors;

    /**
     * @param args the command line arguments
     */
    public static void main(String[] args) throws InterruptedException {
        executors = new ForkJoinPool();
        List<Long> sequence = new ArrayList<>();
        for (int i = 0; i < 20; i++) {
            sequence.add(fib(i));
        }
        System.out.println(sequence);
    }

    /**
     * Computes the nth Fibonacci number in the Fibonacci sequence.
     * @param n The index of the Fibonacci number to compute
     * @return The computed Fibonacci number
     */
    private static Long fib(int n) throws InterruptedException {
        AtomicLong result = new AtomicLong();
        //Flexible sychronization barrier
        Phaser phaser = new Phaser();
        //Base task
        Task initialTask = new Task(n, result, phaser);
        //Register fib(n) calling thread
        phaser.register();
        //Submit base task
        executors.submit(initialTask);
        //Make the calling thread arrive at the synchronization
        //barrier and wait for all future tasks to arrive.
        phaser.arriveAndAwaitAdvance();
        //Get the result of the parallel computation.
        return result.get();
    }

    private static class Task implements Runnable {
        /** The Fibonacci sequence index of this task. */
        private final int index;
        /** The shared result of the computation. */
        private final AtomicLong result;
        /** The synchronizer. */
        private final Phaser phaser;

        public Task(int n, AtomicLong result, Phaser phaser) {
            index = n;
            this.result = result;
            this.phaser = phaser;
            //Inform synchronizer of additional work to complete.
            phaser.register();
        }

        @Override
        public void run() {
            if (index == 1) {
                result.incrementAndGet();
            } else if (index > 1) {
                //recurrence relation: Fn = Fn-1 + Fn-2
                Task task1 = new Task(index - 1, result, phaser);
                Task task2 = new Task(index - 2, result, phaser);
                executors.submit(task1);
                executors.submit(task2);
            }
            //Notify synchronizer of task completion.
            phaser.arrive();
        }
    }
}

推荐