diff --git a/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java b/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java index e366ae7c2f3..bc02570da52 100644 --- a/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java +++ b/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java @@ -35,6 +35,8 @@ import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.util.ThreadUtils; +import org.apache.celeborn.common.util.Utils; +import org.apache.celeborn.common.write.PushState; public class DataPusher { private static final Logger logger = LoggerFactory.getLogger(DataPusher.class); @@ -43,6 +45,7 @@ public class DataPusher { private LinkedBlockingQueue idleQueue; // partition -> PushTask Queue + private final PushState pushState; private final DataPushQueue dataPushQueue; private final ReentrantLock idleLock = new ReentrantLock(); private final Condition idleFull = idleLock.newCondition(); @@ -98,6 +101,8 @@ public DataPusher( this.client = client; this.afterPush = afterPush; this.mapStatusLengths = mapStatusLengths; + final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); + this.pushState = client.getPushState(mapKey); pushThread = ThreadUtils.newDaemonThread( @@ -194,6 +199,9 @@ public void checkException() throws IOException { if (exceptionRef.get() != null) { throw exceptionRef.get(); } + if (pushState.exception.get() != null) { + throw pushState.exception.get(); + } } protected void pushData(PushTask task) throws IOException { @@ -217,6 +225,7 @@ private void waitIdleQueueFullWithLock() throws InterruptedException { while (idleQueue != null && idleQueue.remainingCapacity() > 0 && exceptionRef.get() == null + && pushState.exception.get() == null && (pushThread != null && pushThread.isAlive())) { idleFull.await(WAIT_TIME_NANOS, TimeUnit.NANOSECONDS); } @@ -229,7 +238,9 @@ private void waitIdleQueueFullWithLock() throws InterruptedException { } protected boolean stillRunning() { - return !terminated && !Objects.nonNull(exceptionRef.get()); + return !terminated + && !Objects.nonNull(exceptionRef.get()) + && !Objects.nonNull(pushState.exception.get()); } public DataPushQueue getDataPushQueue() {