diff --git a/core/src/main/java/com/scalar/db/transaction/consensuscommit/ParallelExecutor.java b/core/src/main/java/com/scalar/db/transaction/consensuscommit/ParallelExecutor.java index c657c6b306..876e73fe54 100644 --- a/core/src/main/java/com/scalar/db/transaction/consensuscommit/ParallelExecutor.java +++ b/core/src/main/java/com/scalar/db/transaction/consensuscommit/ParallelExecutor.java @@ -147,7 +147,8 @@ public void executeImplicitPreRead(List tasks, String tran } } - private void executeTasks( + @VisibleForTesting + void executeTasks( List tasks, boolean parallel, boolean noWait, @@ -158,14 +159,14 @@ private void executeTasks( if (tasks.size() == 1 && !noWait) { // If there is only one task and noWait is false, we can run it directly without parallel // execution. - executeTasksSerially(tasks, stopOnError, taskName, transactionId); + tasks.get(0).run(); return; } if (parallel) { executeTasksInParallel(tasks, noWait, stopOnError, taskName, transactionId); } else { - executeTasksSerially(tasks, stopOnError, taskName, transactionId); + executeTasksSerially(tasks, stopOnError); } } @@ -180,80 +181,67 @@ private void executeTasksInParallel( CompletionService completionService = new ExecutorCompletionService<>(parallelExecutorService); - tasks.forEach( - t -> - completionService.submit( - () -> { - try { - t.run(); - } catch (Exception e) { - logger.warn( - "Failed to run a {} task. Transaction ID: {}", taskName, transactionId, e); - throw e; - } - return null; - })); - - if (!noWait) { - Exception exception = null; - for (int i = 0; i < tasks.size(); i++) { - Future future = ScalarDbUtils.takeUninterruptibly(completionService); - try { - Uninterruptibles.getUninterruptibly(future); - } catch (java.util.concurrent.ExecutionException e) { - if (e.getCause() instanceof ExecutionException) { - if (!stopOnError) { - exception = (ExecutionException) e.getCause(); - } else { - throw (ExecutionException) e.getCause(); - } - } else if (e.getCause() instanceof ValidationConflictException) { - if (!stopOnError) { - exception = (ValidationConflictException) e.getCause(); - } else { - throw (ValidationConflictException) e.getCause(); - } - } else if (e.getCause() instanceof CrudException) { - if (!stopOnError) { - exception = (CrudException) e.getCause(); - } else { - throw (CrudException) e.getCause(); + // Submit tasks + for (ParallelExecutorTask task : tasks) { + completionService.submit( + () -> { + try { + task.run(); + } catch (Exception e) { + logger.warn( + "Failed to run a {} task. Transaction ID: {}", taskName, transactionId, e); + throw e; } - } else if (e.getCause() instanceof RuntimeException) { - throw (RuntimeException) e.getCause(); - } else if (e.getCause() instanceof Error) { - throw (Error) e.getCause(); + return null; + }); + } + + // Optionally wait for completion + if (noWait) { + return; + } + + Throwable throwable = null; + + for (int i = 0; i < tasks.size(); i++) { + Future future = ScalarDbUtils.takeUninterruptibly(completionService); + try { + Uninterruptibles.getUninterruptibly(future); + } catch (java.util.concurrent.ExecutionException e) { + Throwable cause = e.getCause(); + + if (stopOnError) { + rethrow(cause); + } else { + if (throwable == null) { + throwable = cause; } else { - throw new AssertionError("Can't reach here. Maybe a bug", e); + throwable.addSuppressed(cause); } } } + } - if (!stopOnError && exception != null) { - if (exception instanceof ExecutionException) { - throw (ExecutionException) exception; - } else if (exception instanceof ValidationConflictException) { - throw (ValidationConflictException) exception; - } else { - throw (CrudException) exception; - } - } + // Rethrow exception if necessary + if (!stopOnError && throwable != null) { + rethrow(throwable); } } - private void executeTasksSerially( - List tasks, boolean stopOnError, String taskName, String transactionId) + private void executeTasksSerially(List tasks, boolean stopOnError) throws ExecutionException, ValidationConflictException, CrudException { Exception exception = null; for (ParallelExecutorTask task : tasks) { try { task.run(); } catch (ExecutionException | ValidationConflictException | CrudException e) { - logger.warn("Failed to run a {} task. Transaction ID: {}", taskName, transactionId, e); - if (!stopOnError) { - exception = e; + if (exception == null) { + exception = e; + } else { + exception.addSuppressed(e); + } } else { throw e; } @@ -261,13 +249,24 @@ private void executeTasksSerially( } if (!stopOnError && exception != null) { - if (exception instanceof ExecutionException) { - throw (ExecutionException) exception; - } else if (exception instanceof ValidationConflictException) { - throw (ValidationConflictException) exception; - } else { - throw (CrudException) exception; - } + rethrow(exception); + } + } + + private void rethrow(Throwable cause) + throws ExecutionException, ValidationConflictException, CrudException { + if (cause instanceof ExecutionException) { + throw (ExecutionException) cause; + } else if (cause instanceof ValidationConflictException) { + throw (ValidationConflictException) cause; + } else if (cause instanceof CrudException) { + throw (CrudException) cause; + } else if (cause instanceof RuntimeException) { + throw (RuntimeException) cause; + } else if (cause instanceof Error) { + throw (Error) cause; + } else { + throw new AssertionError("Unexpected exception type", cause); } } diff --git a/core/src/test/java/com/scalar/db/transaction/consensuscommit/ParallelExecutorTest.java b/core/src/test/java/com/scalar/db/transaction/consensuscommit/ParallelExecutorTest.java index c9a0785321..84ac27206d 100644 --- a/core/src/test/java/com/scalar/db/transaction/consensuscommit/ParallelExecutorTest.java +++ b/core/src/test/java/com/scalar/db/transaction/consensuscommit/ParallelExecutorTest.java @@ -5,6 +5,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.only; import static org.mockito.Mockito.spy; @@ -562,4 +563,203 @@ public void executeImplicitPreRead_ParallelImplicitPreReadEnabled_ShouldExecuteT assertThatThrownBy(() -> parallelExecutor.executeImplicitPreRead(tasks, TX_ID)) .isInstanceOf(CrudException.class); } + + @Test + public void executeTasks_SingleTaskAndNoWaitFalse_ShouldExecuteDirectly() + throws ExecutionException, ValidationConflictException, CrudException { + // Arrange + List tasks = Collections.singletonList(task); + boolean parallel = true; // Should be ignored + boolean noWait = false; + boolean stopOnError = true; + + // Act + parallelExecutor.executeTasks(tasks, parallel, noWait, stopOnError, "test", TX_ID); + + // Assert + verify(task).run(); + verify(parallelExecutorService, never()).execute(any()); + } + + @Test + public void executeTasks_SingleTaskAndNoWaitTrue_ShouldUseParallelExecution() + throws ExecutionException, ValidationConflictException, CrudException { + // Arrange + when(config.isParallelPreparationEnabled()).thenReturn(true); + + List tasks = Collections.singletonList(task); + boolean parallel = true; + boolean noWait = true; + boolean stopOnError = false; + + // Act + parallelExecutor.executeTasks(tasks, parallel, noWait, stopOnError, "test", TX_ID); + + // Assert + verify(parallelExecutorService).execute(any()); + } + + @Test + public void executeTasks_ParallelTrue_ShouldExecuteTasksInParallel() + throws ExecutionException, ValidationConflictException, CrudException { + // Arrange + boolean parallel = true; + boolean noWait = false; + boolean stopOnError = false; + + // Act + parallelExecutor.executeTasks(tasks, parallel, noWait, stopOnError, "test", TX_ID); + + // Assert + verify(parallelExecutorService, times(tasks.size())).execute(any()); + } + + @Test + public void executeTasks_ParallelFalse_ShouldExecuteTasksSerially() + throws ExecutionException, ValidationConflictException, CrudException { + // Arrange + boolean parallel = false; + boolean noWait = false; + boolean stopOnError = false; + + // Act + parallelExecutor.executeTasks(tasks, parallel, noWait, stopOnError, "test", TX_ID); + + // Assert + verify(task, times(tasks.size())).run(); + verify(parallelExecutorService, never()).execute(any()); + } + + @Test + public void executeTasks_ParallelTrueAndStopOnErrorTrue_ExceptionThrown_ShouldStopExecution() + throws ExecutionException, ValidationConflictException, CrudException { + // Arrange + boolean parallel = true; + boolean noWait = false; + boolean stopOnError = true; + + doThrow(new ExecutionException("Test exception")).when(task).run(); + + // Act Assert + assertThatThrownBy( + () -> + parallelExecutor.executeTasks(tasks, parallel, noWait, stopOnError, "test", TX_ID)) + .isInstanceOf(ExecutionException.class) + .hasMessage("Test exception"); + + verify(parallelExecutorService, times(tasks.size())).execute(any()); + } + + @Test + public void + executeTasks_ParallelTrueAndStopOnErrorFalse_ExceptionThrown_ShouldContinueOtherTasks() + throws ExecutionException, ValidationConflictException, CrudException { + // Arrange + boolean parallel = true; + boolean noWait = false; + boolean stopOnError = false; + + ParallelExecutorTask failingTask = mock(ParallelExecutorTask.class); + doThrow(new ExecutionException("Test exception")).when(failingTask).run(); + + List mixedTasks = Arrays.asList(failingTask, task); + + // Act Assert + assertThatThrownBy( + () -> + parallelExecutor.executeTasks( + mixedTasks, parallel, noWait, stopOnError, "test", TX_ID)) + .isInstanceOf(ExecutionException.class); + + verify(parallelExecutorService, times(mixedTasks.size())).execute(any()); + } + + @Test + public void + executeTasks_ParallelTrueAndStopOnErrorFalse_ExceptionThrownByMultipleTasks_ShouldContinueOtherTasks() + throws ExecutionException, ValidationConflictException, CrudException { + // Arrange + boolean parallel = true; + boolean noWait = false; + boolean stopOnError = false; + + ExecutionException executionException1 = new ExecutionException("Test exception1"); + ParallelExecutorTask failingTask1 = mock(ParallelExecutorTask.class); + doThrow(executionException1).when(failingTask1).run(); + + ExecutionException executionException2 = new ExecutionException("Test exception2"); + ParallelExecutorTask failingTask2 = mock(ParallelExecutorTask.class); + doThrow(executionException2).when(failingTask2).run(); + + List mixedTasks = Arrays.asList(failingTask1, failingTask2, task); + + // Act Assert + assertThatThrownBy( + () -> + parallelExecutor.executeTasks( + mixedTasks, parallel, noWait, stopOnError, "test", TX_ID)) + .isEqualTo(executionException1) + .hasSuppressedException(executionException2); + + verify(parallelExecutorService, times(mixedTasks.size())).execute(any()); + } + + @Test + public void + executeTasks_ParallelFalseAndStopOnErrorFalse_ExceptionThrown_ShouldContinueOtherTasks() + throws ExecutionException, ValidationConflictException, CrudException { + // Arrange + boolean parallel = false; + boolean noWait = false; + boolean stopOnError = false; + + ParallelExecutorTask failingTask = mock(ParallelExecutorTask.class); + doThrow(new ExecutionException("Test exception")).when(failingTask).run(); + + List mixedTasks = Arrays.asList(failingTask, task); + + // Act Assert + assertThatThrownBy( + () -> + parallelExecutor.executeTasks( + mixedTasks, parallel, noWait, stopOnError, "test", TX_ID)) + .isInstanceOf(ExecutionException.class); + + verify(failingTask, only()).run(); + verify(task, only()).run(); + verify(parallelExecutorService, never()).execute(any()); + } + + @Test + public void + executeTasks_ParallelFalseAndStopOnErrorFalse_ExceptionThrownByMultipleTasks_ShouldContinueOtherTasks() + throws ExecutionException, ValidationConflictException, CrudException { + // Arrange + boolean parallel = false; + boolean noWait = false; + boolean stopOnError = false; + + ExecutionException executionException1 = new ExecutionException("Test exception1"); + ParallelExecutorTask failingTask1 = mock(ParallelExecutorTask.class); + doThrow(executionException1).when(failingTask1).run(); + + ExecutionException executionException2 = new ExecutionException("Test exception2"); + ParallelExecutorTask failingTask2 = mock(ParallelExecutorTask.class); + doThrow(executionException2).when(failingTask2).run(); + + List mixedTasks = Arrays.asList(failingTask1, failingTask2, task); + + // Act Assert + assertThatThrownBy( + () -> + parallelExecutor.executeTasks( + mixedTasks, parallel, noWait, stopOnError, "test", TX_ID)) + .isEqualTo(executionException1) + .hasSuppressedException(executionException2); + + verify(failingTask1, only()).run(); + verify(failingTask2, only()).run(); + verify(task, only()).run(); + verify(parallelExecutorService, never()).execute(any()); + } }