diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java index d523999d604..2928e9418ff 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,6 @@ import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistry; -import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.schema.XSAny; import org.opensaml.core.xml.schema.XSBoolean; import org.opensaml.core.xml.schema.XSBooleanValue; @@ -89,7 +88,6 @@ import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; @@ -410,16 +408,15 @@ private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2Au if (!StringUtils.hasText(inResponseTo)) { return Saml2ResponseValidatorResult.success(); } - AuthnRequest request = parseRequest(storedRequest); - if (request == null) { + if (storedRequest == null) { String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]" + " but no saved authentication request was found"; return Saml2ResponseValidatorResult .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); } - if (!inResponseTo.equals(request.getID())) { + if (!inResponseTo.equals(storedRequest.getId())) { String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the " - + "authentication request [" + request.getID() + "]"; + + "authentication request [" + storedRequest.getId() + "]"; return Saml2ResponseValidatorResult .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); } @@ -776,37 +773,7 @@ private static boolean assertionContainsInResponseTo(Assertion assertion) { } private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) { - AuthnRequest request = parseRequest(serialized); - if (request == null) { - return null; - } - return request.getID(); - } - - private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) { - if (request == null) { - return null; - } - String samlRequest = request.getSamlRequest(); - if (!StringUtils.hasText(samlRequest)) { - return null; - } - if (request.getBinding() == Saml2MessageBinding.REDIRECT) { - samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest)); - } - else { - samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8); - } - try { - Document document = XMLObjectProviderRegistrySupport.getParserPool() - .parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8))); - Element element = document.getDocumentElement(); - return (AuthnRequest) authnRequestUnmarshaller.unmarshall(element); - } - catch (Exception ex) { - String message = "Failed to deserialize associated authentication request [" + ex.getMessage() + "]"; - throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message, ex); - } + return (serialized != null) ? serialized.getId() : null; } private static class SAML20AssertionValidators { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index eb4de8e7e61..ec3c15783a8 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; -import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; import java.util.Arrays; @@ -48,7 +47,6 @@ import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.AttributeValue; -import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.Conditions; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedAttribute; @@ -78,7 +76,6 @@ import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken; import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSamlObjects.CustomOpenSamlObject; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.util.StringUtils; @@ -228,8 +225,7 @@ public void evaluateInResponseToSucceedsWhenInResponseToInResponseAndAssertionsM response.setInResponseTo("SAML2"); response.getAssertions().add(signed(assertion("SAML2"))); response.getAssertions().add(signed(assertion("SAML2"))); - AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", - Saml2MessageBinding.POST, false); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); this.provider.authenticate(token); } @@ -239,32 +235,18 @@ public void evaluateInResponseToSucceedsWhenInResponseToInAssertionOnlyMatchRequ Response response = response(); response.getAssertions().add(signed(assertion())); response.getAssertions().add(signed(assertion("SAML2"))); - AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", - Saml2MessageBinding.POST, false); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); this.provider.authenticate(token); } - @Test - public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndCorruptedStoredRequest() { - Response response = response(); - response.getAssertions().add(signed(assertion())); - response.getAssertions().add(signed(assertion("SAML2"))); - AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", - Saml2MessageBinding.POST, true); - Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); - assertThatExceptionOfType(Saml2AuthenticationException.class) - .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data"); - } - @Test public void evaluateInResponseToFailsWhenInResponseToInAssertionMismatchWithRequestID() { Response response = response(); response.setInResponseTo("SAML2"); response.getAssertions().add(signed(assertion("SAML2"))); response.getAssertions().add(signed(assertion("BAD"))); - AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", - Saml2MessageBinding.POST, false); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion"); @@ -275,8 +257,7 @@ public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndMismatchW Response response = response(); response.getAssertions().add(signed(assertion())); response.getAssertions().add(signed(assertion("BAD"))); - AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", - Saml2MessageBinding.POST, false); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion"); @@ -288,26 +269,12 @@ public void evaluateInResponseToFailsWhenInResponseInToResponseMismatchWithReque response.setInResponseTo("BAD"); response.getAssertions().add(signed(assertion("SAML2"))); response.getAssertions().add(signed(assertion("SAML2"))); - AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", - Saml2MessageBinding.POST, false); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_in_response_to"); } - @Test - public void evaluateInResponseToFailsWhenInResponseInToResponseAndCorruptedStoredRequest() { - Response response = response(); - response.setInResponseTo("SAML2"); - response.getAssertions().add(signed(assertion())); - response.getAssertions().add(signed(assertion())); - AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", - Saml2MessageBinding.POST, true); - Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); - assertThatExceptionOfType(Saml2AuthenticationException.class) - .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data"); - } - @Test public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest() { Response response = response(); @@ -321,8 +288,7 @@ public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest public void evaluateInResponseToSucceedsWhenNoInResponseToInResponseOrAssertions() { Response response = response(); response.getAssertions().add(signed(assertion())); - AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", - Saml2MessageBinding.POST, false); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2"); Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); this.provider.authenticate(token); } @@ -805,17 +771,6 @@ private Response response(String destination, String issuerEntityId) { return response; } - private AuthnRequest request() { - AuthnRequest request = TestOpenSamlObjects.authnRequest(); - return request; - } - - private String serializedRequest(AuthnRequest request, Saml2MessageBinding binding) { - String xml = serialize(request); - return (binding == Saml2MessageBinding.POST) ? Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8)) - : Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); - } - private Assertion assertion(String inResponseTo) { Assertion assertion = TestOpenSamlObjects.assertion(); assertion.setIssueInstant(Instant.now()); @@ -871,19 +826,9 @@ private Saml2AuthenticationToken token(Response response, RelyingPartyRegistrati return new Saml2AuthenticationToken(registration.build(), serialize(response), authenticationRequest); } - private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId, - Saml2MessageBinding binding, boolean corruptRequestString) { - AuthnRequest request = request(); - if (requestId != null) { - request.setID(requestId); - } - String serializedRequest = serializedRequest(request, binding); - if (corruptRequestString) { - serializedRequest = serializedRequest.substring(2, serializedRequest.length() - 2); - } + private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId) { AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); - given(mockAuthenticationRequest.getSamlRequest()).willReturn(serializedRequest); - given(mockAuthenticationRequest.getBinding()).willReturn(binding); + given(mockAuthenticationRequest.getId()).willReturn(requestId); return mockAuthenticationRequest; }