errors) {
+ try {
+ finished(description);
+ } catch (Throwable e) {
+ failure.ifPresent(f -> f.addSuppressed(e)); // ifPresentOrElse() requires Java 9
+ if (!failure.isPresent()) {
+ errors.add(e);
+ }
+ }
+ }
+
+ protected static final TestDescription toTestDescription(Description description) {
+ return new TestDescription() {
+ @Override
+ public String getTestId() {
+ return description.getDisplayName();
+ }
+
+ @Override
+ public String getFilesystemFriendlyName() {
+ return description.getClassName() + "-" + description.getMethodName();
+ }
+ };
+ }
+}
diff --git a/modules/junit-vintage/src/main/java/org/testcontainers/junit/vintage/TemporaryNetwork.java b/modules/junit-vintage/src/main/java/org/testcontainers/junit/vintage/TemporaryNetwork.java
new file mode 100644
index 00000000000..0b43b0e4421
--- /dev/null
+++ b/modules/junit-vintage/src/main/java/org/testcontainers/junit/vintage/TemporaryNetwork.java
@@ -0,0 +1,66 @@
+package org.testcontainers.junit.vintage;
+
+import org.junit.rules.ExternalResource;
+import org.testcontainers.containers.Network;
+
+/**
+ * Integrates {@link Network} with the JUnit4 lifecycle.
+ */
+public final class TemporaryNetwork extends ExternalResource implements Network {
+
+ private final Network network;
+
+ private volatile State state = State.BEFORE_RULE;
+
+ /**
+ * Creates an instance.
+ *
+ * The passed-in network will be closed when the current test completes.
+ *
+ * @param network Network that the rule will delegate to.
+ */
+ public TemporaryNetwork(Network network) {
+ this.network = network;
+ }
+
+ @Override
+ public String getId() {
+ if (state == State.AFTER_RULE) {
+ throw new IllegalStateException("Cannot reference the network after the test completes");
+ }
+ return network.getId();
+ }
+
+ @Override
+ public void close() {
+ switch (state) {
+ case BEFORE_RULE:
+ throw new IllegalStateException("Cannot close the network before the test starts");
+ case INSIDE_RULE:
+ break;
+ case AFTER_RULE:
+ throw new IllegalStateException("Cannot reference the network after the test completes");
+ }
+ network.close();
+ }
+
+ @Override
+ protected void before() throws Throwable {
+ state = State.AFTER_RULE; // Just in case an exception is thrown below.
+ network.getId(); // This has the side-effect of creating the network.
+
+ state = State.INSIDE_RULE;
+ }
+
+ @Override
+ protected void after() {
+ state = State.AFTER_RULE;
+ network.close();
+ }
+
+ private enum State {
+ BEFORE_RULE,
+ INSIDE_RULE,
+ AFTER_RULE,
+ }
+}
diff --git a/modules/junit-vintage/src/main/java/org/testcontainers/junit/vintage/Testcontainers.java b/modules/junit-vintage/src/main/java/org/testcontainers/junit/vintage/Testcontainers.java
new file mode 100644
index 00000000000..0f43e5b1ade
--- /dev/null
+++ b/modules/junit-vintage/src/main/java/org/testcontainers/junit/vintage/Testcontainers.java
@@ -0,0 +1,171 @@
+package org.testcontainers.junit.vintage;
+
+import org.junit.platform.commons.support.AnnotationSupport;
+import org.junit.platform.commons.support.HierarchyTraversalMode;
+import org.junit.platform.commons.support.ModifierSupport;
+import org.junit.platform.commons.support.ReflectionSupport;
+import org.junit.runner.Description;
+import org.junit.runners.model.MultipleFailureException;
+import org.testcontainers.lifecycle.Startable;
+import org.testcontainers.lifecycle.TestDescription;
+import org.testcontainers.lifecycle.TestLifecycleAware;
+
+import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.ListIterator;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.function.Consumer;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+
+/**
+ * Integrates Testcontainers with the JUnit4 lifecycle.
+ */
+public final class Testcontainers extends FailureDetectingExternalResource {
+
+ private final Object testInstance;
+
+ private List startedContainers = Collections.emptyList();
+
+ private List lifecycleAwareContainers = Collections.emptyList();
+
+ /**
+ * Constructs an instance for use by {@code @Rule}.
+ *
+ * @param testInstance instance of the current test.
+ */
+ public Testcontainers(Object testInstance) {
+ this.testInstance = Objects.requireNonNull(testInstance);
+ }
+
+ /**
+ * Constructs an instance for use by {@code @ClassRule}.
+ */
+ public Testcontainers() {
+ testInstance = null;
+ }
+
+ @Override
+ protected void starting(Description description) {
+ if (description.isTest()) {
+ if (testInstance == null) {
+ throw new RuntimeException("Testcontainers used as a @Rule without being provided a test instance");
+ }
+ } else if (testInstance != null) {
+ throw new RuntimeException("Testcontainers used as a @ClassRule but was provided a test instance");
+ }
+
+ List containers = findContainers(description);
+ startedContainers = new ArrayList<>(containers.size());
+ containers.forEach(startable -> {
+ startable.start();
+ startedContainers.add(startable);
+ });
+
+ lifecycleAwareContainers =
+ startedContainers
+ .stream()
+ .filter(startable -> startable instanceof TestLifecycleAware)
+ .map(TestLifecycleAware.class::cast)
+ .collect(Collectors.toList());
+ if (!lifecycleAwareContainers.isEmpty()) {
+ TestDescription testDescription = toTestDescription(description);
+ lifecycleAwareContainers.forEach(container -> container.beforeTest(testDescription));
+ }
+ }
+
+ @Override
+ protected void succeeded(Description description) {
+ if (!lifecycleAwareContainers.isEmpty()) {
+ TestDescription testDescription = toTestDescription(description);
+ forEachReversed(
+ lifecycleAwareContainers,
+ container -> container.afterTest(testDescription, Optional.empty())
+ );
+ }
+ }
+
+ @Override
+ protected void failed(Throwable e, Description description) {
+ if (!lifecycleAwareContainers.isEmpty()) {
+ TestDescription testDescription = toTestDescription(description);
+ Optional exception = Optional.of(e);
+ forEachReversed(lifecycleAwareContainers, container -> container.afterTest(testDescription, exception));
+ }
+ }
+
+ @Override
+ protected void finished(Description description) throws Exception {
+ List errors = new ArrayList();
+
+ forEachReversed(
+ startedContainers,
+ startable -> {
+ try {
+ startable.stop();
+ } catch (Throwable e) {
+ errors.add(e);
+ }
+ }
+ );
+
+ MultipleFailureException.assertEmpty(errors);
+ }
+
+ private List findContainers(Description description) {
+ if (description.getTestClass() == null) {
+ return Collections.emptyList();
+ }
+ Predicate isTargetedContainerField = isContainerField();
+ if (testInstance == null) {
+ isTargetedContainerField = isTargetedContainerField.and(ModifierSupport::isStatic);
+ } else {
+ isTargetedContainerField = isTargetedContainerField.and(ModifierSupport::isNotStatic);
+ }
+
+ return ReflectionSupport
+ .streamFields(description.getTestClass(), isTargetedContainerField, HierarchyTraversalMode.TOP_DOWN)
+ .map(this::getContainerInstance)
+ .collect(Collectors.toList());
+ }
+
+ private static Predicate isContainerField() {
+ return field -> {
+ boolean isAnnotatedWithContainer = AnnotationSupport.isAnnotated(field, Container.class);
+ if (isAnnotatedWithContainer) {
+ boolean isStartable = Startable.class.isAssignableFrom(field.getType());
+
+ if (!isStartable) {
+ throw new RuntimeException(
+ String.format("The @Container field '%s' does not implement Startable", field.getName())
+ );
+ }
+ return true;
+ }
+ return false;
+ };
+ }
+
+ private Startable getContainerInstance(Field field) {
+ try {
+ field.setAccessible(true);
+ Startable containerInstance = (Startable) field.get(testInstance);
+ if (containerInstance == null) {
+ throw new RuntimeException("Container " + field.getName() + " needs to be initialized");
+ }
+ return containerInstance;
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException("Cannot access container defined in field " + field.getName());
+ }
+ }
+
+ private static void forEachReversed(List list, Consumer super T> callback) {
+ ListIterator iterator = list.listIterator(list.size());
+ while (iterator.hasPrevious()) {
+ callback.accept(iterator.previous());
+ }
+ }
+}
diff --git a/modules/junit-vintage/src/test/java/org/testcontainers/junit/vintage/TestLifecycleAwareContainerMock.java b/modules/junit-vintage/src/test/java/org/testcontainers/junit/vintage/TestLifecycleAwareContainerMock.java
new file mode 100644
index 00000000000..30dc0283c96
--- /dev/null
+++ b/modules/junit-vintage/src/test/java/org/testcontainers/junit/vintage/TestLifecycleAwareContainerMock.java
@@ -0,0 +1,60 @@
+package org.testcontainers.junit.vintage;
+
+import org.testcontainers.lifecycle.Startable;
+import org.testcontainers.lifecycle.TestDescription;
+import org.testcontainers.lifecycle.TestLifecycleAware;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+public class TestLifecycleAwareContainerMock implements Startable, TestLifecycleAware {
+
+ static final String START = "start";
+
+ static final String BEFORE_TEST = "beforeTest";
+
+ static final String AFTER_TEST = "afterTest";
+
+ static final String STOP = "stop";
+
+ private final List lifecycleMethodCalls = new ArrayList<>();
+
+ private final List lifecycleFilesystemFriendlyNames = new ArrayList<>();
+
+ private Throwable capturedThrowable;
+
+ @Override
+ public void beforeTest(TestDescription description) {
+ lifecycleMethodCalls.add(BEFORE_TEST);
+ lifecycleFilesystemFriendlyNames.add(description.getFilesystemFriendlyName());
+ }
+
+ @Override
+ public void afterTest(TestDescription description, Optional throwable) {
+ lifecycleMethodCalls.add(AFTER_TEST);
+ throwable.ifPresent(capturedThrowable -> this.capturedThrowable = capturedThrowable);
+ }
+
+ List getLifecycleMethodCalls() {
+ return lifecycleMethodCalls;
+ }
+
+ Throwable getCapturedThrowable() {
+ return capturedThrowable;
+ }
+
+ public List getLifecycleFilesystemFriendlyNames() {
+ return lifecycleFilesystemFriendlyNames;
+ }
+
+ @Override
+ public void start() {
+ lifecycleMethodCalls.add(START);
+ }
+
+ @Override
+ public void stop() {
+ lifecycleMethodCalls.add(STOP);
+ }
+}
diff --git a/modules/junit-vintage/src/test/java/org/testcontainers/junit/vintage/TestcontainersTest.java b/modules/junit-vintage/src/test/java/org/testcontainers/junit/vintage/TestcontainersTest.java
new file mode 100644
index 00000000000..0d4c49c8058
--- /dev/null
+++ b/modules/junit-vintage/src/test/java/org/testcontainers/junit/vintage/TestcontainersTest.java
@@ -0,0 +1,206 @@
+package org.testcontainers.junit.vintage;
+
+import org.junit.After;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.RuleChain;
+import org.junit.rules.TestRule;
+import org.junit.rules.TestWatcher;
+import org.junit.runner.Description;
+import org.junit.runner.JUnitCore;
+import org.junit.runner.Result;
+import org.junit.runner.notification.Failure;
+import org.junit.runners.model.MultipleFailureException;
+import org.junit.runners.model.Statement;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+
+public class TestcontainersTest {
+
+ private static final Statement PASSING_STATEMENT = new Statement() {
+ @Override
+ public void evaluate() throws Throwable {}
+ };
+
+ private static final Statement FAILING_STATEMENT = new Statement() {
+ @Override
+ public void evaluate() throws Throwable {
+ throw new TestException();
+ }
+ };
+
+ @After
+ public void resetIntegrationTest() {
+ IntegrationTest.reset();
+ }
+
+ @Test
+ public void statementDelegateDoesNotThrow() throws Throwable {
+ // Arrange
+ FakeTest fakeTest = new FakeTest();
+ Testcontainers containers = new Testcontainers(fakeTest);
+ Statement statement = containers.apply(PASSING_STATEMENT, FakeTest.TEST_DESCRIPTION);
+ assertThat(fakeTest.testContainer.getLifecycleMethodCalls()).isEmpty();
+
+ // Act - evaluate statement
+ statement.evaluate();
+
+ // Assert
+ assertThat(fakeTest.testContainer.getLifecycleMethodCalls())
+ .containsExactly(
+ TestLifecycleAwareContainerMock.START,
+ TestLifecycleAwareContainerMock.BEFORE_TEST,
+ TestLifecycleAwareContainerMock.AFTER_TEST,
+ TestLifecycleAwareContainerMock.STOP
+ );
+ assertThat(fakeTest.testContainer.getCapturedThrowable()).isNull();
+ assertThat(fakeTest.testContainer.getLifecycleFilesystemFriendlyNames()).isEqualTo(FakeTest.FRIENDLY_NAMES);
+ }
+
+ @Test
+ public void statementDelegateThrows() {
+ // Arrange
+ FakeTest fakeTest = new FakeTest();
+ Testcontainers containers = new Testcontainers(fakeTest);
+ Statement statement = containers.apply(FAILING_STATEMENT, FakeTest.TEST_DESCRIPTION);
+ assertThat(fakeTest.testContainer.getLifecycleMethodCalls()).isEmpty();
+
+ // Act - evaluate statement
+ assertThatExceptionOfType(TestException.class).isThrownBy(statement::evaluate);
+
+ // Assert
+ assertThat(fakeTest.testContainer.getLifecycleMethodCalls())
+ .containsExactly(
+ TestLifecycleAwareContainerMock.START,
+ TestLifecycleAwareContainerMock.BEFORE_TEST,
+ TestLifecycleAwareContainerMock.AFTER_TEST,
+ TestLifecycleAwareContainerMock.STOP
+ );
+ assertThat(fakeTest.testContainer.getCapturedThrowable()).isNotNull();
+ assertThat(fakeTest.testContainer.getLifecycleFilesystemFriendlyNames()).isEqualTo(FakeTest.FRIENDLY_NAMES);
+ }
+
+ @Test
+ public void integrationTestsPass() throws Exception {
+ IntegrationTest.enabled = true;
+
+ Result result = JUnitCore.runClasses(IntegrationTest.class);
+
+ verifyNoFailures(result);
+ assertThat(IntegrationTest.testsStarted)
+ .withFailMessage("No tests in IntegrationTests were run")
+ .isGreaterThan(0);
+ }
+
+ /** Test class used for tests that directly call the Testcontainers rule. */
+ private static class FakeTest {
+
+ static final Description TEST_DESCRIPTION = Description.createTestDescription(FakeTest.class, "boom");
+
+ static final List FRIENDLY_NAMES = Collections.singletonList(
+ FailureDetectingExternalResource.toTestDescription(TEST_DESCRIPTION).getFilesystemFriendlyName()
+ );
+
+ @Container
+ private final TestLifecycleAwareContainerMock testContainer = new TestLifecycleAwareContainerMock();
+ }
+
+ /** Integration tests for verifying behavior around container discovery. */
+ public static class IntegrationTest {
+
+ static final List FRIENDLY_NAMES = Collections.singletonList(
+ FailureDetectingExternalResource
+ .toTestDescription(Description.createTestDescription(IntegrationTest.class, "containerStarted"))
+ .getFilesystemFriendlyName()
+ );
+
+ static int testsStarted = 0;
+
+ static boolean enabled = false;
+
+ /** Ensures that the tests in this class are not run directly by gradle. */
+ final TestRule skipWhenDisabled = new TestRule() {
+ @Override
+ public final Statement apply(Statement base, Description description) {
+ return enabled ? base : PASSING_STATEMENT;
+ }
+ };
+
+ /** The class under test; this isn't annotated by @Rule because it is run by "testRuleChain". */
+ final Testcontainers containers = new Testcontainers(this);
+
+ final TestRule verifyPreAndPostConditions = new TestWatcher() {
+ @Override
+ protected void starting(Description description) {
+ forEachTestContainer(container -> {
+ assertThat(container.getLifecycleMethodCalls()).isEmpty();
+ });
+ }
+
+ @Override
+ protected void finished(Description description) {
+ forEachTestContainer(container -> {
+ assertThat(container.getLifecycleMethodCalls())
+ .containsExactly(
+ TestLifecycleAwareContainerMock.START,
+ TestLifecycleAwareContainerMock.BEFORE_TEST,
+ TestLifecycleAwareContainerMock.AFTER_TEST,
+ TestLifecycleAwareContainerMock.STOP
+ );
+ });
+ }
+ };
+
+ @Rule
+ public final RuleChain testRuleChain = RuleChain
+ .outerRule(skipWhenDisabled)
+ .around(verifyPreAndPostConditions)
+ .around(containers);
+
+ static void reset() {
+ enabled = false;
+ testsStarted = 0;
+ }
+
+ @Container
+ private final TestLifecycleAwareContainerMock testContainer1 = new TestLifecycleAwareContainerMock();
+
+ @Container
+ private final TestLifecycleAwareContainerMock testContainer2 = new TestLifecycleAwareContainerMock();
+
+ @Test
+ public void containerStarted() {
+ testsStarted++;
+
+ forEachTestContainer(container -> {
+ assertThat(container.getLifecycleMethodCalls())
+ .containsExactly(
+ TestLifecycleAwareContainerMock.START,
+ TestLifecycleAwareContainerMock.BEFORE_TEST
+ );
+ });
+ }
+
+ private void forEachTestContainer(Consumer callback) {
+ callback.accept(testContainer1);
+ callback.accept(testContainer2);
+ }
+ }
+
+ private static void verifyNoFailures(Result result) throws Exception {
+ List exceptions = result
+ .getFailures()
+ .stream()
+ .map(Failure::getException)
+ .collect(Collectors.toList());
+ MultipleFailureException.assertEmpty(exceptions);
+ }
+
+ static class TestException extends RuntimeException {}
+}