Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@
* wraps the chain in before and after observations
*
* @author Josh Cummings
* @author Nikita Konev
* @since 6.0
*/
public final class ObservationFilterChainDecorator implements FilterChainProxy.FilterChainDecorator {

private static final Log logger = LogFactory.getLog(FilterChainProxy.class);

private static final String ATTRIBUTE = ObservationFilterChainDecorator.class + ".observation";
static final String ATTRIBUTE = ObservationFilterChainDecorator.class + ".observation";

static final String UNSECURED_OBSERVATION_NAME = "spring.security.http.unsecured.requests";

Expand Down Expand Up @@ -250,6 +251,16 @@ private void wrapFilter(ServletRequest request, ServletResponse response, Filter
private AroundFilterObservation parent(HttpServletRequest request) {
FilterChainObservationContext beforeContext = FilterChainObservationContext.before();
FilterChainObservationContext afterContext = FilterChainObservationContext.after();

AroundFilterObservation existingParentObservation = (AroundFilterObservation) request
.getAttribute(ATTRIBUTE);
if (existingParentObservation != null) {
beforeContext
.setParentObservation(existingParentObservation.before().getContext().getParentObservation());
afterContext
.setParentObservation(existingParentObservation.after().getContext().getParentObservation());
}

Observation before = Observation.createNotStarted(this.convention, () -> beforeContext, this.registry);
Observation after = Observation.createNotStarted(this.convention, () -> afterContext, this.registry);
AroundFilterObservation parent = AroundFilterObservation.create(before, after);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,65 @@ public void doFilterWhenMatchesThenObservationRegistryObserves() throws Exceptio
assertFilterChainObservation(contexts.next(), "after", 1);
}

// gh-12610
@Test
void parentObservationIsTakenIntoAccountDuringDispatchError() throws Exception {
ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
given(handler.supportsContext(any())).willReturn(true);
ObservationRegistry registry = ObservationRegistry.create();
registry.observationConfig().observationHandler(handler);

given(this.matcher.matches(any())).willReturn(true);
SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter));
FilterChainProxy fcp = new FilterChainProxy(sec);
fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry));
Filter initialFilter = ObservationFilterChainDecorator.FilterObservation
.create(Observation.createNotStarted("wrap", registry))
.wrap(fcp);

ServletRequest initialRequest = new MockHttpServletRequest("GET", "/");
initialFilter.doFilter(initialRequest, new MockHttpServletResponse(), this.chain);

// simulate request attribute copying in case dispatching to ERROR
ObservationFilterChainDecorator.AroundFilterObservation parentObservation = (ObservationFilterChainDecorator.AroundFilterObservation) initialRequest
.getAttribute(ObservationFilterChainDecorator.ATTRIBUTE);
assertThat(parentObservation).isNotNull();

// simulate dispatching error-related request
Filter errorRelatedFilter = ObservationFilterChainDecorator.FilterObservation
.create(Observation.createNotStarted("wrap", registry))
.wrap(fcp);
ServletRequest errorRelatedRequest = new MockHttpServletRequest("GET", "/error");
errorRelatedRequest.setAttribute(ObservationFilterChainDecorator.ATTRIBUTE, parentObservation);
errorRelatedFilter.doFilter(errorRelatedRequest, new MockHttpServletResponse(), this.chain);

ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
verify(handler, times(8)).onStart(captor.capture());
verify(handler, times(8)).onStop(any());
List<Observation.Context> contexts = captor.getAllValues();

Observation.Context initialRequestObservationContextBefore = contexts.get(1);
Observation.Context initialRequestObservationContextAfter = contexts.get(3);
assertFilterChainObservation(initialRequestObservationContextBefore, "before", 1);
assertFilterChainObservation(initialRequestObservationContextAfter, "after", 1);

assertThat(initialRequestObservationContextBefore.getParentObservation()).isNotNull();
assertThat(initialRequestObservationContextBefore.getParentObservation())
.isSameAs(initialRequestObservationContextAfter.getParentObservation());

Observation.Context errorRelatedRequestObservationContextBefore = contexts.get(5);
Observation.Context errorRelatedRequestObservationContextAfter = contexts.get(7);
assertFilterChainObservation(errorRelatedRequestObservationContextBefore, "before", 1);
assertFilterChainObservation(errorRelatedRequestObservationContextAfter, "after", 1);

assertThat(errorRelatedRequestObservationContextBefore.getParentObservation()).isNotNull();
assertThat(errorRelatedRequestObservationContextBefore.getParentObservation())
.isSameAs(initialRequestObservationContextBefore.getParentObservation());
assertThat(errorRelatedRequestObservationContextAfter.getParentObservation()).isNotNull();
assertThat(errorRelatedRequestObservationContextAfter.getParentObservation())
.isSameAs(initialRequestObservationContextBefore.getParentObservation());
}

@Test
public void doFilterWhenMultipleFiltersThenObservationRegistryObserves() throws Exception {
ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
Expand Down