55import org .junit .platform .commons .support .ModifierSupport ;
66import org .junit .platform .commons .support .ReflectionSupport ;
77import org .junit .runner .Description ;
8+ import org .junit .runners .model .FrameworkMethod ;
89import org .junit .runners .model .MultipleFailureException ;
910import org .testcontainers .lifecycle .Startable ;
1011import org .testcontainers .lifecycle .TestDescription ;
1112import org .testcontainers .lifecycle .TestLifecycleAware ;
1213
14+ import java .lang .reflect .AnnotatedElement ;
1315import java .lang .reflect .Field ;
16+ import java .lang .reflect .Member ;
17+ import java .lang .reflect .Method ;
1418import java .util .ArrayList ;
1519import java .util .Collections ;
1620import java .util .List ;
2024import java .util .function .Consumer ;
2125import java .util .function .Predicate ;
2226import java .util .stream .Collectors ;
27+ import java .util .stream .Stream ;
2328
2429/**
2530 * Integrates Testcontainers with the JUnit4 lifecycle.
2631 */
2732public final class Testcontainers extends FailureDetectingExternalResource {
2833
34+ private static HierarchyTraversalMode TRAVERSAL_MODE = HierarchyTraversalMode .TOP_DOWN ;
35+
2936 private final Object testInstance ;
3037
3138 private List <Startable > startedContainers = Collections .emptyList ();
@@ -116,51 +123,82 @@ protected void finished(Description description) throws Exception {
116123 }
117124
118125 private List <Startable > findContainers (Description description ) {
119- if (description .getTestClass () == null ) {
126+ Class <?> testClass = description .getTestClass ();
127+ if (testClass == null ) {
120128 return Collections .emptyList ();
121129 }
122- Predicate <Field > isTargetedContainer = isContainer ();
123- if (testInstance == null ) {
124- isTargetedContainer = isTargetedContainer .and (ModifierSupport ::isStatic );
125- } else {
126- isTargetedContainer = isTargetedContainer .and (ModifierSupport ::isNotStatic );
127- }
128130
129- return ReflectionSupport
130- .findFields (description .getTestClass (), isTargetedContainer , HierarchyTraversalMode .TOP_DOWN )
131- .stream ()
132- .map (this ::getContainerInstance )
131+ Predicate <Member > hasExpectedModifier = testInstance == null
132+ ? ModifierSupport ::isStatic
133+ : ModifierSupport ::isNotStatic ;
134+
135+ return Stream
136+ .of (
137+ ReflectionSupport
138+ .findMethods (testClass , isContainerMethod ().and (hasExpectedModifier ), TRAVERSAL_MODE )
139+ .stream ()
140+ .map (this ::getContainerInstance ),
141+ ReflectionSupport
142+ .findFields (testClass , isContainerField ().and (hasExpectedModifier ), TRAVERSAL_MODE )
143+ .stream ()
144+ .map (this ::getContainerInstance )
145+ )
146+ .flatMap (s -> s )
133147 .collect (Collectors .toList ());
134148 }
135149
136- private static Predicate <Field > isContainer () {
137- return field -> {
138- boolean isAnnotatedWithContainer = AnnotationSupport .isAnnotated (field , Container .class );
139- if (isAnnotatedWithContainer ) {
140- boolean isStartable = Startable .class .isAssignableFrom (field .getType ());
150+ private static Predicate <Method > isContainerMethod () {
151+ return method -> isAnnotatedWithContainer (method );
152+ }
141153
142- if (!isStartable ) {
143- throw new RuntimeException (
144- String .format ("The @Container field '%s' does not implement Startable" , field .getName ())
145- );
146- }
147- return true ;
148- }
149- return false ;
150- };
154+ private static Predicate <Field > isContainerField () {
155+ return field -> isAnnotatedWithContainer (field );
156+ }
157+
158+ private static boolean isAnnotatedWithContainer (AnnotatedElement element ) {
159+ return AnnotationSupport .isAnnotated (element , Container .class );
160+ }
161+
162+ private Startable getContainerInstance (Method method ) {
163+ if (!Startable .class .isAssignableFrom (method .getReturnType ())) {
164+ throw new RuntimeException (
165+ String .format ("The @Container method '%s()' does not return a Startable" , method .getName ())
166+ );
167+ }
168+
169+ Object container = null ;
170+ try {
171+ method .setAccessible (true );
172+ container = new FrameworkMethod (method ).invokeExplosively (testInstance );
173+ } catch (Throwable e ) {
174+ throwUnchecked (e );
175+ }
176+
177+ if (container == null ) {
178+ throw new RuntimeException (String .format ("The @Container method '%s()' returned null" , method .getName ()));
179+ }
180+ return (Startable ) container ;
151181 }
152182
153183 private Startable getContainerInstance (Field field ) {
184+ if (!Startable .class .isAssignableFrom (field .getType ())) {
185+ throw new RuntimeException (
186+ String .format ("The @Container field '%s' does not implement Startable" , field .getName ())
187+ );
188+ }
189+
190+ Startable container = null ;
154191 try {
155192 field .setAccessible (true );
156- Startable containerInstance = (Startable ) field .get (testInstance );
157- if (containerInstance == null ) {
158- throw new RuntimeException ("Container " + field .getName () + " needs to be initialized" );
159- }
160- return containerInstance ;
193+ container = (Startable ) field .get (testInstance );
161194 } catch (IllegalAccessException e ) {
162- throw new RuntimeException ("Cannot access container defined in field " + field .getName ());
195+ throwUnchecked (e );
196+ }
197+
198+ if (container == null ) {
199+ throw new RuntimeException ("Container " + field .getName () + " needs to be initialized" );
163200 }
201+ return container ;
164202 }
165203
166204 private static <T > void forEachReversed (List <T > list , Consumer <? super T > callback ) {
@@ -169,4 +207,9 @@ private static <T> void forEachReversed(List<T> list, Consumer<? super T> callba
169207 callback .accept (iterator .previous ());
170208 }
171209 }
210+
211+ @ SuppressWarnings ("unchecked" )
212+ private static <T extends Throwable > void throwUnchecked (Throwable e ) throws T {
213+ throw (T ) e ;
214+ }
172215}
0 commit comments