diff --git a/src/main/java/de/rub/nds/crawler/constant/JobStatus.java b/src/main/java/de/rub/nds/crawler/constant/JobStatus.java index 99c521b..03765fd 100644 --- a/src/main/java/de/rub/nds/crawler/constant/JobStatus.java +++ b/src/main/java/de/rub/nds/crawler/constant/JobStatus.java @@ -15,6 +15,8 @@ public enum JobStatus { /** Job is waiting to be executed. */ TO_BE_EXECUTED(false), + /** Job is currently being executed. Partial results may be available in DB. */ + RUNNING(false), /** The domain was not resolvable. An empty result was written to DB. */ UNRESOLVABLE(true), /** An uncaught exception occurred while resolving the host. */ diff --git a/src/main/java/de/rub/nds/crawler/core/BulkScanWorker.java b/src/main/java/de/rub/nds/crawler/core/BulkScanWorker.java index 11831cc..0a87055 100644 --- a/src/main/java/de/rub/nds/crawler/core/BulkScanWorker.java +++ b/src/main/java/de/rub/nds/crawler/core/BulkScanWorker.java @@ -9,15 +9,17 @@ package de.rub.nds.crawler.core; import de.rub.nds.crawler.data.ScanConfig; +import de.rub.nds.crawler.data.ScanJobDescription; import de.rub.nds.crawler.data.ScanTarget; +import de.rub.nds.crawler.persistence.IPersistenceProvider; import de.rub.nds.crawler.util.CanceallableThreadPoolExecutor; import de.rub.nds.scanner.core.execution.NamedThreadFactory; -import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.bson.Document; @@ -41,6 +43,9 @@ public abstract class BulkScanWorker { /** The scan configuration for this worker */ protected final T scanConfig; + /** The persistence provider for writing partial results */ + protected final IPersistenceProvider persistenceProvider; + /** * Calls the inner scan function and may handle cleanup. This is needed to wrap the scanner into * a future object such that we can handle timeouts properly. @@ -55,10 +60,16 @@ public abstract class BulkScanWorker { * @param scanConfig The scan configuration for this worker * @param parallelScanThreads The number of parallel scan threads to use, i.e., how many {@link * ScanTarget}s to handle in parallel. + * @param persistenceProvider The persistence provider for writing partial results */ - protected BulkScanWorker(String bulkScanId, T scanConfig, int parallelScanThreads) { + protected BulkScanWorker( + String bulkScanId, + T scanConfig, + int parallelScanThreads, + IPersistenceProvider persistenceProvider) { this.bulkScanId = bulkScanId; this.scanConfig = scanConfig; + this.persistenceProvider = persistenceProvider; timeoutExecutor = new CanceallableThreadPoolExecutor( @@ -74,31 +85,59 @@ protected BulkScanWorker(String bulkScanId, T scanConfig, int parallelScanThread * Handles a scan target by submitting it to the executor. If init was not called, it will * initialize itself. In this case it will also clean up itself if all jobs are done. * - * @param scanTarget The target to scan. - * @return A future that resolves to the scan result once the scan is done. + *

Returns a {@link ProgressableFuture} that represents the entire scan lifecycle, allowing + * callers to: + * + *

+ * + * @param jobDescription The job description for this scan. + * @return A ProgressableFuture representing the scan lifecycle */ - public Future handle(ScanTarget scanTarget) { + public ProgressableFuture handle(ScanJobDescription jobDescription) { // if we initialized ourself, we also clean up ourself shouldCleanupSelf.weakCompareAndSetAcquire(false, init()); activeJobs.incrementAndGet(); - return timeoutExecutor.submit( + + ProgressableFuture progressableFuture = new ProgressableFuture<>(); + + // Compose a consumer that both updates the future and persists partial results + Consumer progressConsumer = + partialResult -> { + progressableFuture.updateResult(partialResult); + persistPartialResult(jobDescription, partialResult); + }; + + timeoutExecutor.submit( () -> { - Document result = scan(scanTarget); - if (activeJobs.decrementAndGet() == 0 && shouldCleanupSelf.get()) { - cleanup(); + try { + Document result = scan(jobDescription, progressConsumer); + progressableFuture.complete(result); + } catch (Exception e) { + progressableFuture.completeExceptionally(e); + } finally { + if (activeJobs.decrementAndGet() == 0 && shouldCleanupSelf.get()) { + cleanup(); + } } - return result; }); + + return progressableFuture; } /** * Scans a target and returns the result as a Document. This is the core scanning functionality * that must be implemented by subclasses. * - * @param scanTarget The target to scan + * @param jobDescription The job description containing target and metadata + * @param progressConsumer Consumer to call with partial results during scanning * @return The scan result as a Document */ - public abstract Document scan(ScanTarget scanTarget); + public abstract Document scan( + ScanJobDescription jobDescription, Consumer progressConsumer); /** * Initializes this worker if it hasn't been initialized yet. This method is thread-safe and @@ -161,4 +200,15 @@ public final boolean cleanup() { * specific resources. */ protected abstract void cleanupInternal(); + + /** + * Persists a partial scan result. This method can be called by subclasses during scanning to + * save intermediate results. + * + * @param jobDescription The job description for the scan + * @param partialResult The partial result document to persist + */ + protected void persistPartialResult(ScanJobDescription jobDescription, Document partialResult) { + persistenceProvider.upsertPartialResult(jobDescription, partialResult); + } } diff --git a/src/main/java/de/rub/nds/crawler/core/BulkScanWorkerManager.java b/src/main/java/de/rub/nds/crawler/core/BulkScanWorkerManager.java index 3e78782..0c33f4e 100644 --- a/src/main/java/de/rub/nds/crawler/core/BulkScanWorkerManager.java +++ b/src/main/java/de/rub/nds/crawler/core/BulkScanWorkerManager.java @@ -14,8 +14,8 @@ import de.rub.nds.crawler.data.BulkScanInfo; import de.rub.nds.crawler.data.ScanConfig; import de.rub.nds.crawler.data.ScanJobDescription; +import de.rub.nds.crawler.persistence.IPersistenceProvider; import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.exception.UncheckedException; import org.apache.logging.log4j.LogManager; @@ -58,21 +58,27 @@ public static BulkScanWorkerManager getInstance() { /** * Static convenience method to handle a scan job. See also {@link #handle(ScanJobDescription, - * int, int)}. + * int, int, IPersistenceProvider)}. * * @param scanJobDescription The scan job to handle * @param parallelConnectionThreads The number of parallel connection threads to use (used to * create worker if it does not exist) * @param parallelScanThreads The number of parallel scan threads to use (used to create worker * if it does not exist) - * @return A future that returns the scan result when the target is scanned is done + * @param persistenceProvider The persistence provider for writing partial results + * @return A ProgressableFuture representing the scan lifecycle */ - public static Future handleStatic( + public static ProgressableFuture handleStatic( ScanJobDescription scanJobDescription, int parallelConnectionThreads, - int parallelScanThreads) { + int parallelScanThreads, + IPersistenceProvider persistenceProvider) { BulkScanWorkerManager manager = getInstance(); - return manager.handle(scanJobDescription, parallelConnectionThreads, parallelScanThreads); + return manager.handle( + scanJobDescription, + parallelConnectionThreads, + parallelScanThreads, + persistenceProvider); } private final Cache> bulkScanWorkers; @@ -102,6 +108,7 @@ private BulkScanWorkerManager() { * create worker if it does not exist) * @param parallelScanThreads The number of parallel scan threads to use (used to create worker * if it does not exist) + * @param persistenceProvider The persistence provider for writing partial results * @return A bulk scan worker for the specified bulk scan * @throws UncheckedException If a worker cannot be created */ @@ -109,14 +116,18 @@ public BulkScanWorker getBulkScanWorker( String bulkScanId, ScanConfig scanConfig, int parallelConnectionThreads, - int parallelScanThreads) { + int parallelScanThreads, + IPersistenceProvider persistenceProvider) { try { return bulkScanWorkers.get( bulkScanId, () -> { BulkScanWorker ret = scanConfig.createWorker( - bulkScanId, parallelConnectionThreads, parallelScanThreads); + bulkScanId, + parallelConnectionThreads, + parallelScanThreads, + persistenceProvider); ret.init(); return ret; }); @@ -135,19 +146,22 @@ public BulkScanWorker getBulkScanWorker( * create worker if it does not exist) * @param parallelScanThreads The number of parallel scan threads to use (used to create worker * if it does not exist) - * @return A future that returns the scan result when the target is scanned is done + * @param persistenceProvider The persistence provider for writing partial results + * @return A ProgressableFuture representing the scan lifecycle */ - public Future handle( + public ProgressableFuture handle( ScanJobDescription scanJobDescription, int parallelConnectionThreads, - int parallelScanThreads) { + int parallelScanThreads, + IPersistenceProvider persistenceProvider) { BulkScanInfo bulkScanInfo = scanJobDescription.getBulkScanInfo(); BulkScanWorker worker = getBulkScanWorker( bulkScanInfo.getBulkScanId(), bulkScanInfo.getScanConfig(), parallelConnectionThreads, - parallelScanThreads); - return worker.handle(scanJobDescription.getScanTarget()); + parallelScanThreads, + persistenceProvider); + return worker.handle(scanJobDescription); } } diff --git a/src/main/java/de/rub/nds/crawler/core/ProgressableFuture.java b/src/main/java/de/rub/nds/crawler/core/ProgressableFuture.java new file mode 100644 index 0000000..e77a456 --- /dev/null +++ b/src/main/java/de/rub/nds/crawler/core/ProgressableFuture.java @@ -0,0 +1,101 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.core; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * A Future implementation that supports tracking progress through partial results. + * + *

This class extends the standard {@link Future} contract with the ability to: + * + *

    + *
  • Get the current partial result via {@link #getCurrentResult()} + *
  • Update the partial result as work progresses via {@link #updateResult(Object)} + *
  • Wait for the final result via standard Future methods ({@link #get()}, {@link #get(long, + * TimeUnit)}) + *
+ * + * @param The type of result this future produces + */ +public class ProgressableFuture implements Future { + + private volatile T currentResult; + private final CompletableFuture delegate = new CompletableFuture<>(); + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return delegate.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return delegate.isCancelled(); + } + + @Override + public boolean isDone() { + return delegate.isDone(); + } + + @Override + public T get() throws InterruptedException, ExecutionException { + return delegate.get(); + } + + @Override + public T get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return delegate.get(timeout, unit); + } + + /** + * Get the current result. If the operation is still in progress, this returns the latest + * partial result. If the operation is complete, this returns the final result. + * + * @return The current result, or null if no result is available yet + */ + public T getCurrentResult() { + return currentResult; + } + + /** + * Update the current result with a partial result. This is called during processing when new + * partial results are available. + * + * @param partialResult The updated partial result + */ + public void updateResult(T partialResult) { + this.currentResult = partialResult; + } + + /** + * Mark the operation as complete with the final result. This will complete the Future and + * notify any waiting consumers. + * + * @param result The final result + */ + void complete(T result) { + this.currentResult = result; + this.delegate.complete(result); + } + + /** + * Mark the operation as failed with an exception. + * + * @param exception The exception that caused the failure + */ + void completeExceptionally(Throwable exception) { + this.delegate.completeExceptionally(exception); + } +} diff --git a/src/main/java/de/rub/nds/crawler/core/Worker.java b/src/main/java/de/rub/nds/crawler/core/Worker.java index 1608e10..09afe0e 100644 --- a/src/main/java/de/rub/nds/crawler/core/Worker.java +++ b/src/main/java/de/rub/nds/crawler/core/Worker.java @@ -70,20 +70,20 @@ public void start() { } private ScanResult waitForScanResult( - Future resultFuture, ScanJobDescription scanJobDescription) + ProgressableFuture progressableFuture, ScanJobDescription scanJobDescription) throws ExecutionException, InterruptedException, TimeoutException { Document resultDocument; JobStatus jobStatus; try { - resultDocument = resultFuture.get(scanTimeout, TimeUnit.MILLISECONDS); + resultDocument = progressableFuture.get(scanTimeout, TimeUnit.MILLISECONDS); jobStatus = resultDocument != null ? JobStatus.SUCCESS : JobStatus.EMPTY; } catch (TimeoutException e) { LOGGER.info( "Trying to shutdown scan of '{}' because timeout reached", scanJobDescription.getScanTarget()); - resultFuture.cancel(true); + progressableFuture.cancel(true); // after interrupting, the scan should return as soon as possible - resultDocument = resultFuture.get(10, TimeUnit.SECONDS); + resultDocument = progressableFuture.get(10, TimeUnit.SECONDS); jobStatus = JobStatus.CANCELLED; } scanJobDescription.setStatus(jobStatus); @@ -92,15 +92,19 @@ private ScanResult waitForScanResult( private void handleScanJob(ScanJobDescription scanJobDescription) { LOGGER.info("Received scan job for {}", scanJobDescription.getScanTarget()); - Future resultFuture = + ProgressableFuture progressableFuture = BulkScanWorkerManager.handleStatic( - scanJobDescription, parallelConnectionThreads, parallelScanThreads); + scanJobDescription, + parallelConnectionThreads, + parallelScanThreads, + persistenceProvider); + workerExecutor.submit( () -> { ScanResult scanResult = null; boolean persist = true; try { - scanResult = waitForScanResult(resultFuture, scanJobDescription); + scanResult = waitForScanResult(progressableFuture, scanJobDescription); } catch (InterruptedException e) { LOGGER.error("Worker was interrupted - not persisting anything", e); scanJobDescription.setStatus(JobStatus.INTERNAL_ERROR); @@ -118,7 +122,7 @@ private void handleScanJob(ScanJobDescription scanJobDescription) { "Scan of '{}' did not finish in time and did not cancel gracefully", scanJobDescription.getScanTarget()); scanJobDescription.setStatus(JobStatus.CANCELLED); - resultFuture.cancel(true); + progressableFuture.cancel(true); scanResult = ScanResult.fromException(scanJobDescription, e); } catch (Exception e) { LOGGER.error( diff --git a/src/main/java/de/rub/nds/crawler/data/ScanConfig.java b/src/main/java/de/rub/nds/crawler/data/ScanConfig.java index e7bcd72..3cc5a87 100644 --- a/src/main/java/de/rub/nds/crawler/data/ScanConfig.java +++ b/src/main/java/de/rub/nds/crawler/data/ScanConfig.java @@ -9,6 +9,7 @@ package de.rub.nds.crawler.data; import de.rub.nds.crawler.core.BulkScanWorker; +import de.rub.nds.crawler.persistence.IPersistenceProvider; import de.rub.nds.scanner.core.config.ScannerDetail; import de.rub.nds.scanner.core.probe.ProbeType; import java.io.Serializable; @@ -122,8 +123,12 @@ public void setExcludedProbes(List excludedProbes) { * @param bulkScanID The ID of the bulk scan this worker is for * @param parallelConnectionThreads The number of parallel connection threads to use * @param parallelScanThreads The number of parallel scan threads to use + * @param persistenceProvider The persistence provider for writing partial results * @return A worker for this scan configuration */ public abstract BulkScanWorker createWorker( - String bulkScanID, int parallelConnectionThreads, int parallelScanThreads); + String bulkScanID, + int parallelConnectionThreads, + int parallelScanThreads, + IPersistenceProvider persistenceProvider); } diff --git a/src/main/java/de/rub/nds/crawler/data/ScanTarget.java b/src/main/java/de/rub/nds/crawler/data/ScanTarget.java index 0ef2142..eced16e 100644 --- a/src/main/java/de/rub/nds/crawler/data/ScanTarget.java +++ b/src/main/java/de/rub/nds/crawler/data/ScanTarget.java @@ -71,7 +71,6 @@ public static Pair fromTargetString( } if (targetString.startsWith("\"") && targetString.endsWith("\"")) { targetString = targetString.replace("\"", ""); - System.out.println(targetString); } // check if targetString contains port (e.g. "www.example.com:8080" or "[2001:db8::1]:8080") @@ -82,7 +81,7 @@ public static Pair fromTargetString( String portString = targetString.substring(bracketEnd + 2); try { int port = Integer.parseInt(portString); - if (port > 1 && port < 65535) { + if (port >= 1 && port <= 65535) { target.setPort(port); } else { target.setPort(defaultPort); @@ -99,7 +98,7 @@ public static Pair fromTargetString( // Likely IPv4 or hostname with port try { int port = Integer.parseInt(parts[1]); - if (port > 1 && port < 65535) { + if (port >= 1 && port <= 65535) { target.setPort(port); } else { target.setPort(defaultPort); diff --git a/src/main/java/de/rub/nds/crawler/persistence/IPersistenceProvider.java b/src/main/java/de/rub/nds/crawler/persistence/IPersistenceProvider.java index 30d2ffb..aec842e 100644 --- a/src/main/java/de/rub/nds/crawler/persistence/IPersistenceProvider.java +++ b/src/main/java/de/rub/nds/crawler/persistence/IPersistenceProvider.java @@ -73,4 +73,13 @@ public interface IPersistenceProvider { */ ScanResult getScanResultByScanJobDescriptionId( String dbName, String collectionName, String scanJobDescriptionId); + + /** + * Upsert a partial scan result into the database. Uses the job ID as the document ID, so + * subsequent calls with the same job will overwrite the previous partial result. + * + * @param job The scan job description (provides ID, database name, collection name). + * @param partialResult The partial result document to upsert. + */ + void upsertPartialResult(ScanJobDescription job, org.bson.Document partialResult); } diff --git a/src/main/java/de/rub/nds/crawler/persistence/MongoPersistenceProvider.java b/src/main/java/de/rub/nds/crawler/persistence/MongoPersistenceProvider.java index a1278c1..6eab7b6 100644 --- a/src/main/java/de/rub/nds/crawler/persistence/MongoPersistenceProvider.java +++ b/src/main/java/de/rub/nds/crawler/persistence/MongoPersistenceProvider.java @@ -25,6 +25,7 @@ import com.mongodb.client.MongoClients; import com.mongodb.client.MongoDatabase; import com.mongodb.client.model.Indexes; +import com.mongodb.client.model.ReplaceOptions; import com.mongodb.lang.NonNull; import de.rub.nds.crawler.config.delegate.MongoDbDelegate; import de.rub.nds.crawler.constant.JobStatus; @@ -392,4 +393,32 @@ public ScanResult getScanResultByScanJobDescriptionId( e); } } + + @Override + public void upsertPartialResult(ScanJobDescription job, org.bson.Document partialResult) { + String dbName = job.getDbName(); + String collectionName = job.getCollectionName(); + String jobId = job.getId().toString(); + + LOGGER.debug( + "Upserting partial result for job {} into collection: {}.{}", + jobId, + dbName, + collectionName); + + try { + // Get raw MongoDB collection (not JacksonMongoCollection) for Document operations + var collection = databaseCache.getUnchecked(dbName).getCollection(collectionName); + + // Upsert: replace if exists, insert if not + collection.replaceOne( + new org.bson.Document("_id", jobId), + partialResult, + new ReplaceOptions().upsert(true)); + + LOGGER.debug("Upserted partial result for job {}", jobId); + } catch (Exception e) { + LOGGER.warn("Failed to upsert partial result for job {}: {}", jobId, e.getMessage()); + } + } } diff --git a/src/test/java/de/rub/nds/crawler/core/BulkScanWorkerTest.java b/src/test/java/de/rub/nds/crawler/core/BulkScanWorkerTest.java new file mode 100644 index 0000000..35b1dd3 --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/core/BulkScanWorkerTest.java @@ -0,0 +1,385 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.core; + +import static org.junit.jupiter.api.Assertions.*; + +import de.rub.nds.crawler.constant.JobStatus; +import de.rub.nds.crawler.data.BulkScan; +import de.rub.nds.crawler.data.ScanConfig; +import de.rub.nds.crawler.data.ScanJobDescription; +import de.rub.nds.crawler.data.ScanResult; +import de.rub.nds.crawler.data.ScanTarget; +import de.rub.nds.crawler.dummy.DummyPersistenceProvider; +import de.rub.nds.crawler.persistence.IPersistenceProvider; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import org.bson.Document; +import org.junit.jupiter.api.Test; + +class BulkScanWorkerTest { + + // Test implementation of ScanConfig + static class TestScanConfig extends ScanConfig implements Serializable { + public TestScanConfig() { + super(de.rub.nds.scanner.core.config.ScannerDetail.NORMAL, 0, 60); + } + + @Override + public BulkScanWorker createWorker( + String bulkScanID, + int parallelConnectionThreads, + int parallelScanThreads, + IPersistenceProvider persistenceProvider) { + return new TestBulkScanWorker( + bulkScanID, this, parallelScanThreads, persistenceProvider); + } + } + + // Test implementation of BulkScanWorker + static class TestBulkScanWorker extends BulkScanWorker { + private boolean initCalled = false; + private boolean cleanupCalled = false; + private ScanJobDescription capturedJobDescription = null; + + TestBulkScanWorker( + String bulkScanId, + TestScanConfig scanConfig, + int parallelScanThreads, + IPersistenceProvider persistenceProvider) { + super(bulkScanId, scanConfig, parallelScanThreads, persistenceProvider); + } + + @Override + public Document scan( + ScanJobDescription jobDescription, Consumer progressConsumer) { + // Capture the job description during scan + capturedJobDescription = jobDescription; + ScanTarget scanTarget = jobDescription.getScanTarget(); + + Document result = new Document(); + result.put("target", scanTarget.getIp()); + result.put("hasJobDescription", jobDescription != null); + if (jobDescription != null) { + result.put("jobId", jobDescription.getId().toString()); + } + return result; + } + + @Override + protected void initInternal() { + initCalled = true; + } + + @Override + protected void cleanupInternal() { + cleanupCalled = true; + } + + public boolean isInitCalled() { + return initCalled; + } + + public boolean isCleanupCalled() { + return cleanupCalled; + } + + public ScanJobDescription getCapturedJobDescription() { + return capturedJobDescription; + } + } + + @Test + void testGetCurrentJobDescriptionReturnsNullOutsideScanContext() { + TestScanConfig config = new TestScanConfig(); + TestBulkScanWorker worker = + new TestBulkScanWorker("test-bulk-id", config, 1, new DummyPersistenceProvider()); + + // getCurrentJobDescription() is protected, so we can't call it directly from test + // But we can verify through the scan() method that it returns null when not in context + assertNull( + worker.getCapturedJobDescription(), + "Job description should be null before any scan"); + } + + @Test + void testGetCurrentJobDescriptionReturnsCorrectJobInScanContext() throws Exception { + TestScanConfig config = new TestScanConfig(); + TestBulkScanWorker worker = + new TestBulkScanWorker("test-bulk-id", config, 1, new DummyPersistenceProvider()); + + ScanTarget target = new ScanTarget(); + target.setIp("192.0.2.1"); // TEST-NET-1 (RFC 5737) + target.setPort(443); + + BulkScan bulkScan = + new BulkScan( + BulkScanWorkerTest.class, + BulkScanWorkerTest.class, + "test-db", + config, + System.currentTimeMillis(), + false, + null); + + ScanJobDescription jobDescription = + new ScanJobDescription(target, bulkScan, JobStatus.TO_BE_EXECUTED); + + // Execute the scan + ProgressableFuture future = worker.handle(jobDescription); + Document result = future.get(); + + // Verify the job description was available during scan + assertTrue( + result.getBoolean("hasJobDescription"), + "Job description should be available in scan context"); + assertEquals(jobDescription.getId().toString(), result.getString("jobId")); + + // Verify the captured job description matches + assertNotNull(worker.getCapturedJobDescription()); + assertEquals(jobDescription.getId(), worker.getCapturedJobDescription().getId()); + assertEquals(target, worker.getCapturedJobDescription().getScanTarget()); + + // Simulate the partial results persistence flow + DummyPersistenceProvider persistenceProvider = new DummyPersistenceProvider(); + + // Update job status to SUCCESS (required by ScanResult constructor) + jobDescription.setStatus(JobStatus.SUCCESS); + + // Create ScanResult from the scan result Document and job description + ScanResult scanResult = new ScanResult(jobDescription, result); + + // Verify ScanResult has the correct scanJobDescriptionId + assertEquals( + jobDescription.getId().toString(), + scanResult.getScanJobDescriptionId(), + "ScanResult should use job description UUID as scanJobDescriptionId"); + + // Simulate persisting to MongoDB + persistenceProvider.insertScanResult(scanResult, jobDescription); + + // Simulate retrieving from MongoDB by scanJobDescriptionId + ScanResult retrievedResult = + persistenceProvider.getScanResultByScanJobDescriptionId( + "test-db", "test-collection", jobDescription.getId().toString()); + + // Verify the retrieved result matches + assertNotNull( + retrievedResult, "Should be able to retrieve ScanResult by job description ID"); + assertEquals( + jobDescription.getId().toString(), + retrievedResult.getScanJobDescriptionId(), + "Retrieved result should have matching scanJobDescriptionId"); + assertEquals( + scanResult.getBulkScan(), + retrievedResult.getBulkScan(), + "Retrieved result should have matching bulk scan ID"); + assertEquals( + scanResult.getScanTarget(), + retrievedResult.getScanTarget(), + "Retrieved result should have matching scan target"); + assertEquals( + scanResult.getResult(), + retrievedResult.getResult(), + "Retrieved result should have matching result document"); + } + + @Test + void testThreadLocalIsCleanedUpAfterScan() throws Exception { + TestScanConfig config = new TestScanConfig(); + TestBulkScanWorker worker = + new TestBulkScanWorker("test-bulk-id", config, 1, new DummyPersistenceProvider()); + + ScanTarget target = new ScanTarget(); + target.setIp("192.0.2.1"); // TEST-NET-1 (RFC 5737) + target.setPort(443); + + BulkScan bulkScan = + new BulkScan( + BulkScanWorkerTest.class, + BulkScanWorkerTest.class, + "test-db", + config, + System.currentTimeMillis(), + false, + null); + + ScanJobDescription jobDescription = + new ScanJobDescription(target, bulkScan, JobStatus.TO_BE_EXECUTED); + + // Execute the scan + ProgressableFuture future = worker.handle(jobDescription); + future.get(); // Wait for completion + + // After scan completes, verify we can run another scan + ScanTarget newTarget = new ScanTarget(); + newTarget.setIp("192.0.2.2"); // TEST-NET-1 (RFC 5737) + newTarget.setPort(443); + + ScanJobDescription newJobDescription = + new ScanJobDescription(newTarget, bulkScan, JobStatus.TO_BE_EXECUTED); + + ProgressableFuture future2 = worker.handle(newJobDescription); + Document result2 = future2.get(); + + // The second scan should have the second job description, not the first + assertEquals(newJobDescription.getId().toString(), result2.getString("jobId")); + assertEquals(newJobDescription.getId(), worker.getCapturedJobDescription().getId()); + } + + @Test + void testMultipleConcurrentScansHaveSeparateContexts() throws Exception { + TestScanConfig config = new TestScanConfig(); + TestBulkScanWorker worker = + new TestBulkScanWorker("test-bulk-id", config, 2, new DummyPersistenceProvider()); + + BulkScan bulkScan = + new BulkScan( + BulkScanWorkerTest.class, + BulkScanWorkerTest.class, + "test-db", + config, + System.currentTimeMillis(), + false, + null); + + // Create multiple job descriptions + List jobDescriptions = new ArrayList<>(); + List> futures = new ArrayList<>(); + + for (int i = 0; i < 5; i++) { + ScanTarget target = new ScanTarget(); + target.setIp("192.0.2." + (i + 1)); // TEST-NET-1 (RFC 5737) + target.setPort(443); + + ScanJobDescription jobDescription = + new ScanJobDescription(target, bulkScan, JobStatus.TO_BE_EXECUTED); + jobDescriptions.add(jobDescription); + + futures.add(worker.handle(jobDescription)); + } + + // Wait for all scans to complete and verify each got the correct job description + for (int i = 0; i < 5; i++) { + Document result = futures.get(i).get(); + assertTrue(result.getBoolean("hasJobDescription")); + assertEquals( + jobDescriptions.get(i).getId().toString(), + result.getString("jobId"), + "Scan " + i + " should have its own job description"); + } + } + + @Test + void testInitializationIsCalledOnFirstHandle() throws Exception { + TestScanConfig config = new TestScanConfig(); + TestBulkScanWorker worker = + new TestBulkScanWorker("test-bulk-id", config, 1, new DummyPersistenceProvider()); + + assertFalse(worker.isInitCalled(), "Init should not be called before first handle"); + + ScanTarget target = new ScanTarget(); + target.setIp("192.0.2.1"); // TEST-NET-1 (RFC 5737) + target.setPort(443); + + BulkScan bulkScan = + new BulkScan( + BulkScanWorkerTest.class, + BulkScanWorkerTest.class, + "test-db", + config, + System.currentTimeMillis(), + false, + null); + + ScanJobDescription jobDescription = + new ScanJobDescription(target, bulkScan, JobStatus.TO_BE_EXECUTED); + + ProgressableFuture future = worker.handle(jobDescription); + future.get(); + + assertTrue(worker.isInitCalled(), "Init should be called on first handle"); + } + + @Test + void testCleanupIsCalledWhenAllJobsComplete() throws Exception { + TestScanConfig config = new TestScanConfig(); + TestBulkScanWorker worker = + new TestBulkScanWorker("test-bulk-id", config, 1, new DummyPersistenceProvider()); + + ScanTarget target = new ScanTarget(); + target.setIp("192.0.2.1"); // TEST-NET-1 (RFC 5737) + target.setPort(443); + + BulkScan bulkScan = + new BulkScan( + BulkScanWorkerTest.class, + BulkScanWorkerTest.class, + "test-db", + config, + System.currentTimeMillis(), + false, + null); + + ScanJobDescription jobDescription = + new ScanJobDescription(target, bulkScan, JobStatus.TO_BE_EXECUTED); + + ProgressableFuture future = worker.handle(jobDescription); + future.get(); + + // Give cleanup a moment to execute (it runs after job completion) + Thread.sleep(100); + + assertTrue(worker.isCleanupCalled(), "Cleanup should be called when all jobs complete"); + } + + @Test + void testManualInitPreventsSelfCleanup() throws Exception { + TestScanConfig config = new TestScanConfig(); + TestBulkScanWorker worker = + new TestBulkScanWorker("test-bulk-id", config, 1, new DummyPersistenceProvider()); + + // Call init manually + worker.init(); + assertTrue(worker.isInitCalled(), "Init should be called"); + + ScanTarget target = new ScanTarget(); + target.setIp("192.0.2.1"); // TEST-NET-1 (RFC 5737) + target.setPort(443); + + BulkScan bulkScan = + new BulkScan( + BulkScanWorkerTest.class, + BulkScanWorkerTest.class, + "test-db", + config, + System.currentTimeMillis(), + false, + null); + + ScanJobDescription jobDescription = + new ScanJobDescription(target, bulkScan, JobStatus.TO_BE_EXECUTED); + + ProgressableFuture future = worker.handle(jobDescription); + future.get(); + + // Give cleanup a moment (if it were to execute) + Thread.sleep(100); + + assertFalse( + worker.isCleanupCalled(), + "Cleanup should NOT be called when init was manual (shouldCleanupSelf = false)"); + + // Cleanup should only be called when we explicitly call it + worker.cleanup(); + assertTrue(worker.isCleanupCalled(), "Cleanup should be called when explicitly called"); + } +} diff --git a/src/test/java/de/rub/nds/crawler/core/ControllerTest.java b/src/test/java/de/rub/nds/crawler/core/ControllerTest.java index f8922e4..614a07f 100644 --- a/src/test/java/de/rub/nds/crawler/core/ControllerTest.java +++ b/src/test/java/de/rub/nds/crawler/core/ControllerTest.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -42,8 +43,9 @@ void submitting() throws IOException, InterruptedException { Controller controller = new Controller(config, orchestrationProvider, persistenceProvider); controller.start(); - Thread.sleep(1000); - + Assertions.assertTrue( + orchestrationProvider.waitForJobs(2, 5, TimeUnit.SECONDS), + "Timed out waiting for jobs to be submitted"); Assertions.assertEquals(2, orchestrationProvider.jobQueue.size()); Assertions.assertEquals(0, orchestrationProvider.unackedJobs.size()); } @@ -70,8 +72,9 @@ void submittingWithExcludedProbes() throws IOException, InterruptedException { Controller controller = new Controller(config, orchestrationProvider, persistenceProvider); controller.start(); - Thread.sleep(1000); - + Assertions.assertTrue( + orchestrationProvider.waitForJobs(2, 5, TimeUnit.SECONDS), + "Timed out waiting for jobs to be submitted"); Assertions.assertEquals(2, orchestrationProvider.jobQueue.size()); Assertions.assertEquals(0, orchestrationProvider.unackedJobs.size()); @@ -103,18 +106,16 @@ void submittingWithoutExcludedProbes() throws IOException, InterruptedException Controller controller = new Controller(config, orchestrationProvider, persistenceProvider); controller.start(); - Thread.sleep(1000); - + Assertions.assertTrue( + orchestrationProvider.waitForJobs(1, 5, TimeUnit.SECONDS), + "Timed out waiting for jobs to be submitted"); Assertions.assertEquals(1, orchestrationProvider.jobQueue.size()); ScanJobDescription job = orchestrationProvider.jobQueue.peek(); List jobExcludedProbes = job.getBulkScanInfo().getScanConfig().getExcludedProbes(); - if (jobExcludedProbes == null) { - Assertions.assertNull(jobExcludedProbes, "Expected excluded probes to be null"); - } else { - Assertions.assertTrue( - jobExcludedProbes.isEmpty(), "Expected excluded probes to be empty"); - } + Assertions.assertTrue( + jobExcludedProbes == null || jobExcludedProbes.isEmpty(), + "Expected excluded probes to be null or empty"); } } diff --git a/src/test/java/de/rub/nds/crawler/dummy/DummyControllerCommandConfig.java b/src/test/java/de/rub/nds/crawler/dummy/DummyControllerCommandConfig.java index 0c4f28d..0b01d9e 100644 --- a/src/test/java/de/rub/nds/crawler/dummy/DummyControllerCommandConfig.java +++ b/src/test/java/de/rub/nds/crawler/dummy/DummyControllerCommandConfig.java @@ -11,6 +11,7 @@ import de.rub.nds.crawler.config.ControllerCommandConfig; import de.rub.nds.crawler.core.BulkScanWorker; import de.rub.nds.crawler.data.ScanConfig; +import de.rub.nds.crawler.persistence.IPersistenceProvider; import de.rub.nds.scanner.core.config.ScannerDetail; public class DummyControllerCommandConfig extends ControllerCommandConfig { @@ -20,7 +21,10 @@ public ScanConfig getScanConfig() { return new ScanConfig(ScannerDetail.NORMAL, 1, 1, getExcludedProbes()) { @Override public BulkScanWorker createWorker( - String bulkScanID, int parallelConnectionThreads, int parallelScanThreads) { + String bulkScanID, + int parallelConnectionThreads, + int parallelScanThreads, + IPersistenceProvider persistenceProvider) { return null; } }; diff --git a/src/test/java/de/rub/nds/crawler/dummy/DummyOrchestrationProvider.java b/src/test/java/de/rub/nds/crawler/dummy/DummyOrchestrationProvider.java index 3dcf043..006aa25 100644 --- a/src/test/java/de/rub/nds/crawler/dummy/DummyOrchestrationProvider.java +++ b/src/test/java/de/rub/nds/crawler/dummy/DummyOrchestrationProvider.java @@ -19,6 +19,7 @@ import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import org.apache.logging.log4j.Logger; public class DummyOrchestrationProvider implements IOrchestrationProvider { @@ -86,4 +87,25 @@ public void notifyOfDoneScanJob(ScanJobDescription scanJobDescription) { public void closeConnection() { consumerThread.interrupt(); } + + /** + * Waits until the job queue reaches the expected size or timeout occurs. + * + * @param expectedSize The expected number of jobs in the queue + * @param timeout The maximum time to wait + * @param unit The time unit for the timeout + * @return true if the expected size was reached, false if timeout occurred + */ + public boolean waitForJobs(int expectedSize, long timeout, TimeUnit unit) + throws InterruptedException { + long deadlineNanos = System.nanoTime() + unit.toNanos(timeout); + while (jobQueue.size() < expectedSize) { + long remainingNanos = deadlineNanos - System.nanoTime(); + if (remainingNanos <= 0) { + return false; + } + Thread.sleep(Math.min(50, TimeUnit.NANOSECONDS.toMillis(remainingNanos))); + } + return true; + } } diff --git a/src/test/java/de/rub/nds/crawler/dummy/DummyPersistenceProvider.java b/src/test/java/de/rub/nds/crawler/dummy/DummyPersistenceProvider.java index 501b3d4..1bc3e9a 100644 --- a/src/test/java/de/rub/nds/crawler/dummy/DummyPersistenceProvider.java +++ b/src/test/java/de/rub/nds/crawler/dummy/DummyPersistenceProvider.java @@ -15,10 +15,12 @@ import java.util.ArrayList; import java.util.LinkedList; import java.util.List; +import org.bson.Document; public class DummyPersistenceProvider implements IPersistenceProvider { public final List results = new ArrayList<>(); public final List bulkScans = new ArrayList<>(); + public final List partialResults = new ArrayList<>(); @Override public void insertScanResult(ScanResult scanResult, ScanJobDescription job) { @@ -55,4 +57,9 @@ public ScanResult getScanResultByScanJobDescriptionId( .max((r1, r2) -> r1.getTimestamp().compareTo(r2.getTimestamp())) .orElse(null); } + + @Override + public void upsertPartialResult(ScanJobDescription job, Document partialResult) { + partialResults.add(partialResult); + } } diff --git a/src/test/java/de/rub/nds/crawler/dummy/DummyPersistenceProviderTest.java b/src/test/java/de/rub/nds/crawler/dummy/DummyPersistenceProviderTest.java index 2dc8740..c31af5d 100644 --- a/src/test/java/de/rub/nds/crawler/dummy/DummyPersistenceProviderTest.java +++ b/src/test/java/de/rub/nds/crawler/dummy/DummyPersistenceProviderTest.java @@ -17,6 +17,7 @@ import de.rub.nds.crawler.data.ScanJobDescription; import de.rub.nds.crawler.data.ScanResult; import de.rub.nds.crawler.data.ScanTarget; +import de.rub.nds.crawler.persistence.IPersistenceProvider; import de.rub.nds.scanner.core.config.ScannerDetail; import org.bson.Document; import org.junit.jupiter.api.BeforeEach; @@ -38,7 +39,8 @@ void setUp() { public BulkScanWorker createWorker( String bulkScanID, int parallelConnectionThreads, - int parallelScanThreads) { + int parallelScanThreads, + IPersistenceProvider persistenceProvider) { return null; } };