Skip to content

Commit 5dce82c

Browse files
committed
Close Both Observations
Depending on when a request is cancelled, the before and after observation starts and stops may be called out of order due to the order in which their doOnCancel handlers are invoked. To address this, the before filter-wrapper now always closes both the before observation and the after observation. Since the before filter- wrapper wraps the entire request, this ensures that either that was started is stopped, and either that has not been started yet cannot inadvertently be started by any unexpected ordering of events that follows. Closes gh-14031
1 parent 30e3f9d commit 5dce82c

File tree

2 files changed

+140
-4
lines changed

2 files changed

+140
-4
lines changed

web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,13 @@ public Observation error(Throwable ex) {
292292

293293
@Override
294294
public void stop() {
295-
this.currentObservation.get().stop();
295+
this.before.stop();
296+
this.after.stop();
297+
}
298+
299+
private void close() {
300+
this.before.close();
301+
this.after.close();
296302
}
297303

298304
@Override
@@ -357,11 +363,11 @@ public WebFilter wrap(WebFilter filter) {
357363
start();
358364
// @formatter:off
359365
return filter.filter(exchange, chain)
360-
.doOnSuccess((v) -> stop())
361-
.doOnCancel(this::stop)
366+
.doOnSuccess((v) -> close())
367+
.doOnCancel(this::close)
362368
.doOnError((t) -> {
363369
error(t);
364-
stop();
370+
close();
365371
})
366372
.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, this));
367373
// @formatter:on
@@ -433,6 +439,21 @@ private void stop() {
433439
}
434440
}
435441

442+
private void close() {
443+
try {
444+
this.lock.lock();
445+
if (this.state.compareAndSet(1, 3)) {
446+
this.observation.stop();
447+
}
448+
else {
449+
this.state.set(3);
450+
}
451+
}
452+
finally {
453+
this.lock.unlock();
454+
}
455+
}
456+
436457
}
437458

438459
}

web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,98 @@ void decorateWhenNoopThenDoesNotObserve() {
7878
verifyNoInteractions(handler);
7979
}
8080

81+
@Test
82+
void decorateWhenTerminatingFilterThenObserves() {
83+
AccumulatingObservationHandler handler = new AccumulatingObservationHandler();
84+
ObservationRegistry registry = ObservationRegistry.create();
85+
registry.observationConfig().observationHandler(handler);
86+
ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
87+
WebFilterChain chain = mock(WebFilterChain.class);
88+
given(chain.filter(any())).willReturn(Mono.error(() -> new Exception("ack")));
89+
WebFilterChain decorated = decorator.decorate(chain,
90+
List.of(new BasicAuthenticationFilter(), new TerminatingFilter()));
91+
Observation http = Observation.start("http", registry).contextualName("http");
92+
try {
93+
decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build()))
94+
.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http))
95+
.block();
96+
}
97+
catch (Exception ex) {
98+
http.error(ex);
99+
}
100+
finally {
101+
http.stop();
102+
}
103+
handler.assertSpanStart(0, "http", null);
104+
handler.assertSpanStart(1, "spring.security.filterchains", "http");
105+
handler.assertSpanStop(2, "security filterchain before");
106+
handler.assertSpanStart(3, "spring.security.filterchains", "http");
107+
handler.assertSpanStop(4, "security filterchain after");
108+
handler.assertSpanStop(5, "http");
109+
}
110+
111+
@Test
112+
void decorateWhenFilterErrorThenStopsObservation() {
113+
AccumulatingObservationHandler handler = new AccumulatingObservationHandler();
114+
ObservationRegistry registry = ObservationRegistry.create();
115+
registry.observationConfig().observationHandler(handler);
116+
ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
117+
WebFilterChain chain = mock(WebFilterChain.class);
118+
WebFilterChain decorated = decorator.decorate(chain, List.of(new ErroringFilter()));
119+
Observation http = Observation.start("http", registry).contextualName("http");
120+
try {
121+
decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build()))
122+
.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http))
123+
.block();
124+
}
125+
catch (Exception ex) {
126+
http.error(ex);
127+
}
128+
finally {
129+
http.stop();
130+
}
131+
handler.assertSpanStart(0, "http", null);
132+
handler.assertSpanStart(1, "spring.security.filterchains", "http");
133+
handler.assertSpanError(2);
134+
handler.assertSpanStop(3, "security filterchain before");
135+
handler.assertSpanError(4);
136+
handler.assertSpanStop(5, "http");
137+
}
138+
139+
@Test
140+
void decorateWhenErrorSignalThenStopsObservation() {
141+
AccumulatingObservationHandler handler = new AccumulatingObservationHandler();
142+
ObservationRegistry registry = ObservationRegistry.create();
143+
registry.observationConfig().observationHandler(handler);
144+
ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
145+
WebFilterChain chain = mock(WebFilterChain.class);
146+
given(chain.filter(any())).willReturn(Mono.error(() -> new Exception("ack")));
147+
WebFilterChain decorated = decorator.decorate(chain, List.of(new BasicAuthenticationFilter()));
148+
Observation http = Observation.start("http", registry).contextualName("http");
149+
try {
150+
decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build()))
151+
.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http))
152+
.block();
153+
}
154+
catch (Exception ex) {
155+
http.error(ex);
156+
}
157+
finally {
158+
http.stop();
159+
}
160+
handler.assertSpanStart(0, "http", null);
161+
handler.assertSpanStart(1, "spring.security.filterchains", "http");
162+
handler.assertSpanStop(2, "security filterchain before");
163+
handler.assertSpanStart(3, "secured request", "security filterchain before");
164+
handler.assertSpanError(4);
165+
handler.assertSpanStop(5, "secured request");
166+
handler.assertSpanStart(6, "spring.security.filterchains", "http");
167+
handler.assertSpanError(7);
168+
handler.assertSpanStop(8, "security filterchain after");
169+
handler.assertSpanError(9);
170+
handler.assertSpanStop(10, "http");
171+
}
172+
81173
// gh-12849
82174
@Test
83175
void decorateWhenCustomAfterFilterThenObserves() {
@@ -171,6 +263,24 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
171263

172264
}
173265

266+
static class ErroringFilter implements WebFilter {
267+
268+
@Override
269+
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
270+
return Mono.error(() -> new RuntimeException("ack"));
271+
}
272+
273+
}
274+
275+
static class TerminatingFilter implements WebFilter {
276+
277+
@Override
278+
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
279+
return Mono.empty();
280+
}
281+
282+
}
283+
174284
static class AccumulatingObservationHandler implements ObservationHandler<Observation.Context> {
175285

176286
List<Event> contexts = new ArrayList<>();
@@ -246,6 +356,11 @@ private void assertSpanStop(int index, String name) {
246356
}
247357
}
248358

359+
private void assertSpanError(int index) {
360+
Event event = this.contexts.get(index);
361+
assertThat(event.event).isEqualTo("error");
362+
}
363+
249364
static class Event {
250365

251366
String event;

0 commit comments

Comments
 (0)