diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java index a08dda68a67f..e8c198596063 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java @@ -25,12 +25,16 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.security.Principal; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.time.temporal.TemporalAccessor; +import java.util.*; +import java.util.function.Consumer; import jakarta.servlet.ServletContext; import jakarta.servlet.ServletRequest; @@ -109,7 +113,7 @@ public abstract class AbstractMockHttpServletRequestBuilder headers = new LinkedMultiValueMap<>(); + private final HttpHeaders headers = new HttpHeaders(); private final MultiValueMap parameters = new LinkedMultiValueMap<>(); @@ -342,7 +346,13 @@ public B accept(String... mediaTypes) { * @param values one or more header values */ public B header(String name, Object... values) { - this.headers.addAll(name, Arrays.asList(values)); + /* + TODO ask how it should behave if values is omitted. Currently the header is added with value + 'null'. But, usually 'null' means the header is not present. Add list with 1 null element? + */ + var stringValues = + Arrays.stream(values).map(AbstractMockHttpServletRequestBuilder::objectToString).toList(); + this.headers.addAll(name, stringValues); return self(); } @@ -355,6 +365,20 @@ public B headers(HttpHeaders httpHeaders) { return self(); } + /** + * Manipulate this builder's headers with the given consumer. The + * headers provided to the consumer are "live", so that the consumer can be used to + * {@linkplain HttpHeaders#set(String, String) overwrite} existing header values, + * {@linkplain HttpHeaders#remove(String) remove} values, or use any of the other + * {@link HttpHeaders} methods. + * @param httpHeadersConsumer a function that consumes the {@code HttpHeaders} + * @return this builder + */ + public B headers(Consumer httpHeadersConsumer) { + httpHeadersConsumer.accept(this.headers); + return self(); + } + /** * Add a request parameter to {@link MockHttpServletRequest#getParameterMap()}. *

In the Servlet API, a request parameter may be parsed from the query @@ -665,12 +689,12 @@ public Object merge(@Nullable Object parent) { this.contentType = parentBuilder.contentType; } - for (Map.Entry> entry : parentBuilder.headers.entrySet()) { - String headerName = entry.getKey(); - if (!this.headers.containsKey(headerName)) { - this.headers.put(headerName, entry.getValue()); - } - } + parentBuilder.headers.forEach( + (name, values) -> { + if (!this.headers.containsHeader(name)) { + this.headers.put(name, values); + } + }); for (Map.Entry> entry : parentBuilder.parameters.entrySet()) { String paramName = entry.getKey(); if (!this.parameters.containsKey(paramName)) { @@ -741,6 +765,57 @@ private boolean containsCookie(Cookie cookie) { return false; } + /** + * Convert an object to a RFC7231-compliant string when needed. + * + * @param o the object to convert + * @return the object's {@code toString()} value by default; if {@code o} is a {@link TemporalAccessor}, + * it is formatted to an RFC7231-compliant string. + */ + private static String objectToString(Object o) { + Assert.notNull(o, "'o' must not be null"); + if (o instanceof TemporalAccessor ta) { + return temporalToString(ta); + } else { + return o.toString(); + } + } + + /** + * Try to convert a temporal value to an RFC7231-compliant Internet Message Format (IMF-fixdate) + * string (preferred pattern): {@code EEE, dd MMM yyyy HH:mm:ss zzz}. + *

+ * This method supports {@link Instant}, {@link ZonedDateTime}, {@link OffsetDateTime}, + * {@link LocalDateTime}, {@link LocalDate}, and {@link Date}. + *

+ * If an exact conversion cannot be performed, the object's {@code toString()} value is returned + * as a fallback; that representation is typically ISO‑8601, which is not RFC7231 compliant. + * + * @param temporalAccessor the value to convert + * @return an RFC7231-compliant string when possible, otherwise {@code TemporalAccessor.toString()} + */ + private static String temporalToString(TemporalAccessor temporalAccessor) { + Assert.notNull(temporalAccessor, "'temporalAccessor' must not be null"); + var rfc7231 = DateTimeFormatter.RFC_1123_DATE_TIME; + var utc = ZoneOffset.UTC; + + if (temporalAccessor instanceof Instant instant) { + return rfc7231.format(instant.atZone(utc)); + } else if (temporalAccessor instanceof ZonedDateTime zonedDateTime) { + return rfc7231.format(zonedDateTime.withZoneSameInstant(utc)); + } else if (temporalAccessor instanceof OffsetDateTime offsetDateTime) { + return rfc7231.format(offsetDateTime.atZoneSameInstant(utc)); + } else if (temporalAccessor instanceof LocalDateTime localDateTime) { + return rfc7231.format(localDateTime.atZone(utc)); + } else if (temporalAccessor instanceof LocalDate localDate) { + return rfc7231.format(localDate.atStartOfDay(utc)); + } else if(temporalAccessor instanceof Date date) { + return rfc7231.format(Instant.ofEpochMilli(date.getTime()).atZone(utc)); + } else { + return temporalAccessor.toString(); + } + } + /** * Build a {@link MockHttpServletRequest}. */ @@ -801,15 +876,11 @@ public final MockHttpServletRequest buildRequest(ServletContext servletContext) httpHeaders.forEach((name, values) -> values.forEach(value -> this.headers.add(name, value))); } - this.headers.forEach((name, values) -> { - for (Object value : values) { - request.addHeader(name, value); - } - }); + this.headers.forEach((name, values) -> values.forEach(value -> request.addHeader(name, value))); - if (!ObjectUtils.isEmpty(this.content) && - !this.headers.containsKey(HttpHeaders.CONTENT_LENGTH) && - !this.headers.containsKey(HttpHeaders.TRANSFER_ENCODING)) { + if (!ObjectUtils.isEmpty(this.content) + && !this.headers.containsHeader(HttpHeaders.CONTENT_LENGTH) + && !this.headers.containsHeader(HttpHeaders.TRANSFER_ENCODING)) { request.addHeader(HttpHeaders.CONTENT_LENGTH, this.content.length); } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilderTests.java index 60f25994c3bb..f05cc18ae14c 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilderTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilderTests.java @@ -19,6 +19,11 @@ import java.io.IOException; import java.net.URI; import java.security.Principal; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -543,6 +548,44 @@ void headers() { assertThat(request.getHeader("Content-Type")).isEqualTo(MediaType.APPLICATION_JSON.toString()); } + @Test + void headersConsumer() { + this.builder.header("X-Foo-String", "bar"); + this.builder.header("X-Foo-Date", LocalDate.now()); + this.builder.header("X-Foo-List-Int", List.of(1, 2, 3)); + + this.builder.headers(httpHeaders -> { + httpHeaders.put("X-Baz", Arrays.asList("qux", "quux")); + httpHeaders.remove("X-Foo-Date"); + }); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + List headerNames = Collections.list(request.getHeaderNames()); + + assertThat(headerNames).containsExactly("X-Foo-String", "X-Foo-List-Int", "X-Baz"); + } + + @Test + void headersTemporal() { + this.builder.header( + "X-Foo", + Instant.parse("2024-03-01T00:00:00+01:00"), + ZonedDateTime.parse("2024-03-01T00:00:00+01:00"), + OffsetDateTime.parse("2024-03-01T00:00:00+01:00"), + LocalDateTime.of(2024, 2, 29, 23, 0, 0), + LocalDate.of(2024, 2, 29)); + + MockHttpServletRequest request = this.builder.buildRequest(this.servletContext); + List headers = Collections.list(request.getHeaders("X-Foo")); + + assertThat(headers).containsExactly( + "Thu, 29 Feb 2024 23:00:00 GMT", + "Thu, 29 Feb 2024 23:00:00 GMT", + "Thu, 29 Feb 2024 23:00:00 GMT", + "Thu, 29 Feb 2024 23:00:00 GMT", + "Thu, 29 Feb 2024 00:00:00 GMT"); + } + @Test void cookie() { Cookie cookie1 = new Cookie("foo", "bar");