Skip to content

Commit a402c38

Browse files
committed
Add ConditionValidator Support
Closes gh-8769
1 parent d9d8253 commit a402c38

File tree

2 files changed

+143
-28
lines changed

2 files changed

+143
-28
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
import java.time.Instant;
2222
import java.util.ArrayList;
2323
import java.util.Collection;
24+
import java.util.Collections;
2425
import java.util.HashMap;
2526
import java.util.HashSet;
2627
import java.util.LinkedHashMap;
2728
import java.util.List;
2829
import java.util.Map;
2930
import java.util.Set;
30-
import java.util.function.Consumer;
3131
import java.util.function.Function;
3232
import javax.annotation.Nonnull;
3333

@@ -193,10 +193,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
193193

194194
private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter =
195195
new SignatureTrustEngineConverter();
196-
private Converter<Saml2AuthenticationToken, SAML20AssertionValidator> assertionValidatorConverter =
196+
private Converter<Tuple, SAML20AssertionValidator> assertionValidatorConverter =
197197
new SAML20AssertionValidatorConverter();
198-
private Converter<Saml2AuthenticationToken, ValidationContext> validationContextConverter =
199-
new ValidationContextConverter(params -> {});
198+
private Collection<ConditionValidator> conditionValidators =
199+
Collections.singleton(new AudienceRestrictionConditionValidator());
200+
private Converter<Tuple, ValidationContext> validationContextConverter =
201+
new ValidationContextConverter();
200202
private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
201203

202204
/**
@@ -209,6 +211,33 @@ public OpenSamlAuthenticationProvider() {
209211
this.parserPool = this.registry.getParserPool();
210212
}
211213

214+
/**
215+
* Set the the collection of {@link ConditionValidator}s used when validating an assertion.
216+
*
217+
* @param conditionValidators the collection of validators to use
218+
* @since 5.4
219+
*/
220+
public void setConditionValidators(
221+
Collection<ConditionValidator> conditionValidators) {
222+
223+
Assert.notEmpty(conditionValidators, "conditionValidators cannot be empty");
224+
this.conditionValidators = conditionValidators;
225+
}
226+
227+
/**
228+
* Set the strategy for retrieving the {@link ValidationContext} used when
229+
* validating an assertion.
230+
*
231+
* @param validationContextConverter the strategy to use
232+
* @since 5.4
233+
*/
234+
public void setValidationContextConverter(
235+
Converter<Tuple, ValidationContext> validationContextConverter) {
236+
237+
Assert.notNull(validationContextConverter, "validationContextConverter cannot be empty");
238+
this.validationContextConverter = validationContextConverter;
239+
}
240+
212241
/**
213242
* Sets the {@link Converter} used for extracting assertion attributes that
214243
* can be mapped to authorities.
@@ -238,8 +267,6 @@ public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
238267
*/
239268
public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) {
240269
this.responseTimeValidationSkew = responseTimeValidationSkew;
241-
this.validationContextConverter = new ValidationContextConverter(
242-
params -> params.put(CLOCK_SKEW, responseTimeValidationSkew.toMillis()));
243270
}
244271

245272
/**
@@ -303,7 +330,7 @@ private void process(Saml2AuthenticationToken token, Response response) {
303330
throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " +
304331
"Please either sign the response or all of the assertions.");
305332
}
306-
validationExceptions.putAll(validateAssertions(token, assertions));
333+
validationExceptions.putAll(validateAssertions(token, response));
307334

308335
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
309336
NameID nameId = decryptPrincipal(decrypter, firstAssertion);
@@ -392,7 +419,8 @@ private void process(Saml2AuthenticationToken token, Response response) {
392419
}
393420

394421
private Map<String, Saml2AuthenticationException> validateAssertions
395-
(Saml2AuthenticationToken token, List<Assertion> assertions) {
422+
(Saml2AuthenticationToken token, Response response) {
423+
List<Assertion> assertions = response.getAssertions();
396424
if (assertions.isEmpty()) {
397425
throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
398426
}
@@ -401,14 +429,16 @@ private void process(Saml2AuthenticationToken token, Response response) {
401429
if (logger.isDebugEnabled()) {
402430
logger.debug("Validating " + assertions.size() + " assertions");
403431
}
432+
433+
Tuple tuple = new Tuple(token, response);
434+
SAML20AssertionValidator validator = this.assertionValidatorConverter.convert(tuple);
435+
ValidationContext context = this.validationContextConverter.convert(tuple);
404436
for (Assertion assertion : assertions) {
405437
if (logger.isTraceEnabled()) {
406438
logger.trace("Validating assertion " + assertion.getID());
407439
}
408440
try {
409-
ValidationContext context = this.validationContextConverter.convert(token);
410-
ValidationResult result = this.assertionValidatorConverter.convert(token).validate(assertion, context);
411-
if (result != ValidationResult.VALID) {
441+
if (validator.validate(assertion, context) != ValidationResult.VALID) {
412442
String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
413443
assertion.getID(), ((Response) assertion.getParent()).getID(),
414444
context.getValidationFailureMessage());
@@ -512,6 +542,7 @@ private Object getXSAnyObjectValue(XSAny xsAny) {
512542
}
513543

514544
private static class SignatureTrustEngineConverter implements Converter<Saml2AuthenticationToken, SignatureTrustEngine> {
545+
515546
@Override
516547
public SignatureTrustEngine convert(Saml2AuthenticationToken token) {
517548
Set<Credential> credentials = new HashSet<>();
@@ -530,35 +561,27 @@ public SignatureTrustEngine convert(Saml2AuthenticationToken token) {
530561
}
531562
}
532563

533-
private static class ValidationContextConverter implements Converter<Saml2AuthenticationToken, ValidationContext> {
534-
Consumer<Map<String, Object>> validationContextParametersConverter;
535-
536-
ValidationContextConverter(Consumer<Map<String, Object>> validationContextParametersConverter) {
537-
this.validationContextParametersConverter = validationContextParametersConverter;
538-
}
564+
private class ValidationContextConverter implements Converter<Tuple, ValidationContext> {
539565

540566
@Override
541-
public ValidationContext convert(Saml2AuthenticationToken token) {
542-
String audience = token.getRelyingPartyRegistration().getEntityId();
543-
String recipient = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
567+
public ValidationContext convert(Tuple tuple) {
568+
String audience = tuple.authentication.getRelyingPartyRegistration().getEntityId();
569+
String recipient = tuple.authentication.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
544570
Map<String, Object> params = new HashMap<>();
545-
params.put(CLOCK_SKEW, Duration.ofMinutes(5).toMillis());
571+
params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
546572
params.put(COND_VALID_AUDIENCES, singleton(audience));
547573
params.put(SC_VALID_RECIPIENTS, singleton(recipient));
548574
params.put(SIGNATURE_REQUIRED, false); // this verification is performed earlier
549-
this.validationContextParametersConverter.accept(params);
550575
return new ValidationContext(params);
551576
}
552577
}
553578

554-
private class SAML20AssertionValidatorConverter implements Converter<Saml2AuthenticationToken, SAML20AssertionValidator> {
555-
private final Collection<ConditionValidator> conditions = new ArrayList<>();
579+
private class SAML20AssertionValidatorConverter implements Converter<Tuple, SAML20AssertionValidator> {
556580
private final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
557581
private final Collection<StatementValidator> statements = new ArrayList<>();
558582
private final SignaturePrevalidator validator = new SAMLSignatureProfileValidator();
559583

560584
SAML20AssertionValidatorConverter() {
561-
this.conditions.add(new AudienceRestrictionConditionValidator());
562585
this.subjects.add(new BearerSubjectConfirmationValidator() {
563586
@Nonnull
564587
@Override
@@ -571,9 +594,11 @@ protected ValidationResult validateAddress(@Nonnull SubjectConfirmation confirma
571594
}
572595

573596
@Override
574-
public SAML20AssertionValidator convert(Saml2AuthenticationToken token) {
575-
return new SAML20AssertionValidator(this.conditions, this.subjects, this.statements,
576-
OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(token),
597+
public SAML20AssertionValidator convert(Tuple tuple) {
598+
Collection<ConditionValidator> conditions =
599+
OpenSamlAuthenticationProvider.this.conditionValidators;
600+
return new SAML20AssertionValidator(conditions, this.subjects, this.statements,
601+
OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(tuple.authentication),
577602
this.validator);
578603
}
579604
}
@@ -616,4 +641,27 @@ private static Saml2AuthenticationException authException(String code, String de
616641

617642
return new Saml2AuthenticationException(validationError(code, description), cause);
618643
}
644+
645+
/**
646+
* A tuple containing the authentication token and the associated OpenSAML {@link Response}.
647+
*
648+
* @since 5.4
649+
*/
650+
public static class Tuple {
651+
private final Saml2AuthenticationToken authentication;
652+
private final Response response;
653+
654+
private Tuple(Saml2AuthenticationToken authentication, Response response) {
655+
this.authentication = authentication;
656+
this.response = response;
657+
}
658+
659+
public Saml2AuthenticationToken getAuthentication() {
660+
return this.authentication;
661+
}
662+
663+
public Response getResponse() {
664+
return this.response;
665+
}
666+
}
619667
}

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
import java.time.Instant;
2424
import java.util.Arrays;
2525
import java.util.Collections;
26+
import java.util.HashMap;
2627
import java.util.LinkedHashMap;
2728
import java.util.List;
2829
import java.util.Map;
30+
import javax.xml.namespace.QName;
2931
import javax.xml.parsers.DocumentBuilder;
3032
import javax.xml.parsers.DocumentBuilderFactory;
3133

@@ -42,12 +44,17 @@
4244
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
4345
import org.opensaml.core.xml.io.Marshaller;
4446
import org.opensaml.core.xml.io.MarshallingException;
47+
import org.opensaml.saml.common.assertion.ValidationContext;
48+
import org.opensaml.saml.common.assertion.ValidationResult;
49+
import org.opensaml.saml.saml2.assertion.impl.OneTimeUseConditionValidator;
4550
import org.opensaml.saml.saml2.core.Assertion;
4651
import org.opensaml.saml.saml2.core.AttributeStatement;
4752
import org.opensaml.saml.saml2.core.AttributeValue;
53+
import org.opensaml.saml.saml2.core.Condition;
4854
import org.opensaml.saml.saml2.core.EncryptedAssertion;
4955
import org.opensaml.saml.saml2.core.EncryptedID;
5056
import org.opensaml.saml.saml2.core.NameID;
57+
import org.opensaml.saml.saml2.core.OneTimeUse;
5158
import org.opensaml.saml.saml2.core.Response;
5259
import org.w3c.dom.Document;
5360
import org.w3c.dom.Element;
@@ -57,14 +64,18 @@
5764
import org.springframework.security.saml2.Saml2Exception;
5865
import org.springframework.security.saml2.credentials.Saml2X509Credential;
5966

67+
import static java.util.Collections.singleton;
6068
import static org.assertj.core.api.Assertions.assertThat;
69+
import static org.assertj.core.api.Assertions.assertThatCode;
6170
import static org.mockito.ArgumentMatchers.any;
6271
import static org.mockito.Mockito.atLeastOnce;
6372
import static org.mockito.Mockito.mock;
6473
import static org.mockito.Mockito.verify;
6574
import static org.mockito.Mockito.when;
6675
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
6776
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
77+
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
78+
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
6879
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
6980
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
7081
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
@@ -353,6 +364,62 @@ public void writeObjectWhenTypeIsSaml2AuthenticationThenNoException() throws IOE
353364
objectOutputStream.flush();
354365
}
355366

367+
@Test
368+
public void authenticateWhenConditionValidatorsCustomizedThenUses() throws Exception {
369+
OneTimeUseConditionValidator validator = mock(OneTimeUseConditionValidator.class);
370+
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
371+
provider.setConditionValidators(Collections.singleton(validator));
372+
Response response = response();
373+
Assertion assertion = assertion();
374+
OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME);
375+
assertion.getConditions().getConditions().add(oneTimeUse);
376+
response.getAssertions().add(assertion);
377+
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
378+
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
379+
when(validator.getServicedCondition()).thenReturn(OneTimeUse.DEFAULT_ELEMENT_NAME);
380+
when(validator.validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class)))
381+
.thenReturn(ValidationResult.VALID);
382+
provider.authenticate(token);
383+
verify(validator).validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class));
384+
}
385+
386+
@Test
387+
public void authenticateWhenValidationContextCustomizedThenUsers() {
388+
Map<String, Object> parameters = new HashMap<>();
389+
parameters.put(SC_VALID_RECIPIENTS, singleton(DESTINATION));
390+
parameters.put(SIGNATURE_REQUIRED, false);
391+
ValidationContext context = mock(ValidationContext.class);
392+
when(context.getStaticParameters()).thenReturn(parameters);
393+
OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
394+
provider.setValidationContextConverter(tuple -> context);
395+
Response response = response();
396+
Assertion assertion = assertion();
397+
response.getAssertions().add(assertion);
398+
signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
399+
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
400+
provider.authenticate(token);
401+
verify(context, atLeastOnce()).getStaticParameters();
402+
}
403+
404+
@Test
405+
public void setValidationContextConverterWhenNullThenIllegalArgument() {
406+
assertThatCode(() -> this.provider.setValidationContextConverter(null))
407+
.isInstanceOf(IllegalArgumentException.class);
408+
}
409+
410+
@Test
411+
public void setConditionValidatorsWhenNullOrEmptyThenIllegalArgument() {
412+
assertThatCode(() -> this.provider.setConditionValidators(null))
413+
.isInstanceOf(IllegalArgumentException.class);
414+
415+
assertThatCode(() -> this.provider.setConditionValidators(Collections.emptyList()))
416+
.isInstanceOf(IllegalArgumentException.class);
417+
}
418+
419+
private <T extends XMLObject> T build(QName qName) {
420+
return (T) getBuilderFactory().getBuilder(qName).buildObject(qName);
421+
}
422+
356423
private String serialize(XMLObject object) {
357424
try {
358425
Marshaller marshaller = getMarshallerFactory().getMarshaller(object);

0 commit comments

Comments
 (0)