Skip to content

Commit 97d1a49

Browse files
committed
Add findUniqueByAssertingPartyEntityId
Closes gh-12848
1 parent 8522e9a commit 97d1a49

File tree

4 files changed

+255
-2
lines changed

4 files changed

+255
-2
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepository.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,36 @@
2121
import java.util.Collections;
2222
import java.util.Iterator;
2323
import java.util.LinkedHashMap;
24+
import java.util.List;
2425
import java.util.Map;
2526

2627
import org.springframework.util.Assert;
28+
import org.springframework.util.LinkedMultiValueMap;
29+
import org.springframework.util.MultiValueMap;
2730

2831
/**
29-
* An in-memory implementation of {@link RelyingPartyRegistrationRepository}.
30-
* Also implements {@link Iterable} to simplify the default login page.
32+
* An in-memory implementation of {@link RelyingPartyRegistrationRepository}. Also
33+
* implements {@link Iterable} to simplify the default login page.
3134
*
3235
* @author Filip Hanik
36+
* @author Josh Cummings
3337
* @since 5.2
3438
*/
3539
public class InMemoryRelyingPartyRegistrationRepository
3640
implements RelyingPartyRegistrationRepository, Iterable<RelyingPartyRegistration> {
3741

3842
private final Map<String, RelyingPartyRegistration> byRegistrationId;
3943

44+
private final Map<String, List<RelyingPartyRegistration>> byAssertingPartyEntityId;
45+
4046
public InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistration... registrations) {
4147
this(Arrays.asList(registrations));
4248
}
4349

4450
public InMemoryRelyingPartyRegistrationRepository(Collection<RelyingPartyRegistration> registrations) {
4551
Assert.notEmpty(registrations, "registrations cannot be empty");
4652
this.byRegistrationId = createMappingToIdentityProvider(registrations);
53+
this.byAssertingPartyEntityId = createMappingByAssertingPartyEntityId(registrations);
4754
}
4855

4956
private static Map<String, RelyingPartyRegistration> createMappingToIdentityProvider(
@@ -59,11 +66,32 @@ private static Map<String, RelyingPartyRegistration> createMappingToIdentityProv
5966
return Collections.unmodifiableMap(result);
6067
}
6168

69+
private static Map<String, List<RelyingPartyRegistration>> createMappingByAssertingPartyEntityId(
70+
Collection<RelyingPartyRegistration> rps) {
71+
MultiValueMap<String, RelyingPartyRegistration> result = new LinkedMultiValueMap<>();
72+
for (RelyingPartyRegistration rp : rps) {
73+
result.add(rp.getAssertingPartyDetails().getEntityId(), rp);
74+
}
75+
return Collections.unmodifiableMap(result);
76+
}
77+
6278
@Override
6379
public RelyingPartyRegistration findByRegistrationId(String id) {
6480
return this.byRegistrationId.get(id);
6581
}
6682

83+
@Override
84+
public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
85+
Collection<RelyingPartyRegistration> registrations = this.byAssertingPartyEntityId.get(entityId);
86+
if (registrations == null) {
87+
return null;
88+
}
89+
if (registrations.size() > 1) {
90+
return null;
91+
}
92+
return registrations.iterator().next();
93+
}
94+
6795
@Override
6896
public Iterator<RelyingPartyRegistration> iterator() {
6997
return this.byRegistrationId.values().iterator();

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationRepository.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
* A repository for {@link RelyingPartyRegistration}s
2121
*
2222
* @author Filip Hanik
23+
* @author Josh Cummings
2324
* @since 5.2
2425
*/
2526
public interface RelyingPartyRegistrationRepository {
@@ -32,4 +33,16 @@ public interface RelyingPartyRegistrationRepository {
3233
*/
3334
RelyingPartyRegistration findByRegistrationId(String registrationId);
3435

36+
/**
37+
* Returns the unique relying party registration associated with the asserting party's
38+
* {@code entityId} or {@code null} if there is no unique match.
39+
* @param entityId the asserting party's entity id
40+
* @return the unique {@link RelyingPartyRegistration} associated the given asserting
41+
* party; {@code null} of there is no unique match asserting party
42+
* @since 6.1
43+
*/
44+
default RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
45+
return findByRegistrationId(entityId);
46+
}
47+
3548
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.saml2.provider.service.web;
18+
19+
import java.io.ByteArrayOutputStream;
20+
import java.nio.charset.StandardCharsets;
21+
import java.util.Arrays;
22+
import java.util.Base64;
23+
import java.util.function.Function;
24+
import java.util.zip.Inflater;
25+
import java.util.zip.InflaterOutputStream;
26+
27+
import jakarta.servlet.http.HttpServletRequest;
28+
29+
import org.springframework.http.HttpMethod;
30+
import org.springframework.security.saml2.core.Saml2Error;
31+
import org.springframework.security.saml2.core.Saml2ErrorCodes;
32+
import org.springframework.security.saml2.core.Saml2ParameterNames;
33+
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
34+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
35+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
36+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
37+
import org.springframework.security.web.authentication.AuthenticationConverter;
38+
import org.springframework.util.Assert;
39+
40+
/**
41+
* An {@link AuthenticationConverter} that generates a {@link Saml2AuthenticationToken}
42+
* appropriate for authenticated a SAML 2.0 Assertion against an
43+
* {@link org.springframework.security.authentication.AuthenticationManager}.
44+
*
45+
* @author Josh Cummings
46+
* @since 5.4
47+
*/
48+
public final class Saml2AuthenticationTokenConverter implements AuthenticationConverter {
49+
50+
// MimeDecoder allows extra line-breaks as well as other non-alphabet values.
51+
// This matches the behaviour of the commons-codec decoder.
52+
private static final Base64.Decoder BASE64 = Base64.getMimeDecoder();
53+
54+
private static final Base64Checker BASE_64_CHECKER = new Base64Checker();
55+
56+
private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
57+
58+
private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
59+
60+
/**
61+
* Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
62+
* resolving {@link RelyingPartyRegistration}s
63+
* @param relyingPartyRegistrationResolver the strategy for resolving
64+
* {@link RelyingPartyRegistration}s
65+
*/
66+
public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
67+
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
68+
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
69+
this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
70+
}
71+
72+
@Override
73+
public Saml2AuthenticationToken convert(HttpServletRequest request) {
74+
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
75+
String relyingPartyRegistrationId = (authenticationRequest != null)
76+
? authenticationRequest.getRelyingPartyRegistrationId() : null;
77+
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
78+
relyingPartyRegistrationId);
79+
if (relyingPartyRegistration == null) {
80+
return null;
81+
}
82+
String saml2Response = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
83+
if (saml2Response == null) {
84+
return null;
85+
}
86+
byte[] b = samlDecode(saml2Response);
87+
saml2Response = inflateIfRequired(request, b);
88+
return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
89+
}
90+
91+
/**
92+
* Use the given {@link Saml2AuthenticationRequestRepository} to load authentication
93+
* request.
94+
* @param authenticationRequestRepository the
95+
* {@link Saml2AuthenticationRequestRepository} to use
96+
* @since 5.6
97+
*/
98+
public void setAuthenticationRequestRepository(
99+
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
100+
Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
101+
this.loader = authenticationRequestRepository::loadAuthenticationRequest;
102+
}
103+
104+
private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
105+
return this.loader.apply(request);
106+
}
107+
108+
private String inflateIfRequired(HttpServletRequest request, byte[] b) {
109+
if (HttpMethod.GET.matches(request.getMethod())) {
110+
return samlInflate(b);
111+
}
112+
return new String(b, StandardCharsets.UTF_8);
113+
}
114+
115+
private byte[] samlDecode(String base64EncodedPayload) {
116+
try {
117+
BASE_64_CHECKER.checkAcceptable(base64EncodedPayload);
118+
return BASE64.decode(base64EncodedPayload);
119+
}
120+
catch (Exception ex) {
121+
throw new Saml2AuthenticationException(
122+
new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Failed to decode SAMLResponse"), ex);
123+
}
124+
}
125+
126+
private String samlInflate(byte[] b) {
127+
try {
128+
ByteArrayOutputStream out = new ByteArrayOutputStream();
129+
InflaterOutputStream inflaterOutputStream = new InflaterOutputStream(out, new Inflater(true));
130+
inflaterOutputStream.write(b);
131+
inflaterOutputStream.finish();
132+
return out.toString(StandardCharsets.UTF_8.name());
133+
}
134+
catch (Exception ex) {
135+
throw new Saml2AuthenticationException(
136+
new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Unable to inflate string"), ex);
137+
}
138+
}
139+
140+
static class Base64Checker {
141+
142+
private static final int[] values = genValueMapping();
143+
144+
Base64Checker() {
145+
146+
}
147+
148+
private static int[] genValueMapping() {
149+
byte[] alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
150+
.getBytes(StandardCharsets.ISO_8859_1);
151+
152+
int[] values = new int[256];
153+
Arrays.fill(values, -1);
154+
for (int i = 0; i < alphabet.length; i++) {
155+
values[alphabet[i] & 0xff] = i;
156+
}
157+
return values;
158+
}
159+
160+
boolean isAcceptable(String s) {
161+
int goodChars = 0;
162+
int lastGoodCharVal = -1;
163+
164+
// count number of characters from Base64 alphabet
165+
for (int i = 0; i < s.length(); i++) {
166+
int val = values[0xff & s.charAt(i)];
167+
if (val != -1) {
168+
lastGoodCharVal = val;
169+
goodChars++;
170+
}
171+
}
172+
173+
// in cases of an incomplete final chunk, ensure the unused bits are zero
174+
switch (goodChars % 4) {
175+
case 0:
176+
return true;
177+
case 2:
178+
return (lastGoodCharVal & 0b1111) == 0;
179+
case 3:
180+
return (lastGoodCharVal & 0b11) == 0;
181+
default:
182+
return false;
183+
}
184+
}
185+
186+
void checkAcceptable(String ins) {
187+
if (!isAcceptable(ins)) {
188+
throw new IllegalArgumentException("Unaccepted Encoding");
189+
}
190+
}
191+
192+
}
193+
194+
}

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepositoryTests.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,22 @@ void findByRegistrationIdWhenGivenWrongIdThenReturnsNull() {
4242
assertThat(registrations.findByRegistrationId(null)).isNull();
4343
}
4444

45+
@Test
46+
void findByAssertingPartyEntityIdWhenGivenEntityIdThenReturnsMatchingRegistrations() {
47+
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
48+
InMemoryRelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(
49+
registration);
50+
String assertingPartyEntityId = registration.getAssertingPartyDetails().getEntityId();
51+
assertThat(registrations.findUniqueByAssertingPartyEntityId(assertingPartyEntityId)).isEqualTo(registration);
52+
}
53+
54+
@Test
55+
void findByAssertingPartyEntityIdWhenGivenWrongEntityIdThenReturnsEmpty() {
56+
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
57+
InMemoryRelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(
58+
registration);
59+
String assertingPartyEntityId = registration.getAssertingPartyDetails().getEntityId();
60+
assertThat(registrations.findUniqueByAssertingPartyEntityId(assertingPartyEntityId + "wrong")).isNull();
61+
}
62+
4563
}

0 commit comments

Comments
 (0)