Skip to content

Commit a0826a2

Browse files
committed
CorsInterceptor at the front of the chain
Closes gh-22459
1 parent d1f888a commit a0826a2

File tree

4 files changed

+53
-51
lines changed

4 files changed

+53
-51
lines changed

spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -91,6 +91,10 @@ public void addInterceptor(HandlerInterceptor interceptor) {
9191
initInterceptorList().add(interceptor);
9292
}
9393

94+
public void addInterceptor(int index, HandlerInterceptor interceptor) {
95+
initInterceptorList().add(index, interceptor);
96+
}
97+
9498
public void addInterceptors(HandlerInterceptor... interceptors) {
9599
if (!ObjectUtils.isEmpty(interceptors)) {
96100
CollectionUtils.mergeArrayIntoCollection(interceptors, initInterceptorList());

spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ protected HandlerExecutionChain getCorsHandlerExecutionChain(HttpServletRequest
526526
chain = new HandlerExecutionChain(new PreFlightHandler(config), interceptors);
527527
}
528528
else {
529-
chain.addInterceptor(new CorsInterceptor(config));
529+
chain.addInterceptor(0, new CorsInterceptor(config));
530530
}
531531
return chain;
532532
}

spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,11 @@
8686
import org.springframework.web.servlet.view.json.MappingJackson2JsonView;
8787
import org.springframework.web.util.UrlPathHelper;
8888

89-
import static com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES;
90-
import static com.fasterxml.jackson.databind.MapperFeature.DEFAULT_VIEW_INCLUSION;
91-
import static org.junit.Assert.assertEquals;
92-
import static org.junit.Assert.assertFalse;
93-
import static org.junit.Assert.assertNotNull;
94-
import static org.junit.Assert.assertSame;
89+
import static com.fasterxml.jackson.databind.DeserializationFeature.*;
90+
import static com.fasterxml.jackson.databind.MapperFeature.*;
91+
import static org.junit.Assert.*;
9592
import static org.mockito.Mockito.*;
96-
import static org.springframework.http.MediaType.APPLICATION_ATOM_XML;
97-
import static org.springframework.http.MediaType.APPLICATION_JSON;
98-
import static org.springframework.http.MediaType.APPLICATION_XML;
93+
import static org.springframework.http.MediaType.*;
9994

10095
/**
10196
* A test fixture with a sub-class of {@link WebMvcConfigurationSupport} that also
@@ -141,9 +136,10 @@ public void handlerMappings() throws Exception {
141136
assertNotNull(chain);
142137
assertNotNull(chain.getInterceptors());
143138
assertEquals(4, chain.getInterceptors().length);
144-
assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[0].getClass());
145-
assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[1].getClass());
146-
assertEquals(ResourceUrlProviderExposingInterceptor.class, chain.getInterceptors()[2].getClass());
139+
assertEquals("CorsInterceptor", chain.getInterceptors()[0].getClass().getSimpleName());
140+
assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[1].getClass());
141+
assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[2].getClass());
142+
assertEquals(ResourceUrlProviderExposingInterceptor.class, chain.getInterceptors()[3].getClass());
147143

148144
Map<RequestMappingInfo, HandlerMethod> map = rmHandlerMapping.getHandlerMethods();
149145
assertEquals(2, map.size());
@@ -185,10 +181,11 @@ public void handlerMappings() throws Exception {
185181
assertNotNull(chain);
186182
assertNotNull(chain.getHandler());
187183
assertEquals(Arrays.toString(chain.getInterceptors()), 5, chain.getInterceptors().length);
188-
// PathExposingHandlerInterceptor at chain.getInterceptors()[0]
189-
assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[1].getClass());
190-
assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[2].getClass());
191-
assertEquals(ResourceUrlProviderExposingInterceptor.class, chain.getInterceptors()[3].getClass());
184+
assertEquals("CorsInterceptor", chain.getInterceptors()[0].getClass().getSimpleName());
185+
// PathExposingHandlerInterceptor at chain.getInterceptors()[1]
186+
assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[2].getClass());
187+
assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[3].getClass());
188+
assertEquals(ResourceUrlProviderExposingInterceptor.class, chain.getInterceptors()[4].getClass());
192189

193190
handlerMapping = (AbstractHandlerMapping) this.config.defaultServletHandlerMapping();
194191
handlerMapping.setApplicationContext(this.context);

spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@
1616

1717
package org.springframework.web.servlet.handler;
1818

19-
import static org.junit.Assert.*;
20-
2119
import java.io.IOException;
2220
import java.util.Collections;
23-
2421
import javax.servlet.ServletException;
2522
import javax.servlet.http.HttpServletRequest;
2623
import javax.servlet.http.HttpServletResponse;
@@ -32,6 +29,7 @@
3229
import org.springframework.http.HttpHeaders;
3330
import org.springframework.http.HttpStatus;
3431
import org.springframework.mock.web.test.MockHttpServletRequest;
32+
import org.springframework.util.ObjectUtils;
3533
import org.springframework.web.HttpRequestHandler;
3634
import org.springframework.web.bind.annotation.RequestMethod;
3735
import org.springframework.web.context.support.StaticWebApplicationContext;
@@ -41,6 +39,9 @@
4139
import org.springframework.web.servlet.HandlerInterceptor;
4240
import org.springframework.web.servlet.support.WebContentGenerator;
4341

42+
import static org.junit.Assert.*;
43+
import static org.mockito.Mockito.*;
44+
4445
/**
4546
* Unit tests for CORS-related handling in {@link AbstractHandlerMapping}.
4647
* @author Sebastien Deleuze
@@ -57,6 +58,7 @@ public class CorsAbstractHandlerMappingTests {
5758
public void setup() {
5859
StaticWebApplicationContext context = new StaticWebApplicationContext();
5960
this.handlerMapping = new TestHandlerMapping();
61+
this.handlerMapping.setInterceptors(mock(HandlerInterceptor.class));
6062
this.handlerMapping.setApplicationContext(context);
6163
this.request = new MockHttpServletRequest();
6264
this.request.setRemoteHost("domain1.com");
@@ -69,6 +71,7 @@ public void actualRequestWithoutCorsConfigurationProvider() throws Exception {
6971
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
7072
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
7173
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
74+
7275
assertNotNull(chain);
7376
assertTrue(chain.getHandler() instanceof SimpleHandler);
7477
}
@@ -80,6 +83,7 @@ public void preflightRequestWithoutCorsConfigurationProvider() throws Exception
8083
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
8184
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
8285
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
86+
8387
assertNotNull(chain);
8488
assertTrue(chain.getHandler() instanceof SimpleHandler);
8589
}
@@ -91,11 +95,10 @@ public void actualRequestWithCorsConfigurationProvider() throws Exception {
9195
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
9296
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
9397
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
98+
9499
assertNotNull(chain);
95100
assertTrue(chain.getHandler() instanceof CorsAwareHandler);
96-
CorsConfiguration config = getCorsConfiguration(chain, false);
97-
assertNotNull(config);
98-
assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"});
101+
assertEquals(Collections.singletonList("*"), getRequiredCorsConfiguration(chain, false).getAllowedOrigins());
99102
}
100103

101104
@Test
@@ -105,12 +108,11 @@ public void preflightRequestWithCorsConfigurationProvider() throws Exception {
105108
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
106109
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
107110
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
111+
108112
assertNotNull(chain);
109113
assertNotNull(chain.getHandler());
110-
assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler"));
111-
CorsConfiguration config = getCorsConfiguration(chain, true);
112-
assertNotNull(config);
113-
assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"});
114+
assertEquals("PreFlightHandler", chain.getHandler().getClass().getSimpleName());
115+
assertEquals(Collections.singletonList("*"), getRequiredCorsConfiguration(chain, true).getAllowedOrigins());
114116
}
115117

116118
@Test
@@ -123,11 +125,10 @@ public void actualRequestWithMappedCorsConfiguration() throws Exception {
123125
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
124126
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
125127
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
128+
126129
assertNotNull(chain);
127130
assertTrue(chain.getHandler() instanceof SimpleHandler);
128-
config = getCorsConfiguration(chain, false);
129-
assertNotNull(config);
130-
assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"});
131+
assertEquals(Collections.singletonList("*"), getRequiredCorsConfiguration(chain, false).getAllowedOrigins());
131132
}
132133

133134
@Test
@@ -140,12 +141,11 @@ public void preflightRequestWithMappedCorsConfiguration() throws Exception {
140141
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
141142
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
142143
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
144+
143145
assertNotNull(chain);
144146
assertNotNull(chain.getHandler());
145-
assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler"));
146-
config = getCorsConfiguration(chain, true);
147-
assertNotNull(config);
148-
assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"});
147+
assertEquals("PreFlightHandler", chain.getHandler().getClass().getSimpleName());
148+
assertEquals(Collections.singletonList("*"), getRequiredCorsConfiguration(chain, true).getAllowedOrigins());
149149
}
150150

151151
@Test
@@ -156,11 +156,12 @@ public void actualRequestWithCorsConfigurationSource() throws Exception {
156156
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
157157
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
158158
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
159+
159160
assertNotNull(chain);
160161
assertTrue(chain.getHandler() instanceof SimpleHandler);
161-
CorsConfiguration config = getCorsConfiguration(chain, false);
162+
CorsConfiguration config = getRequiredCorsConfiguration(chain, false);
162163
assertNotNull(config);
163-
assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray());
164+
assertEquals(Collections.singletonList("*"), config.getAllowedOrigins());
164165
assertEquals(true, config.getAllowCredentials());
165166
}
166167

@@ -172,35 +173,35 @@ public void preflightRequestWithCorsConfigurationSource() throws Exception {
172173
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
173174
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
174175
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
176+
175177
assertNotNull(chain);
176178
assertNotNull(chain.getHandler());
177-
assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler"));
178-
CorsConfiguration config = getCorsConfiguration(chain, true);
179+
assertEquals("PreFlightHandler", chain.getHandler().getClass().getSimpleName());
180+
CorsConfiguration config = getRequiredCorsConfiguration(chain, true);
179181
assertNotNull(config);
180-
assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray());
182+
assertEquals(Collections.singletonList("*"), config.getAllowedOrigins());
181183
assertEquals(true, config.getAllowCredentials());
182184
}
183185

184186

185-
private CorsConfiguration getCorsConfiguration(HandlerExecutionChain chain, boolean isPreFlightRequest) {
187+
@SuppressWarnings("ConstantConditions")
188+
private CorsConfiguration getRequiredCorsConfiguration(HandlerExecutionChain chain, boolean isPreFlightRequest) {
189+
CorsConfiguration corsConfig = null;
186190
if (isPreFlightRequest) {
187191
Object handler = chain.getHandler();
188-
assertTrue(handler.getClass().getSimpleName().equals("PreFlightHandler"));
192+
assertEquals("PreFlightHandler", handler.getClass().getSimpleName());
189193
DirectFieldAccessor accessor = new DirectFieldAccessor(handler);
190-
return (CorsConfiguration)accessor.getPropertyValue("config");
194+
corsConfig = (CorsConfiguration) accessor.getPropertyValue("config");
191195
}
192196
else {
193197
HandlerInterceptor[] interceptors = chain.getInterceptors();
194-
if (interceptors != null) {
195-
for (HandlerInterceptor interceptor : interceptors) {
196-
if (interceptor.getClass().getSimpleName().equals("CorsInterceptor")) {
197-
DirectFieldAccessor accessor = new DirectFieldAccessor(interceptor);
198-
return (CorsConfiguration) accessor.getPropertyValue("config");
199-
}
200-
}
198+
if (!ObjectUtils.isEmpty(interceptors)) {
199+
DirectFieldAccessor accessor = new DirectFieldAccessor(interceptors[0]);
200+
corsConfig = (CorsConfiguration) accessor.getPropertyValue("config");
201201
}
202202
}
203-
return null;
203+
assertNotNull(corsConfig);
204+
return corsConfig;
204205
}
205206

206207
public class TestHandlerMapping extends AbstractHandlerMapping {

0 commit comments

Comments
 (0)