diff --git a/src/main/java/org/apache/commons/lang3/concurrent/BackgroundInitializer.java b/src/main/java/org/apache/commons/lang3/concurrent/BackgroundInitializer.java index 2be935ad014..4d46f7ad4eb 100644 --- a/src/main/java/org/apache/commons/lang3/concurrent/BackgroundInitializer.java +++ b/src/main/java/org/apache/commons/lang3/concurrent/BackgroundInitializer.java @@ -345,7 +345,11 @@ public synchronized boolean isInitialized() { try { future.get(); return true; - } catch (CancellationException | ExecutionException | InterruptedException e) { + } catch (CancellationException | ExecutionException e) { + return false; + } catch (InterruptedException e) { + // reset interrupted state + Thread.currentThread().interrupt(); return false; } } diff --git a/src/main/java/org/apache/commons/lang3/concurrent/UncheckedFutureImpl.java b/src/main/java/org/apache/commons/lang3/concurrent/UncheckedFutureImpl.java index 26c42fc34bf..6b92e8e7123 100644 --- a/src/main/java/org/apache/commons/lang3/concurrent/UncheckedFutureImpl.java +++ b/src/main/java/org/apache/commons/lang3/concurrent/UncheckedFutureImpl.java @@ -42,6 +42,7 @@ public V get() { try { return super.get(); } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); throw new UncheckedInterruptedException(e); } catch (final ExecutionException e) { throw new UncheckedExecutionException(e); @@ -53,6 +54,7 @@ public V get(final long timeout, final TimeUnit unit) { try { return super.get(timeout, unit); } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); throw new UncheckedInterruptedException(e); } catch (final ExecutionException e) { throw new UncheckedExecutionException(e); diff --git a/src/test/java/org/apache/commons/lang3/concurrent/UncheckedFutureTest.java b/src/test/java/org/apache/commons/lang3/concurrent/UncheckedFutureTest.java index ab670480468..db8f59fda2c 100644 --- a/src/test/java/org/apache/commons/lang3/concurrent/UncheckedFutureTest.java +++ b/src/test/java/org/apache/commons/lang3/concurrent/UncheckedFutureTest.java @@ -18,14 +18,20 @@ package org.apache.commons.lang3.concurrent; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Arrays; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.stream.Collectors; import org.apache.commons.lang3.AbstractLangTest; @@ -121,4 +127,62 @@ void testOnFuture() { assertEquals("Z", UncheckedFuture.on(new TestFuture<>("Z")).get()); } + + @Test + void interruptFlagIsPreservedOnGet() throws Exception { + assertInterruptPreserved(UncheckedFuture::get); + } + + @Test + void interruptFlagIsPreservedOnGetWithTimeout() throws Exception { + assertInterruptPreserved(uf -> uf.get(1, TimeUnit.DAYS)); + } + + private static void assertInterruptPreserved(Consumer> call) throws Exception { + final CountDownLatch enteredGet = new CountDownLatch(1); + final Future blockingFuture = new AbstractFutureProxy(ConcurrentUtils.constantFuture(42)) { + private final CountDownLatch neverRelease = new CountDownLatch(1); + + @Override + public Integer get() throws InterruptedException { + enteredGet.countDown(); + neverRelease.await(); + throw new AssertionError("We should not get here"); + } + + @Override + public Integer get(long timeout, TimeUnit unit) throws InterruptedException { + enteredGet.countDown(); + neverRelease.await(); + throw new AssertionError("We should not get here"); + } + + @Override + public boolean isDone() { + return false; + } + + }; + final UncheckedFuture uf = UncheckedFuture.on(blockingFuture); + final AtomicReference thrown = new AtomicReference<>(); + final AtomicBoolean interruptObserved = new AtomicBoolean(false); + final Thread worker = new Thread(() -> { + try { + call.accept(uf); + thrown.set(new AssertionError("We should not get here")); + } catch (Throwable e) { + interruptObserved.set(Thread.currentThread().isInterrupted()); + thrown.set(e); + } + }, "unchecked-future-test-worker"); + worker.start(); + assertTrue(enteredGet.await(2, TimeUnit.SECONDS), "Worker did not enter Future.get() in time"); + worker.interrupt(); + worker.join(); + final Throwable t = thrown.get(); + assertInstanceOf(UncheckedInterruptedException.class, t, "Unexpected exception: " + t); + assertInstanceOf(InterruptedException.class, t.getCause(), "Cause should be InterruptedException"); + assertTrue(interruptObserved.get(), "Interrupt flag was not restored by the wrapper"); + } + }