5858import java .util .List ;
5959import java .util .Map ;
6060import java .util .Set ;
61- import java .util .concurrent .CountDownLatch ;
6261import java .util .concurrent .CyclicBarrier ;
6362import java .util .concurrent .Executor ;
6463import java .util .concurrent .TimeUnit ;
@@ -378,28 +377,24 @@ protected void newResponseAsync(
378377 public void testConcurrentlyCompletionAndCancellation () throws InterruptedException {
379378 final var action = getTestTransportNodesAction ();
380379
381- final CountDownLatch onCancelledLatch = new CountDownLatch (1 );
382- final CancellableTask cancellableTask = new CancellableTask (randomLong (), "transport" , "action" , "" , null , emptyMap ()) {
383- @ Override
384- protected void onCancelled () {
385- onCancelledLatch .countDown ();
386- }
387- };
380+ final CancellableTask cancellableTask = new CancellableTask (randomLong (), "transport" , "action" , "" , null , emptyMap ());
388381
389382 final PlainActionFuture <TestNodesResponse > future = new PlainActionFuture <>();
390383 action .execute (cancellableTask , new TestNodesRequest (), future );
391384
392385 final List <TestNodeResponse > nodeResponses = new ArrayList <>();
393386 final CapturingTransport .CapturedRequest [] capturedRequests = transport .getCapturedRequestsAndClear ();
387+ // Complete all but the last request for racing completion with cancellation
394388 for (int i = 0 ; i < capturedRequests .length - 1 ; i ++) {
395389 final var capturedRequest = capturedRequests [i ];
396390 nodeResponses .add (completeOneRequest (capturedRequest ));
397391 }
398392
399393 final var raceBarrier = new CyclicBarrier (3 );
394+ final var lastResponseFuture = new PlainActionFuture <TestNodeResponse >();
400395 final Thread completeThread = new Thread (() -> {
401396 safeAwait (raceBarrier );
402- nodeResponses . add (completeOneRequest (capturedRequests [capturedRequests .length - 1 ]));
397+ lastResponseFuture . onResponse (completeOneRequest (capturedRequests [capturedRequests .length - 1 ]));
403398 });
404399 final Thread cancelThread = new Thread (() -> {
405400 safeAwait (raceBarrier );
@@ -419,8 +414,11 @@ protected void onCancelled() {
419414 assertNotNull ("expect task cancellation exception, but got\n " + ExceptionsHelper .stackTrace (e ), taskCancelledException );
420415 assertThat (e .getMessage (), containsString ("task cancelled [simulated]" ));
421416 assertTrue (cancellableTask .isCancelled ());
422- safeAwait ( onCancelledLatch ); // wait for the latch, the listener for releasing node responses is called before it
417+ // All previously captured responses are released due to cancellation
423418 assertTrue (nodeResponses .stream ().allMatch (r -> r .hasReferences () == false ));
419+ // Wait for the last response to be gathered and assert it is also released by either the concurrent cancellation or
420+ // not tracked in onItemResponse at all due to already cancelled
421+ assertFalse (safeGet (lastResponseFuture ).hasReferences ());
424422 }
425423
426424 completeThread .join (10_000 );
0 commit comments