Skip to content

Commit c15cd50

Browse files
committed
Better handling of edge cases (0 or 1 Promise param) in Promises any / anyStrict / atLeast; result of Promises.from doesn't wrap original exception into CompletionException
1 parent 524d7e8 commit c15cd50

File tree

1 file changed

+79
-23
lines changed

1 file changed

+79
-23
lines changed

src/main/java/net/tascalate/concurrent/Promises.java

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
import java.time.Duration;
1919
import java.util.Arrays;
20+
import java.util.Collection;
2021
import java.util.Collections;
2122
import java.util.List;
23+
import java.util.NoSuchElementException;
2224
import java.util.Objects;
2325
import java.util.Optional;
2426
import java.util.concurrent.Callable;
@@ -28,6 +30,7 @@
2830
import java.util.concurrent.Executor;
2931
import java.util.concurrent.atomic.AtomicReference;
3032
import java.util.function.Consumer;
33+
import java.util.function.Function;
3134
import java.util.function.Supplier;
3235

3336
/**
@@ -90,7 +93,7 @@ public static <T> Promise<T> from(CompletionStage<T> stage) {
9093
return new CompletablePromise<>((CompletableFuture<T>)stage);
9194
}
9295

93-
return COMPLETED_DEPENDENT_PROMISE.thenCombine(stage, (u, v) -> v, PromiseOrigin.PARAM_ONLY).raw();
96+
return transform(stage, Function.identity(), Function.identity());
9497
}
9598

9699
/**
@@ -184,7 +187,22 @@ public static <T> Promise<T> any(boolean cancelRemaining, CompletionStage<? exte
184187
}
185188

186189
public static <T> Promise<T> any(boolean cancelRemaining, List<CompletionStage<? extends T>> promises) {
187-
return unwrap(atLeast(1, (promises == null ? 0 : promises.size()) - 1, cancelRemaining, promises), false);
190+
int size = null == promises ? 0 : promises.size();
191+
switch (size) {
192+
case 0:
193+
@SuppressWarnings("unchecked")
194+
Promise<T> emptyResult = (Promise<T>)EMPTY_AGGREGATE_FAILURE;
195+
return emptyResult;
196+
case 1:
197+
@SuppressWarnings("unchecked")
198+
CompletionStage<T> singleResult = (CompletionStage<T>) promises.get(0);
199+
return transform(singleResult, Function.identity(), Promises::wrapMultitargetException);
200+
default:
201+
return transform(
202+
atLeast(1, size - 1, cancelRemaining, promises),
203+
Promises::extractFirstNonNull, Function.identity() /* DO NOT unwrap multitarget exception */
204+
);
205+
}
188206
}
189207

190208
/**
@@ -233,7 +251,22 @@ public static <T> Promise<T> anyStrict(boolean cancelRemaining, CompletionStage<
233251
}
234252

235253
public static <T> Promise<T> anyStrict(boolean cancelRemaining, List<CompletionStage<? extends T>> promises) {
236-
return unwrap(atLeast(1, 0, cancelRemaining, promises), true);
254+
int size = null == promises ? 0 : promises.size();
255+
switch (size) {
256+
case 0:
257+
@SuppressWarnings("unchecked")
258+
Promise<T> emptyResult = (Promise<T>)EMPTY_AGGREGATE_FAILURE;
259+
return emptyResult;
260+
case 1:
261+
@SuppressWarnings("unchecked")
262+
CompletionStage<T> singleResult = (CompletionStage<T>) promises.get(0);
263+
return from(singleResult);
264+
default:
265+
return transform(
266+
atLeast(1, 0, cancelRemaining, promises),
267+
Promises::extractFirstNonNull, Promises::unwrapMultitargetException
268+
);
269+
}
237270
}
238271

239272
/**
@@ -395,10 +428,7 @@ public static <T> Promise<List<T>> atLeast(int minResultsCount, int maxErrorsCou
395428
} else if (minResultsCount == 0) {
396429
return success(Collections.emptyList());
397430
} else if (size == 1) {
398-
return from(promises.get(0))
399-
.dependent()
400-
.thenApply((T r) -> Collections.singletonList(r), true)
401-
.raw();
431+
return transform(promises.get(0), Collections::singletonList, Function.identity());
402432
} else {
403433
return new AggregatingPromise<>(minResultsCount, maxErrorsCount, cancelRemaining, promises);
404434
}
@@ -531,23 +561,49 @@ static CompletionException wrapException(Throwable e) {
531561
}
532562
}
533563

534-
private static <T> Promise<T> unwrap(CompletionStage<List<T>> original, boolean unwrapException) {
535-
return from(original)
536-
.dependent()
537-
.handle((r, e) -> {
538-
if (null != e) {
539-
if (unwrapException) {
540-
Throwable targetException = unwrapException(e);
541-
if (targetException instanceof MultitargetException) {
542-
throw wrapException( ((MultitargetException)targetException).getFirstException().get() );
543-
}
544-
}
545-
throw wrapException(e);
564+
private static <T, U> Promise<T> transform(CompletionStage<U> original,
565+
Function<? super U, ? extends T> resultMapper,
566+
Function<? super Throwable, ? extends Throwable> errorMapper) {
567+
CompletablePromise<T> result = new CompletablePromise<T>() {
568+
@Override
569+
public boolean cancel(boolean mayInterruptIfRunning) {
570+
if (super.cancel(mayInterruptIfRunning)) {
571+
CompletablePromise.cancelPromise(original, mayInterruptIfRunning);
572+
return true;
546573
} else {
547-
return r.stream().filter(Objects::nonNull).findFirst().get();
574+
return false;
548575
}
549-
}, true)
550-
.raw();
576+
}
577+
};
578+
original.whenComplete((r, e) -> {
579+
if (null == e) {
580+
result.onSuccess( resultMapper.apply(r) );
581+
} else {
582+
result.onFailure( errorMapper.apply(e) );
583+
}
584+
});
585+
return result;
586+
}
587+
588+
private static <T> T extractFirstNonNull(Collection<? extends T> collection) {
589+
return collection.stream().filter(Objects::nonNull).findFirst().get();
590+
}
591+
592+
private static <E extends Throwable> Throwable unwrapMultitargetException(E exception) {
593+
Throwable targetException = unwrapException(exception);
594+
if (targetException instanceof MultitargetException) {
595+
return MultitargetException.class.cast(targetException).getFirstException().get();
596+
} else {
597+
return targetException;
598+
}
599+
}
600+
601+
private static <E extends Throwable> MultitargetException wrapMultitargetException(E exception) {
602+
if (exception instanceof MultitargetException) {
603+
return (MultitargetException)exception;
604+
} else {
605+
return new MultitargetException(Collections.singletonList(exception));
606+
}
551607
}
552608

553609
private static class ObjectRef<T> {
@@ -563,5 +619,5 @@ T dereference() {
563619
}
564620

565621
private static final Object IGNORE = new Object();
566-
private static final DependentPromise<Object> COMPLETED_DEPENDENT_PROMISE = success(null).dependent();
622+
private static final Promise<Object> EMPTY_AGGREGATE_FAILURE = failure(new NoSuchElementException());
567623
}

0 commit comments

Comments
 (0)