diff --git a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/service/connection/BeanOrigin.java b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/service/connection/BeanOrigin.java index 701a38d1c482..ca071c461f49 100644 --- a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/service/connection/BeanOrigin.java +++ b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/service/connection/BeanOrigin.java @@ -25,18 +25,26 @@ * {@link Origin} backed by a Spring Bean. * * @author Phillip Webb + * @author Yanming Zhou */ class BeanOrigin implements Origin { private final String beanName; + private final BeanDefinition beanDefinition; + private final String resourceDescription; BeanOrigin(String beanName, BeanDefinition beanDefinition) { this.beanName = beanName; + this.beanDefinition = beanDefinition; this.resourceDescription = (beanDefinition != null) ? beanDefinition.getResourceDescription() : null; } + BeanDefinition getBeanDefinition() { + return this.beanDefinition; + } + @Override public boolean equals(Object obj) { if (this == obj) { diff --git a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/service/connection/ConnectionDetailsRegistrar.java b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/service/connection/ConnectionDetailsRegistrar.java index 491d61c5d6af..06839fbb46ac 100644 --- a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/service/connection/ConnectionDetailsRegistrar.java +++ b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/service/connection/ConnectionDetailsRegistrar.java @@ -29,6 +29,7 @@ import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; @@ -49,6 +50,7 @@ * @author Moritz Halbritter * @author Andy Wilkinson * @author Phillip Webb + * @author Yanming Zhou */ class ConnectionDetailsRegistrar { @@ -109,9 +111,22 @@ private void registerBeanDefinition(BeanDefinitionRegistry registry, Contain RootBeanDefinition beanDefinition = new RootBeanDefinition(beanType, beanSupplier); beanDefinition.setAttribute(ServiceConnection.class.getName(), true); containerMetadata.addTo(beanDefinition); + if (source.getOrigin() instanceof BeanOrigin beanOrigin) { + inheritQualifiers(beanOrigin.getBeanDefinition(), beanDefinition); + } registry.registerBeanDefinition(beanName, beanDefinition); } + private void inheritQualifiers(BeanDefinition origin, RootBeanDefinition derived) { + derived.setPrimary(origin.isPrimary()); + derived.setFallback(origin.isFallback()); + derived.setAutowireCandidate(origin.isAutowireCandidate()); + if (origin instanceof RootBeanDefinition rbd) { + derived.setDefaultCandidate(rbd.isDefaultCandidate()); + derived.setQualifiedElement(rbd.getQualifiedElement()); + } + } + private String getBeanName(ContainerConnectionSource source, ConnectionDetails connectionDetails) { List parts = new ArrayList<>(); parts.add(ClassUtils.getShortNameAsProperty(connectionDetails.getClass())); diff --git a/spring-boot-project/spring-boot-testcontainers/src/test/java/org/springframework/boot/testcontainers/service/connection/ConnectionDetailsRegistrarTests.java b/spring-boot-project/spring-boot-testcontainers/src/test/java/org/springframework/boot/testcontainers/service/connection/ConnectionDetailsRegistrarTests.java index d40ce9664c35..6a020edbc86d 100644 --- a/spring-boot-project/spring-boot-testcontainers/src/test/java/org/springframework/boot/testcontainers/service/connection/ConnectionDetailsRegistrarTests.java +++ b/spring-boot-project/spring-boot-testcontainers/src/test/java/org/springframework/boot/testcontainers/service/connection/ConnectionDetailsRegistrarTests.java @@ -40,6 +40,7 @@ * Tests for {@link ConnectionDetailsRegistrar}. * * @author Phillip Webb + * @author Yanming Zhou */ class ConnectionDetailsRegistrarTests { @@ -106,6 +107,32 @@ void registerBeanDefinitionsRegistersDefinition() { assertThat(beanFactory.getBean(TestConnectionDetails.class)).isNotNull(); } + @Test + void containerConnectionDetailsBeanShouldInheritQualifiersFromContainerBean() { + RootBeanDefinition originBeanDefinition = new RootBeanDefinition(); + originBeanDefinition.setPrimary(true); + originBeanDefinition.setFallback(false); + originBeanDefinition.setAutowireCandidate(true); + originBeanDefinition.setDefaultCandidate(true); + originBeanDefinition.setQualifiedElement(ConnectionDetailsRegistrarTests.class); + Origin origin = new BeanOrigin("test", originBeanDefinition); + ContainerConnectionSource source = new ContainerConnectionSource<>("test", origin, PostgreSQLContainer.class, + null, this.annotation, () -> this.container); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + ConnectionDetailsRegistrar registrar = new ConnectionDetailsRegistrar(beanFactory, this.factories); + given(this.factories.getConnectionDetails(source, true)) + .willReturn(Map.of(TestConnectionDetails.class, new TestConnectionDetails())); + registrar.registerBeanDefinitions(beanFactory, source); + String[] beanNames = beanFactory.getBeanNamesForType(TestConnectionDetails.class); + assertThat(beanNames).hasSize(1); + RootBeanDefinition beanDefinition = (RootBeanDefinition) beanFactory.getBeanDefinition(beanNames[0]); + assertThat(beanDefinition.isPrimary()).isEqualTo(originBeanDefinition.isPrimary()); + assertThat(beanDefinition.isFallback()).isEqualTo(originBeanDefinition.isFallback()); + assertThat(beanDefinition.isAutowireCandidate()).isEqualTo(originBeanDefinition.isAutowireCandidate()); + assertThat(beanDefinition.isDefaultCandidate()).isEqualTo(originBeanDefinition.isDefaultCandidate()); + assertThat(beanDefinition.getQualifiedElement()).isEqualTo(originBeanDefinition.getQualifiedElement()); + } + static class TestConnectionDetails implements ConnectionDetails { }