diff --git a/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java b/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java index 311a8da9900..6e344e917e6 100644 --- a/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java +++ b/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java @@ -35,6 +35,8 @@ import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.util.ThreadExceptionHandler; +import org.apache.celeborn.common.util.Utils; +import org.apache.celeborn.common.write.PushState; public class DataPusher { private static final Logger logger = LoggerFactory.getLogger(DataPusher.class); @@ -43,6 +45,7 @@ public class DataPusher { private LinkedBlockingQueue idleQueue; // partition -> PushTask Queue + private final PushState pushState; private final DataPushQueue dataPushQueue; private final ReentrantLock idleLock = new ReentrantLock(); private final Condition idleFull = idleLock.newCondition(); @@ -98,6 +101,8 @@ public DataPusher( this.client = client; this.afterPush = afterPush; this.mapStatusLengths = mapStatusLengths; + final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); + this.pushState = client.getPushState(mapKey); pushThread = new Thread("celeborn-client-data-pusher-" + taskId) { @@ -193,6 +198,9 @@ public void checkException() throws IOException { if (exceptionRef.get() != null) { throw exceptionRef.get(); } + if (pushState.exception.get() != null) { + throw pushState.exception.get(); + } } protected void pushData(PushTask task) throws IOException { @@ -216,6 +224,7 @@ private void waitIdleQueueFullWithLock() throws InterruptedException { while (idleQueue != null && idleQueue.remainingCapacity() > 0 && exceptionRef.get() == null + && pushState.exception.get() == null && (pushThread != null && pushThread.isAlive())) { idleFull.await(WAIT_TIME_NANOS, TimeUnit.NANOSECONDS); } @@ -228,7 +237,9 @@ private void waitIdleQueueFullWithLock() throws InterruptedException { } protected boolean stillRunning() { - return !terminated && !Objects.nonNull(exceptionRef.get()); + return !terminated + && !Objects.nonNull(exceptionRef.get()) + && !Objects.nonNull(pushState.exception.get()); } public DataPushQueue getDataPushQueue() {