From f410aedaba129c11765ba93ca7c1d0c05ab7b984 Mon Sep 17 00:00:00 2001 From: Pham Ba Thong Date: Mon, 9 Jun 2025 13:48:59 +0900 Subject: [PATCH] Move parallelism level for the importing process from the record level to the data chunk level (#2728) Co-authored-by: Toshihiro Suzuki Co-authored-by: Peckstadt Yves --- .../dataimport/processor/ImportProcessor.java | 169 ++++---- .../processor/ImportProcessorTest.java | 400 ++++++++++++++++++ 2 files changed, 477 insertions(+), 92 deletions(-) create mode 100644 data-loader/core/src/test/java/com/scalar/db/dataloader/core/dataimport/processor/ImportProcessorTest.java diff --git a/data-loader/core/src/main/java/com/scalar/db/dataloader/core/dataimport/processor/ImportProcessor.java b/data-loader/core/src/main/java/com/scalar/db/dataloader/core/dataimport/processor/ImportProcessor.java index 50877873f2..3f191f7259 100644 --- a/data-loader/core/src/main/java/com/scalar/db/dataloader/core/dataimport/processor/ImportProcessor.java +++ b/data-loader/core/src/main/java/com/scalar/db/dataloader/core/dataimport/processor/ImportProcessor.java @@ -25,11 +25,11 @@ import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.Phaser; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import lombok.RequiredArgsConstructor; @@ -62,41 +62,77 @@ public abstract class ImportProcessor { * @param reader the {@link BufferedReader} used to read the source file */ public void process(int dataChunkSize, int transactionBatchSize, BufferedReader reader) { - ExecutorService dataChunkExecutor = Executors.newSingleThreadExecutor(); + ExecutorService dataChunkReaderExecutor = Executors.newSingleThreadExecutor(); + ExecutorService dataChunkProcessorExecutor = + Executors.newFixedThreadPool(params.getImportOptions().getMaxThreads()); BlockingQueue dataChunkQueue = new LinkedBlockingQueue<>(params.getImportOptions().getDataChunkQueueSize()); + // Semaphore controls concurrent task submissions, small buffer to be two times of threads + Semaphore taskSemaphore = new Semaphore(params.getImportOptions().getMaxThreads() * 2); + // Phaser tracks task completion (start with 1 for the main thread) + Phaser phaser = new Phaser(1); + try { CompletableFuture readerFuture = CompletableFuture.runAsync( - () -> readDataChunks(reader, dataChunkSize, dataChunkQueue), dataChunkExecutor); + () -> readDataChunks(reader, dataChunkSize, dataChunkQueue), dataChunkReaderExecutor); while (!(dataChunkQueue.isEmpty() && readerFuture.isDone())) { ImportDataChunk dataChunk = dataChunkQueue.poll(100, TimeUnit.MILLISECONDS); if (dataChunk != null) { - processDataChunk(dataChunk, transactionBatchSize); + // Acquire semaphore permit (blocks if no permits available) + taskSemaphore.acquire(); + // Register with phaser before submitting + phaser.register(); + + dataChunkProcessorExecutor.submit( + () -> { + try { + processDataChunk(dataChunk, transactionBatchSize); + } finally { + // Always release semaphore and arrive at phaser + taskSemaphore.release(); + phaser.arriveAndDeregister(); + } + }); } } readerFuture.join(); + // Wait for all tasks to complete + phaser.arriveAndAwaitAdvance(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException( CoreError.DATA_LOADER_DATA_CHUNK_PROCESS_FAILED.buildMessage(e.getMessage()), e); } finally { - dataChunkExecutor.shutdown(); - try { - if (!dataChunkExecutor.awaitTermination(60, TimeUnit.SECONDS)) { - dataChunkExecutor.shutdownNow(); - } - } catch (InterruptedException e) { - dataChunkExecutor.shutdownNow(); - Thread.currentThread().interrupt(); - } + shutdownExecutorGracefully(dataChunkReaderExecutor); + shutdownExecutorGracefully(dataChunkProcessorExecutor); notifyAllDataChunksCompleted(); } } + /** + * Shuts down the given `ExecutorService` gracefully. This method attempts to cleanly shut down + * the executor by first invoking `shutdown` and waiting for termination for up to 60 seconds. If + * the executor does not terminate within this time, it forces a shutdown using `shutdownNow`. If + * interrupted, it forces a shutdown and interrupts the current thread. + * + * @param es the `ExecutorService` to be shut down gracefully + */ + private void shutdownExecutorGracefully(ExecutorService es) { + es.shutdown(); + try { + if (!es.awaitTermination(60, TimeUnit.SECONDS)) { + es.shutdownNow(); + } + } catch (InterruptedException e) { + es.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + /** * Reads and processes data in chunks from the provided reader. * @@ -373,46 +409,26 @@ private ImportDataChunkStatus processDataChunkWithTransactions( Instant startTime = Instant.now(); List transactionBatches = splitIntoTransactionBatches(dataChunk, transactionBatchSize); - ExecutorService transactionBatchExecutor = - Executors.newFixedThreadPool(params.getImportOptions().getMaxThreads()); - List> transactionBatchFutures = new ArrayList<>(); AtomicInteger successCount = new AtomicInteger(0); AtomicInteger failureCount = new AtomicInteger(0); - try { - for (ImportTransactionBatch transactionBatch : transactionBatches) { - Future transactionBatchFuture = - transactionBatchExecutor.submit( - () -> processTransactionBatch(dataChunk.getDataChunkId(), transactionBatch)); - transactionBatchFutures.add(transactionBatchFuture); - } - waitForFuturesToComplete(transactionBatchFutures); - transactionBatchFutures.forEach( - batchResult -> { - try { - ImportTransactionBatchResult importTransactionBatchResult = - (ImportTransactionBatchResult) batchResult.get(); - importTransactionBatchResult - .getRecords() - .forEach( - batchRecords -> { - if (batchRecords.getTargets().stream() - .allMatch( - targetResult -> - targetResult - .getStatus() - .equals(ImportTargetResultStatus.SAVED))) { - successCount.incrementAndGet(); - } else { - failureCount.incrementAndGet(); - } - }); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }); - } finally { - transactionBatchExecutor.shutdown(); + for (ImportTransactionBatch transactionBatch : transactionBatches) { + ImportTransactionBatchResult importTransactionBatchResult = + processTransactionBatch(dataChunk.getDataChunkId(), transactionBatch); + + importTransactionBatchResult + .getRecords() + .forEach( + batchRecords -> { + if (batchRecords.getTargets().stream() + .allMatch( + targetResult -> + targetResult.getStatus().equals(ImportTargetResultStatus.SAVED))) { + successCount.incrementAndGet(); + } else { + failureCount.incrementAndGet(); + } + }); } Instant endTime = Instant.now(); int totalDuration = (int) Duration.between(startTime, endTime).toMillis(); @@ -440,32 +456,17 @@ private ImportDataChunkStatus processDataChunkWithoutTransactions(ImportDataChun Instant startTime = Instant.now(); AtomicInteger successCount = new AtomicInteger(0); AtomicInteger failureCount = new AtomicInteger(0); - ExecutorService recordExecutor = - Executors.newFixedThreadPool(params.getImportOptions().getMaxThreads()); - List> recordFutures = new ArrayList<>(); - try { - for (ImportRow importRow : dataChunk.getSourceData()) { - Future recordFuture = - recordExecutor.submit( - () -> processStorageRecord(dataChunk.getDataChunkId(), importRow)); - recordFutures.add(recordFuture); + + for (ImportRow importRow : dataChunk.getSourceData()) { + ImportTaskResult result = processStorageRecord(dataChunk.getDataChunkId(), importRow); + boolean allSaved = + result.getTargets().stream() + .allMatch(t -> t.getStatus().equals(ImportTargetResultStatus.SAVED)); + if (allSaved) { + successCount.incrementAndGet(); + } else { + failureCount.incrementAndGet(); } - waitForFuturesToComplete(recordFutures); - recordFutures.forEach( - r -> { - try { - ImportTaskResult result = (ImportTaskResult) r.get(); - boolean allSaved = - result.getTargets().stream() - .allMatch(t -> t.getStatus().equals(ImportTargetResultStatus.SAVED)); - if (allSaved) successCount.incrementAndGet(); - else failureCount.incrementAndGet(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }); - } finally { - recordExecutor.shutdown(); } Instant endTime = Instant.now(); int totalDuration = (int) Duration.between(startTime, endTime).toMillis(); @@ -480,20 +481,4 @@ private ImportDataChunkStatus processDataChunkWithoutTransactions(ImportDataChun .status(ImportDataChunkStatusState.COMPLETE) .build(); } - - /** - * Waits for all futures in the provided list to complete. Any exceptions during execution are - * logged but not propagated. - * - * @param futures the list of {@link Future} objects to wait for - */ - private void waitForFuturesToComplete(List> futures) { - for (Future future : futures) { - try { - future.get(); - } catch (Exception e) { - LOGGER.error(e.getMessage()); - } - } - } } diff --git a/data-loader/core/src/test/java/com/scalar/db/dataloader/core/dataimport/processor/ImportProcessorTest.java b/data-loader/core/src/test/java/com/scalar/db/dataloader/core/dataimport/processor/ImportProcessorTest.java new file mode 100644 index 0000000000..b5163eadb9 --- /dev/null +++ b/data-loader/core/src/test/java/com/scalar/db/dataloader/core/dataimport/processor/ImportProcessorTest.java @@ -0,0 +1,400 @@ +package com.scalar.db.dataloader.core.dataimport.processor; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.scalar.db.api.DistributedStorage; +import com.scalar.db.api.DistributedTransaction; +import com.scalar.db.api.DistributedTransactionManager; +import com.scalar.db.api.TableMetadata; +import com.scalar.db.dataloader.core.ScalarDbMode; +import com.scalar.db.dataloader.core.UnitTestUtils; +import com.scalar.db.dataloader.core.dataimport.ImportEventListener; +import com.scalar.db.dataloader.core.dataimport.ImportOptions; +import com.scalar.db.dataloader.core.dataimport.dao.ScalarDbDao; +import com.scalar.db.dataloader.core.dataimport.datachunk.ImportDataChunk; +import com.scalar.db.dataloader.core.dataimport.datachunk.ImportDataChunkStatus; +import com.scalar.db.dataloader.core.dataimport.datachunk.ImportRow; +import com.scalar.db.dataloader.core.dataimport.task.result.ImportTaskResult; +import com.scalar.db.dataloader.core.dataimport.transactionbatch.ImportTransactionBatchResult; +import com.scalar.db.dataloader.core.dataimport.transactionbatch.ImportTransactionBatchStatus; +import com.scalar.db.exception.transaction.TransactionException; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.StringReader; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.Getter; +import lombok.Setter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +/** + * Unit tests for the ImportProcessor class. + * + *

These tests verify that the process method correctly handles different scenarios including + * storage mode, transaction mode, empty data, and large data chunks. + * + *

Additionally, this class tests the thread executor behavior in ImportProcessor, including + * proper shutdown, waiting for tasks to complete, handling interruptions, and task distribution. + */ +@ExtendWith(MockitoExtension.class) +class ImportProcessorTest { + + @Mock private ImportProcessorParams params; + @Mock private ImportOptions importOptions; + @Mock private ScalarDbDao dao; + @Mock private DistributedStorage distributedStorage; + @Mock private DistributedTransactionManager distributedTransactionManager; + @Mock private DistributedTransaction distributedTransaction; + @Mock private TableColumnDataTypes tableColumnDataTypes; + @Mock private ImportEventListener eventListener; + + private Map tableMetadataByTableName; + + @BeforeEach + void setUp() { + // Only set up the common mocks that are used by all tests + tableMetadataByTableName = new HashMap<>(); + tableMetadataByTableName.put("namespace.table", UnitTestUtils.createTestTableMetadata()); + + when(importOptions.getMaxThreads()).thenReturn(2); + when(importOptions.getDataChunkQueueSize()).thenReturn(10); + when(params.getImportOptions()).thenReturn(importOptions); + } + + @Test + void process_withStorageMode_shouldProcessAllDataChunks() { + // Arrange + BufferedReader reader = new BufferedReader(new StringReader("test data")); + when(params.getScalarDbMode()).thenReturn(ScalarDbMode.STORAGE); + when(params.getDao()).thenReturn(dao); + when(params.getDistributedStorage()).thenReturn(distributedStorage); + when(params.getTableColumnDataTypes()).thenReturn(tableColumnDataTypes); + + TestImportProcessor processor = new TestImportProcessor(params); + processor.addListener(eventListener); + + // Act + processor.process(2, 1, reader); + + // Assert + verify(eventListener, times(1)).onAllDataChunksCompleted(); + // Verify that data chunks were processed + verify(eventListener, times(1)).onDataChunkCompleted(any(ImportDataChunkStatus.class)); + } + + @Test + void process_withTransactionMode_shouldProcessAllDataChunks() throws TransactionException { + // Arrange + BufferedReader reader = new BufferedReader(new StringReader("test data")); + when(params.getScalarDbMode()).thenReturn(ScalarDbMode.TRANSACTION); + when(params.getDao()).thenReturn(dao); + when(params.getTableColumnDataTypes()).thenReturn(tableColumnDataTypes); + when(params.getTableMetadataByTableName()).thenReturn(tableMetadataByTableName); + when(params.getDistributedTransactionManager()).thenReturn(distributedTransactionManager); + when(distributedTransactionManager.start()).thenReturn(distributedTransaction); + + TestImportProcessor processor = new TestImportProcessor(params); + processor.addListener(eventListener); + + // Act + processor.process(2, 1, reader); + + // Assert + verify(eventListener, times(1)).onAllDataChunksCompleted(); + // Verify that data chunks were processed + verify(eventListener, times(1)).onDataChunkCompleted(any(ImportDataChunkStatus.class)); + } + + @Test + void process_withEmptyData_shouldNotProcessAnyDataChunks() { + // Arrange + TestImportProcessor processor = new TestImportProcessor(params); + processor.addListener(eventListener); + + BufferedReader reader = new BufferedReader(new StringReader("")); + + // Act + processor.process(2, 1, reader); + + // Assert + verify(eventListener, times(1)).onAllDataChunksCompleted(); + // Verify that no data chunks were processed + verify(eventListener, times(0)).onDataChunkCompleted(any()); + } + + // Thread executor behavior tests + + @Test + void process_withMultipleDataChunks_shouldUseThreadPool() { + // Arrange + final int maxThreads = 4; + when(importOptions.getMaxThreads()).thenReturn(maxThreads); + when(params.getDao()).thenReturn(dao); + when(params.getDistributedStorage()).thenReturn(distributedStorage); + when(params.getTableColumnDataTypes()).thenReturn(tableColumnDataTypes); + when(params.getTableMetadataByTableName()).thenReturn(tableMetadataByTableName); + + // Create test data with multiple chunks + StringBuilder testData = new StringBuilder(); + for (int i = 0; i < 20; i++) { + testData.append("test data line ").append(i).append("\n"); + } + BufferedReader reader = new BufferedReader(new StringReader(testData.toString())); + + // Create a latch to ensure tasks take some time to complete + CountDownLatch latch = new CountDownLatch(1); + + // Create a TestImportProcessor + TestImportProcessor processor = new TestImportProcessor(params); + processor.addListener(eventListener); + processor.setProcessingLatch(latch); + + // Act + processor.process(2, 1, reader); + + // Assert + // Verify that multiple threads were used but not more than maxThreads + assertTrue(processor.getMaxConcurrentThreads().get() > 1, "Should use multiple threads"); + assertTrue( + processor.getMaxConcurrentThreads().get() <= maxThreads, "Should not exceed max threads"); + + // Verify that all data chunks were processed + verify(eventListener, times(1)).onAllDataChunksCompleted(); + } + + @Test + void process_withInterruption_shouldShutdownGracefully() { + // Arrange + BufferedReader reader = new BufferedReader(new StringReader("test data\nmore data\n")); + + // Create a processor that will be interrupted + TestImportProcessor processor = new TestImportProcessor(params); + processor.addListener(eventListener); + processor.setSimulateInterruption(true); + + // Act & Assert + assertThrows(RuntimeException.class, () -> processor.process(2, 1, reader)); + + // Verify that onAllDataChunksCompleted was still called (in finally block) + verify(eventListener, times(1)).onAllDataChunksCompleted(); + } + + @Test + void process_withLargeNumberOfTasks_shouldWaitForAllTasksToComplete() { + // Arrange + final int maxThreads = 2; + when(importOptions.getMaxThreads()).thenReturn(maxThreads); + when(params.getDao()).thenReturn(dao); + when(params.getDistributedStorage()).thenReturn(distributedStorage); + when(params.getTableColumnDataTypes()).thenReturn(tableColumnDataTypes); + when(params.getTableMetadataByTableName()).thenReturn(tableMetadataByTableName); + + // Create test data with many chunks + StringBuilder testData = new StringBuilder(); + for (int i = 0; i < 50; i++) { + testData.append("test data line ").append(i).append("\n"); + } + BufferedReader reader = new BufferedReader(new StringReader(testData.toString())); + + // Create a TestImportProcessor with a small processing delay + TestImportProcessor processor = new TestImportProcessor(params); + processor.addListener(eventListener); + processor.setProcessingDelayMs(10); // 10ms delay per chunk + + // Act + processor.process(2, 1, reader); + + // Assert + // Verify that all tasks were completed + assertTrue(processor.getProcessedChunksCount().get() > 0, "All tasks should be completed"); + verify(eventListener, times(1)).onAllDataChunksCompleted(); + } + + @Test + void process_withShutdown_shouldShutdownExecutorsGracefully() { + // Arrange + when(params.getScalarDbMode()).thenReturn(ScalarDbMode.STORAGE); + when(params.getDao()).thenReturn(dao); + when(params.getDistributedStorage()).thenReturn(distributedStorage); + when(params.getTableColumnDataTypes()).thenReturn(tableColumnDataTypes); + when(params.getTableMetadataByTableName()).thenReturn(tableMetadataByTableName); + + BufferedReader reader = + new BufferedReader(new StringReader("test data\nmore data\neven more data\n")); + + // Create a TestImportProcessor with a longer processing delay + TestImportProcessor processor = new TestImportProcessor(params); + processor.addListener(eventListener); + processor.setProcessingDelayMs(50); // 50ms delay per chunk + + // Act + processor.process(1, 1, reader); + + // Assert + // Verify that all data chunks were processed and executors were shut down gracefully + verify(eventListener, times(1)).onAllDataChunksCompleted(); + assertEquals(3, processor.getProcessedChunksCount().get(), "All chunks should be processed"); + } + + /** + * A simple implementation of ImportProcessor for testing purposes. This class is used to test the + * thread executor behavior in ImportProcessor. + */ + static class TestImportProcessor extends ImportProcessor { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + // Tracking for testing + @Getter private final AtomicInteger processedChunksCount = new AtomicInteger(0); + + @Getter private final AtomicInteger maxConcurrentThreads = new AtomicInteger(0); + + private final AtomicInteger currentConcurrentThreads = new AtomicInteger(0); + + private final AtomicBoolean simulateInterruption = new AtomicBoolean(false); + + @Setter private CountDownLatch processingLatch; + + @Setter private long processingDelayMs = 0; + + public TestImportProcessor(ImportProcessorParams params) { + super(params); + // Add our tracking listener + addTrackingListener(); + } + + /** Sets whether to simulate an interruption during processing. */ + public void setSimulateInterruption(boolean value) { + this.simulateInterruption.set(value); + } + + @Override + protected void readDataChunks( + BufferedReader reader, int dataChunkSize, BlockingQueue dataChunkQueue) { + try { + List rows = new ArrayList<>(); + String line; + int rowNumber = 0; + + while ((line = reader.readLine()) != null) { + if (!line.trim().isEmpty()) { + // Create a simple JsonNode from the line + JsonNode jsonNode = OBJECT_MAPPER.readTree("{\"data\":\"" + line + "\"}"); + rows.add(new ImportRow(rowNumber++, jsonNode)); + + if (rows.size() >= dataChunkSize) { + ImportDataChunk dataChunk = + ImportDataChunk.builder() + .dataChunkId(rowNumber / dataChunkSize) + .sourceData(rows) + .build(); + dataChunkQueue.put(dataChunk); + rows = new ArrayList<>(); + + // Simulate interruption if requested (in the reader thread) + if (simulateInterruption.get()) { + Thread.currentThread().interrupt(); + throw new InterruptedException("Simulated interruption in reader"); + } + } + } + } + + // Add any remaining rows + if (!rows.isEmpty()) { + ImportDataChunk dataChunk = + ImportDataChunk.builder() + .dataChunkId(rowNumber / dataChunkSize + 1) + .sourceData(rows) + .build(); + dataChunkQueue.put(dataChunk); + } + } catch (IOException | InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Error reading data chunks", e); + } + } + + // Add a tracking listener to monitor thread behavior + private void addTrackingListener() { + super.addListener( + new ImportEventListener() { + @Override + public void onDataChunkStarted(ImportDataChunkStatus status) { + // Track concurrent threads + int current = currentConcurrentThreads.incrementAndGet(); + maxConcurrentThreads.set(Math.max(current, maxConcurrentThreads.get())); + + // Add processing delay if specified + if (processingDelayMs > 0) { + try { + Thread.sleep(processingDelayMs); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + // Wait on latch if provided + if (processingLatch != null) { + try { + processingLatch.await(100, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + // Simulate interruption if requested (in worker threads) + if (simulateInterruption.get()) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Simulated interruption in worker"); + } + } + + @Override + public void onDataChunkCompleted(ImportDataChunkStatus status) { + processedChunksCount.incrementAndGet(); + currentConcurrentThreads.decrementAndGet(); + } + + @Override + public void onAllDataChunksCompleted() { + // No action needed + } + + @Override + public void onTransactionBatchStarted(ImportTransactionBatchStatus batchStatus) { + // No action needed + } + + @Override + public void onTransactionBatchCompleted(ImportTransactionBatchResult batchResult) { + // No action needed + } + + @Override + public void onTaskComplete(ImportTaskResult taskResult) { + // No action needed + } + }); + } + } +}