diff --git a/heron/common/src/java/com/twitter/heron/common/basics/WakeableLooper.java b/heron/common/src/java/com/twitter/heron/common/basics/WakeableLooper.java index 8744990bda4..e989c2baa3f 100644 --- a/heron/common/src/java/com/twitter/heron/common/basics/WakeableLooper.java +++ b/heron/common/src/java/com/twitter/heron/common/basics/WakeableLooper.java @@ -18,6 +18,8 @@ import java.util.ArrayList; import java.util.List; import java.util.PriorityQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; /** * A WakeableLooper is a class that could: @@ -52,12 +54,15 @@ public abstract class WakeableLooper { // We will also multiple 1000*1000 to convert mill-seconds to nano-seconds private static final Duration INFINITE_FUTURE = Duration.ofMillis(Integer.MAX_VALUE); private volatile boolean exitLoop; + // Used as a flag that the looper has exited after exitLoop() is called. + private final CountDownLatch exitCountDownLatch; public WakeableLooper() { exitLoop = false; tasksOnWakeup = new ArrayList(); timers = new PriorityQueue(); exitTasks = new ArrayList<>(); + exitCountDownLatch = new CountDownLatch(1); } public void clear() { @@ -95,6 +100,21 @@ private void onExit() { for (Runnable r : exitTasks) { r.run(); } + exitCountDownLatch.countDown(); + } + + /** + * After exitLoop() is called, caller can use waitForExit() to make sure + * the looper has finished/skipped all sheduled tasks and the runOnce() function + * won't be called any more. + * @return true if the count down lanch reaches 0, false if the wait times out or is interrupted. + */ + public boolean waitForExit(long timeout, TimeUnit unit) { + try { + return exitCountDownLatch.await(timeout, unit); + } catch (InterruptedException e) { + return false; + } } protected abstract void doWait(); diff --git a/heron/common/src/java/com/twitter/heron/common/network/SocketChannelHelper.java b/heron/common/src/java/com/twitter/heron/common/network/SocketChannelHelper.java index ebdee5581ba..196585b6b90 100644 --- a/heron/common/src/java/com/twitter/heron/common/network/SocketChannelHelper.java +++ b/heron/common/src/java/com/twitter/heron/common/network/SocketChannelHelper.java @@ -21,6 +21,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Queue; +import java.util.concurrent.TimeUnit; import java.util.logging.Logger; import com.twitter.heron.common.basics.ByteAmount; @@ -213,6 +214,10 @@ public void write() { // Force to flush all data in underneath buffer queue to socket with best effort // It is most likely happen when we are handling some unexpected cases, such as exiting public void forceFlushWithBestEffort() { + looper.exitLoop(); + // Wait for NIO loop to confirm stopping process. + looper.waitForExit(10, TimeUnit.SECONDS); + LOG.info("Forcing to flush data to socket with best effort."); while (!outgoingPacketsToWrite.isEmpty()) { int writeState = outgoingPacketsToWrite.poll().writeToChannel(socketChannel); diff --git a/heron/common/tests/java/com/twitter/heron/common/basics/WakeableLooperTest.java b/heron/common/tests/java/com/twitter/heron/common/basics/WakeableLooperTest.java index 3256bac0048..1fd4c8dce4f 100644 --- a/heron/common/tests/java/com/twitter/heron/common/basics/WakeableLooperTest.java +++ b/heron/common/tests/java/com/twitter/heron/common/basics/WakeableLooperTest.java @@ -17,6 +17,7 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.time.Duration; +import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Assert; @@ -27,7 +28,7 @@ * WakeableLooper Tester. */ public class WakeableLooperTest { - private static int globalValue; + private static volatile int globalValue; private WakeableLooper slaveLooper; @Before @@ -41,6 +42,19 @@ public void after() { slaveLooper = null; } + class LooperThread extends Thread { + private WakeableLooper looper; + + LooperThread(WakeableLooper looper) { + super(); + this.looper = looper; + } + + public void run() { + looper.loop(); + } + } + /** * Method: loop() */ @@ -142,6 +156,65 @@ public void run() { Assert.assertEquals(10, globalValue); } + + /** + * Method: waitForExit() + */ + @Test + public void testWaitForExit() { + int sleepTimeMS = 200; + Runnable r = new Runnable() { + @Override + public void run() { + try { + slaveLooper.exitLoop(); // Exit after the first wake up + Thread.sleep(sleepTimeMS); + globalValue = 10; + } catch (InterruptedException e) { + return; + } + } + }; + LooperThread looperThread = new LooperThread(slaveLooper); + looperThread.start(); + long startTime = System.nanoTime(); + slaveLooper.addTasksOnWakeup(r); + // Wait for it to finish. + boolean ret = slaveLooper.waitForExit(sleepTimeMS * 2, TimeUnit.MILLISECONDS); + long endTime = System.nanoTime(); + + Assert.assertTrue(ret); + Assert.assertTrue(endTime - startTime >= sleepTimeMS * 1000); + Assert.assertEquals(10, globalValue); + } + + @Test + public void testWaitForExitTimeout() { + int sleepTimeMS = 200; + Runnable r = new Runnable() { + @Override + public void run() { + try { + slaveLooper.exitLoop(); // Exit after the first wake up + Thread.sleep(sleepTimeMS); + globalValue = 10; + } catch (InterruptedException e) { + return; + } + } + }; + LooperThread looperThread = new LooperThread(slaveLooper); + looperThread.start(); + long startTime = System.nanoTime(); + slaveLooper.addTasksOnWakeup(r); + // Wait for it to finish. + boolean ret = slaveLooper.waitForExit(sleepTimeMS / 10, TimeUnit.MILLISECONDS); + long endTime = System.nanoTime(); + + Assert.assertFalse(ret); + Assert.assertEquals(6, globalValue); + } + /** * Method: getNextTimeoutInterval() */