Skip to content

Commit 2b5a2ee

Browse files
committed
Address Observation Bean Name Collisions
Closes gh-16161
1 parent a550215 commit 2b5a2ee

File tree

6 files changed

+213
-23
lines changed

6 files changed

+213
-23
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* Copyright 2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.config.annotation.rsocket;
18+
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
22+
import io.rsocket.core.RSocketServer;
23+
import io.rsocket.exceptions.RejectedSetupException;
24+
import io.rsocket.frame.decoder.PayloadDecoder;
25+
import io.rsocket.transport.netty.server.CloseableChannel;
26+
import io.rsocket.transport.netty.server.TcpServerTransport;
27+
import org.junit.jupiter.api.AfterEach;
28+
import org.junit.jupiter.api.BeforeEach;
29+
import org.junit.jupiter.api.Test;
30+
import org.junit.jupiter.api.extension.ExtendWith;
31+
32+
import org.springframework.beans.factory.annotation.Autowired;
33+
import org.springframework.context.annotation.Bean;
34+
import org.springframework.context.annotation.Configuration;
35+
import org.springframework.messaging.handler.annotation.MessageMapping;
36+
import org.springframework.messaging.rsocket.RSocketRequester;
37+
import org.springframework.messaging.rsocket.RSocketStrategies;
38+
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
39+
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
40+
import org.springframework.security.core.userdetails.MapReactiveUserDetailsService;
41+
import org.springframework.security.core.userdetails.User;
42+
import org.springframework.security.core.userdetails.UserDetails;
43+
import org.springframework.security.rsocket.core.SecuritySocketAcceptorInterceptor;
44+
import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder;
45+
import org.springframework.stereotype.Controller;
46+
import org.springframework.test.context.ContextConfiguration;
47+
import org.springframework.test.context.junit.jupiter.SpringExtension;
48+
49+
import static org.assertj.core.api.Assertions.assertThat;
50+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
51+
52+
/**
53+
* @author Rob Winch
54+
*/
55+
@ContextConfiguration
56+
@ExtendWith(SpringExtension.class)
57+
public class HelloRSocketWithWebFluxITests {
58+
59+
@Autowired
60+
RSocketMessageHandler handler;
61+
62+
@Autowired
63+
SecuritySocketAcceptorInterceptor interceptor;
64+
65+
@Autowired
66+
ServerController controller;
67+
68+
private CloseableChannel server;
69+
70+
private RSocketRequester requester;
71+
72+
@BeforeEach
73+
public void setup() {
74+
// @formatter:off
75+
this.server = RSocketServer.create()
76+
.payloadDecoder(PayloadDecoder.ZERO_COPY)
77+
.interceptors((registry) ->
78+
registry.forSocketAcceptor(this.interceptor)
79+
)
80+
.acceptor(this.handler.responder())
81+
.bind(TcpServerTransport.create("localhost", 0))
82+
.block();
83+
// @formatter:on
84+
}
85+
86+
@AfterEach
87+
public void dispose() {
88+
this.requester.rsocket().dispose();
89+
this.server.dispose();
90+
this.controller.payloads.clear();
91+
}
92+
93+
// gh-16161
94+
@Test
95+
public void retrieveMonoWhenSecureThenDenied() {
96+
// @formatter:off
97+
this.requester = RSocketRequester.builder()
98+
.rsocketStrategies(this.handler.getRSocketStrategies())
99+
.connectTcp("localhost", this.server.address().getPort())
100+
.block();
101+
// @formatter:on
102+
String data = "rob";
103+
// @formatter:off
104+
assertThatExceptionOfType(Exception.class).isThrownBy(
105+
() -> this.requester.route("secure.retrieve-mono")
106+
.data(data)
107+
.retrieveMono(String.class)
108+
.block()
109+
)
110+
.matches((ex) -> ex instanceof RejectedSetupException
111+
|| ex.getClass().toString().contains("ReactiveException"));
112+
// @formatter:on
113+
assertThat(this.controller.payloads).isEmpty();
114+
}
115+
116+
@Configuration
117+
@EnableRSocketSecurity
118+
@EnableWebFluxSecurity
119+
static class Config {
120+
121+
@Bean
122+
ServerController controller() {
123+
return new ServerController();
124+
}
125+
126+
@Bean
127+
RSocketMessageHandler messageHandler() {
128+
RSocketMessageHandler handler = new RSocketMessageHandler();
129+
handler.setRSocketStrategies(rsocketStrategies());
130+
return handler;
131+
}
132+
133+
@Bean
134+
RSocketStrategies rsocketStrategies() {
135+
return RSocketStrategies.builder().encoder(new BasicAuthenticationEncoder()).build();
136+
}
137+
138+
@Bean
139+
MapReactiveUserDetailsService uds() {
140+
// @formatter:off
141+
UserDetails rob = User.withDefaultPasswordEncoder()
142+
.username("rob")
143+
.password("password")
144+
.roles("USER", "ADMIN")
145+
.build();
146+
// @formatter:on
147+
return new MapReactiveUserDetailsService(rob);
148+
}
149+
150+
}
151+
152+
@Controller
153+
static class ServerController {
154+
155+
private List<String> payloads = new ArrayList<>();
156+
157+
@MessageMapping("**")
158+
String retrieveMono(String payload) {
159+
add(payload);
160+
return "Hi " + payload;
161+
}
162+
163+
private void add(String p) {
164+
this.payloads.add(p);
165+
}
166+
167+
}
168+
169+
}

config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.security.config.annotation.rsocket;
1818

19+
import java.util.Map;
20+
1921
import org.springframework.beans.factory.annotation.Autowired;
2022
import org.springframework.context.ApplicationContext;
2123
import org.springframework.context.annotation.Bean;
@@ -62,8 +64,12 @@ void setPasswordEncoder(PasswordEncoder passwordEncoder) {
6264
}
6365

6466
@Autowired(required = false)
65-
void setAuthenticationManagerPostProcessor(ObjectPostProcessor<ReactiveAuthenticationManager> postProcessor) {
66-
this.postProcessor = postProcessor;
67+
void setAuthenticationManagerPostProcessor(
68+
Map<String, ObjectPostProcessor<ReactiveAuthenticationManager>> postProcessors) {
69+
if (postProcessors.size() == 1) {
70+
this.postProcessor = postProcessors.values().iterator().next();
71+
}
72+
this.postProcessor = postProcessors.get("rSocketAuthenticationManagerPostProcessor");
6773
}
6874

6975
@Bean(name = RSOCKET_SECURITY_BEAN_NAME)

config/src/main/java/org/springframework/security/config/annotation/rsocket/ReactiveObservationConfiguration.java

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
import org.springframework.security.authorization.ReactiveAuthorizationManager;
3030
import org.springframework.security.config.ObjectPostProcessor;
3131
import org.springframework.security.config.observation.SecurityObservationSettings;
32-
import org.springframework.security.web.server.ObservationWebFilterChainDecorator;
33-
import org.springframework.security.web.server.WebFilterChainProxy.WebFilterChainDecorator;
34-
import org.springframework.web.server.ServerWebExchange;
32+
import org.springframework.security.rsocket.api.PayloadExchange;
3533

3634
@Configuration(proxyBeanMethods = false)
3735
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
@@ -45,7 +43,7 @@ class ReactiveObservationConfiguration {
4543

4644
@Bean
4745
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
48-
static ObjectPostProcessor<ReactiveAuthorizationManager<ServerWebExchange>> rSocketAuthorizationManagerPostProcessor(
46+
static ObjectPostProcessor<ReactiveAuthorizationManager<PayloadExchange>> rSocketAuthorizationManagerPostProcessor(
4947
ObjectProvider<ObservationRegistry> registry, ObjectProvider<SecurityObservationSettings> predicate) {
5048
return new ObjectPostProcessor<>() {
5149
@Override
@@ -71,18 +69,4 @@ public ReactiveAuthenticationManager postProcess(ReactiveAuthenticationManager o
7169
};
7270
}
7371

74-
@Bean
75-
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
76-
static ObjectPostProcessor<WebFilterChainDecorator> rSocketFilterChainDecoratorPostProcessor(
77-
ObjectProvider<ObservationRegistry> registry, ObjectProvider<SecurityObservationSettings> predicate) {
78-
return new ObjectPostProcessor<>() {
79-
@Override
80-
public WebFilterChainDecorator postProcess(WebFilterChainDecorator object) {
81-
ObservationRegistry r = registry.getIfUnique(() -> ObservationRegistry.NOOP);
82-
boolean active = !r.isNoop() && predicate.getIfUnique(() -> all).shouldObserveRequests();
83-
return active ? new ObservationWebFilterChainDecorator(r) : object;
84-
}
85-
};
86-
}
87-
8872
}

config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveObservationConfiguration.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public ReactiveAuthorizationManager postProcess(ReactiveAuthorizationManager obj
5959

6060
@Bean
6161
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
62-
static ObjectPostProcessor<ReactiveAuthenticationManager> authenticationManagerPostProcessor(
62+
static ObjectPostProcessor<ReactiveAuthenticationManager> reactiveAuthenticationManagerPostProcessor(
6363
ObjectProvider<ObservationRegistry> registry, ObjectProvider<SecurityObservationSettings> predicate) {
6464
return new ObjectPostProcessor<>() {
6565
@Override

config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.security.config.annotation.web.reactive;
1818

19+
import java.util.Map;
20+
1921
import org.springframework.beans.BeansException;
2022
import org.springframework.beans.factory.BeanFactory;
2123
import org.springframework.beans.factory.ObjectProvider;
@@ -96,8 +98,12 @@ void setUserDetailsPasswordService(ReactiveUserDetailsPasswordService userDetail
9698
}
9799

98100
@Autowired(required = false)
99-
void setAuthenticationManagerPostProcessor(ObjectPostProcessor<ReactiveAuthenticationManager> postProcessor) {
100-
this.postProcessor = postProcessor;
101+
void setAuthenticationManagerPostProcessor(
102+
Map<String, ObjectPostProcessor<ReactiveAuthenticationManager>> postProcessors) {
103+
if (postProcessors.size() == 1) {
104+
this.postProcessor = postProcessors.values().iterator().next();
105+
}
106+
this.postProcessor = postProcessors.get("reactiveAuthenticationManagerPostProcessor");
101107
}
102108

103109
@Autowired(required = false)

config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,31 @@ public void getWhenUsingObservationRegistryThenObservesRequest() {
242242
assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain after");
243243
}
244244

245+
// gh-16161
246+
@Test
247+
public void getWhenUsingRSocketThenObservesRequest() {
248+
this.spring.register(ObservationRegistryConfig.class, RSocketSecurityConfig.class).autowire();
249+
// @formatter:off
250+
this.webClient
251+
.get()
252+
.uri("/hello")
253+
.headers((headers) -> headers.setBasicAuth("user", "password"))
254+
.exchange()
255+
.expectStatus()
256+
.isNotFound();
257+
// @formatter:on
258+
ObservationHandler<Observation.Context> handler = this.spring.getContext().getBean(ObservationHandler.class);
259+
ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
260+
verify(handler, times(6)).onStart(captor.capture());
261+
Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
262+
assertThat(contexts.next().getContextualName()).isEqualTo("http get");
263+
assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain before");
264+
assertThat(contexts.next().getName()).isEqualTo("spring.security.authentications");
265+
assertThat(contexts.next().getName()).isEqualTo("spring.security.authorizations");
266+
assertThat(contexts.next().getName()).isEqualTo("spring.security.http.secured.requests");
267+
assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain after");
268+
}
269+
245270
@Configuration
246271
static class SubclassConfig extends ServerHttpSecurityConfiguration {
247272

0 commit comments

Comments
 (0)