Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Phaser;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -62,41 +62,77 @@ public abstract class ImportProcessor {
* @param reader the {@link BufferedReader} used to read the source file
*/
public void process(int dataChunkSize, int transactionBatchSize, BufferedReader reader) {
ExecutorService dataChunkExecutor = Executors.newSingleThreadExecutor();
ExecutorService dataChunkReaderExecutor = Executors.newSingleThreadExecutor();
ExecutorService dataChunkProcessorExecutor =
Executors.newFixedThreadPool(params.getImportOptions().getMaxThreads());
BlockingQueue<ImportDataChunk> dataChunkQueue =
new LinkedBlockingQueue<>(params.getImportOptions().getDataChunkQueueSize());

// Semaphore controls concurrent task submissions, small buffer to be two times of threads
Semaphore taskSemaphore = new Semaphore(params.getImportOptions().getMaxThreads() * 2);
// Phaser tracks task completion (start with 1 for the main thread)
Phaser phaser = new Phaser(1);

try {
CompletableFuture<Void> readerFuture =
CompletableFuture.runAsync(
() -> readDataChunks(reader, dataChunkSize, dataChunkQueue), dataChunkExecutor);
() -> readDataChunks(reader, dataChunkSize, dataChunkQueue), dataChunkReaderExecutor);

while (!(dataChunkQueue.isEmpty() && readerFuture.isDone())) {
ImportDataChunk dataChunk = dataChunkQueue.poll(100, TimeUnit.MILLISECONDS);
if (dataChunk != null) {
processDataChunk(dataChunk, transactionBatchSize);
// Acquire semaphore permit (blocks if no permits available)
taskSemaphore.acquire();
// Register with phaser before submitting
phaser.register();

dataChunkProcessorExecutor.submit(
() -> {
try {
processDataChunk(dataChunk, transactionBatchSize);
} finally {
// Always release semaphore and arrive at phaser
taskSemaphore.release();
phaser.arriveAndDeregister();
}
});
}
}

readerFuture.join();
// Wait for all tasks to complete
phaser.arriveAndAwaitAdvance();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(
CoreError.DATA_LOADER_DATA_CHUNK_PROCESS_FAILED.buildMessage(e.getMessage()), e);
} finally {
dataChunkExecutor.shutdown();
try {
if (!dataChunkExecutor.awaitTermination(60, TimeUnit.SECONDS)) {
dataChunkExecutor.shutdownNow();
}
} catch (InterruptedException e) {
dataChunkExecutor.shutdownNow();
Thread.currentThread().interrupt();
}
shutdownExecutorGracefully(dataChunkReaderExecutor);
shutdownExecutorGracefully(dataChunkProcessorExecutor);
notifyAllDataChunksCompleted();
}
}

/**
* Shuts down the given `ExecutorService` gracefully. This method attempts to cleanly shut down
* the executor by first invoking `shutdown` and waiting for termination for up to 60 seconds. If
* the executor does not terminate within this time, it forces a shutdown using `shutdownNow`. If
* interrupted, it forces a shutdown and interrupts the current thread.
*
* @param es the `ExecutorService` to be shut down gracefully
*/
private void shutdownExecutorGracefully(ExecutorService es) {
es.shutdown();
try {
if (!es.awaitTermination(60, TimeUnit.SECONDS)) {
es.shutdownNow();
}
} catch (InterruptedException e) {
es.shutdownNow();
Thread.currentThread().interrupt();
}
}

/**
* Reads and processes data in chunks from the provided reader.
*
Expand Down Expand Up @@ -373,46 +409,26 @@ private ImportDataChunkStatus processDataChunkWithTransactions(
Instant startTime = Instant.now();
List<ImportTransactionBatch> transactionBatches =
splitIntoTransactionBatches(dataChunk, transactionBatchSize);
ExecutorService transactionBatchExecutor =
Executors.newFixedThreadPool(params.getImportOptions().getMaxThreads());
List<Future<?>> transactionBatchFutures = new ArrayList<>();
AtomicInteger successCount = new AtomicInteger(0);
AtomicInteger failureCount = new AtomicInteger(0);
try {
for (ImportTransactionBatch transactionBatch : transactionBatches) {
Future<?> transactionBatchFuture =
transactionBatchExecutor.submit(
() -> processTransactionBatch(dataChunk.getDataChunkId(), transactionBatch));
transactionBatchFutures.add(transactionBatchFuture);
}

waitForFuturesToComplete(transactionBatchFutures);
transactionBatchFutures.forEach(
batchResult -> {
try {
ImportTransactionBatchResult importTransactionBatchResult =
(ImportTransactionBatchResult) batchResult.get();
importTransactionBatchResult
.getRecords()
.forEach(
batchRecords -> {
if (batchRecords.getTargets().stream()
.allMatch(
targetResult ->
targetResult
.getStatus()
.equals(ImportTargetResultStatus.SAVED))) {
successCount.incrementAndGet();
} else {
failureCount.incrementAndGet();
}
});
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
});
} finally {
transactionBatchExecutor.shutdown();
for (ImportTransactionBatch transactionBatch : transactionBatches) {
ImportTransactionBatchResult importTransactionBatchResult =
processTransactionBatch(dataChunk.getDataChunkId(), transactionBatch);

importTransactionBatchResult
.getRecords()
.forEach(
batchRecords -> {
if (batchRecords.getTargets().stream()
.allMatch(
targetResult ->
targetResult.getStatus().equals(ImportTargetResultStatus.SAVED))) {
successCount.incrementAndGet();
} else {
failureCount.incrementAndGet();
}
});
}
Instant endTime = Instant.now();
int totalDuration = (int) Duration.between(startTime, endTime).toMillis();
Expand Down Expand Up @@ -440,32 +456,17 @@ private ImportDataChunkStatus processDataChunkWithoutTransactions(ImportDataChun
Instant startTime = Instant.now();
AtomicInteger successCount = new AtomicInteger(0);
AtomicInteger failureCount = new AtomicInteger(0);
ExecutorService recordExecutor =
Executors.newFixedThreadPool(params.getImportOptions().getMaxThreads());
List<Future<?>> recordFutures = new ArrayList<>();
try {
for (ImportRow importRow : dataChunk.getSourceData()) {
Future<?> recordFuture =
recordExecutor.submit(
() -> processStorageRecord(dataChunk.getDataChunkId(), importRow));
recordFutures.add(recordFuture);

for (ImportRow importRow : dataChunk.getSourceData()) {
ImportTaskResult result = processStorageRecord(dataChunk.getDataChunkId(), importRow);
boolean allSaved =
result.getTargets().stream()
.allMatch(t -> t.getStatus().equals(ImportTargetResultStatus.SAVED));
if (allSaved) {
successCount.incrementAndGet();
} else {
failureCount.incrementAndGet();
}
waitForFuturesToComplete(recordFutures);
recordFutures.forEach(
r -> {
try {
ImportTaskResult result = (ImportTaskResult) r.get();
boolean allSaved =
result.getTargets().stream()
.allMatch(t -> t.getStatus().equals(ImportTargetResultStatus.SAVED));
if (allSaved) successCount.incrementAndGet();
else failureCount.incrementAndGet();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
});
} finally {
recordExecutor.shutdown();
}
Instant endTime = Instant.now();
int totalDuration = (int) Duration.between(startTime, endTime).toMillis();
Expand All @@ -480,20 +481,4 @@ private ImportDataChunkStatus processDataChunkWithoutTransactions(ImportDataChun
.status(ImportDataChunkStatusState.COMPLETE)
.build();
}

/**
* Waits for all futures in the provided list to complete. Any exceptions during execution are
* logged but not propagated.
*
* @param futures the list of {@link Future} objects to wait for
*/
private void waitForFuturesToComplete(List<Future<?>> futures) {
for (Future<?> future : futures) {
try {
future.get();
} catch (Exception e) {
LOGGER.error(e.getMessage());
}
}
}
}
Loading