Skip to content

Commit

Permalink
Don't use raw xml saml authentication request for response validation
Browse files Browse the repository at this point in the history
closes gh-12961
  • Loading branch information
1livv committed Apr 3, 2023
1 parent dd4ce24 commit dd18b56
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 102 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,7 +37,6 @@
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.schema.XSAny;
import org.opensaml.core.xml.schema.XSBoolean;
import org.opensaml.core.xml.schema.XSBooleanValue;
Expand Down Expand Up @@ -89,7 +88,6 @@
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
Expand Down Expand Up @@ -410,16 +408,15 @@ private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2Au
if (!StringUtils.hasText(inResponseTo)) {
return Saml2ResponseValidatorResult.success();
}
AuthnRequest request = parseRequest(storedRequest);
if (request == null) {
if (storedRequest == null) {
String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]"
+ " but no saved authentication request was found";
return Saml2ResponseValidatorResult
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
}
if (!inResponseTo.equals(request.getID())) {
if (!inResponseTo.equals(storedRequest.getId())) {
String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the "
+ "authentication request [" + request.getID() + "]";
+ "authentication request [" + storedRequest.getId() + "]";
return Saml2ResponseValidatorResult
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
}
Expand Down Expand Up @@ -776,37 +773,7 @@ private static boolean assertionContainsInResponseTo(Assertion assertion) {
}

private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) {
AuthnRequest request = parseRequest(serialized);
if (request == null) {
return null;
}
return request.getID();
}

private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) {
if (request == null) {
return null;
}
String samlRequest = request.getSamlRequest();
if (!StringUtils.hasText(samlRequest)) {
return null;
}
if (request.getBinding() == Saml2MessageBinding.REDIRECT) {
samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest));
}
else {
samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8);
}
try {
Document document = XMLObjectProviderRegistrySupport.getParserPool()
.parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8)));
Element element = document.getDocumentElement();
return (AuthnRequest) authnRequestUnmarshaller.unmarshall(element);
}
catch (Exception ex) {
String message = "Failed to deserialize associated authentication request [" + ex.getMessage() + "]";
throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message, ex);
}
return (serialized != null) ? serialized.getId() : null;
}

private static class SAML20AssertionValidators {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,6 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
Expand Down Expand Up @@ -48,7 +47,6 @@
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AttributeValue;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Conditions;
import org.opensaml.saml.saml2.core.EncryptedAssertion;
import org.opensaml.saml.saml2.core.EncryptedAttribute;
Expand Down Expand Up @@ -78,7 +76,6 @@
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken;
import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSamlObjects.CustomOpenSamlObject;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.util.StringUtils;

Expand Down Expand Up @@ -228,8 +225,7 @@ public void evaluateInResponseToSucceedsWhenInResponseToInResponseAndAssertionsM
response.setInResponseTo("SAML2");
response.getAssertions().add(signed(assertion("SAML2")));
response.getAssertions().add(signed(assertion("SAML2")));
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
Saml2MessageBinding.POST, false);
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
this.provider.authenticate(token);
}
Expand All @@ -239,32 +235,18 @@ public void evaluateInResponseToSucceedsWhenInResponseToInAssertionOnlyMatchRequ
Response response = response();
response.getAssertions().add(signed(assertion()));
response.getAssertions().add(signed(assertion("SAML2")));
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
Saml2MessageBinding.POST, false);
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
this.provider.authenticate(token);
}

@Test
public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndCorruptedStoredRequest() {
Response response = response();
response.getAssertions().add(signed(assertion()));
response.getAssertions().add(signed(assertion("SAML2")));
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
Saml2MessageBinding.POST, true);
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data");
}

@Test
public void evaluateInResponseToFailsWhenInResponseToInAssertionMismatchWithRequestID() {
Response response = response();
response.setInResponseTo("SAML2");
response.getAssertions().add(signed(assertion("SAML2")));
response.getAssertions().add(signed(assertion("BAD")));
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
Saml2MessageBinding.POST, false);
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion");
Expand All @@ -275,8 +257,7 @@ public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndMismatchW
Response response = response();
response.getAssertions().add(signed(assertion()));
response.getAssertions().add(signed(assertion("BAD")));
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
Saml2MessageBinding.POST, false);
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion");
Expand All @@ -288,26 +269,12 @@ public void evaluateInResponseToFailsWhenInResponseInToResponseMismatchWithReque
response.setInResponseTo("BAD");
response.getAssertions().add(signed(assertion("SAML2")));
response.getAssertions().add(signed(assertion("SAML2")));
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
Saml2MessageBinding.POST, false);
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_in_response_to");
}

@Test
public void evaluateInResponseToFailsWhenInResponseInToResponseAndCorruptedStoredRequest() {
Response response = response();
response.setInResponseTo("SAML2");
response.getAssertions().add(signed(assertion()));
response.getAssertions().add(signed(assertion()));
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
Saml2MessageBinding.POST, true);
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data");
}

@Test
public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest() {
Response response = response();
Expand All @@ -321,8 +288,7 @@ public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest
public void evaluateInResponseToSucceedsWhenNoInResponseToInResponseOrAssertions() {
Response response = response();
response.getAssertions().add(signed(assertion()));
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
Saml2MessageBinding.POST, false);
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
this.provider.authenticate(token);
}
Expand Down Expand Up @@ -805,17 +771,6 @@ private Response response(String destination, String issuerEntityId) {
return response;
}

private AuthnRequest request() {
AuthnRequest request = TestOpenSamlObjects.authnRequest();
return request;
}

private String serializedRequest(AuthnRequest request, Saml2MessageBinding binding) {
String xml = serialize(request);
return (binding == Saml2MessageBinding.POST) ? Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8))
: Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
}

private Assertion assertion(String inResponseTo) {
Assertion assertion = TestOpenSamlObjects.assertion();
assertion.setIssueInstant(Instant.now());
Expand Down Expand Up @@ -871,19 +826,9 @@ private Saml2AuthenticationToken token(Response response, RelyingPartyRegistrati
return new Saml2AuthenticationToken(registration.build(), serialize(response), authenticationRequest);
}

private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId,
Saml2MessageBinding binding, boolean corruptRequestString) {
AuthnRequest request = request();
if (requestId != null) {
request.setID(requestId);
}
String serializedRequest = serializedRequest(request, binding);
if (corruptRequestString) {
serializedRequest = serializedRequest.substring(2, serializedRequest.length() - 2);
}
private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId) {
AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
given(mockAuthenticationRequest.getSamlRequest()).willReturn(serializedRequest);
given(mockAuthenticationRequest.getBinding()).willReturn(binding);
given(mockAuthenticationRequest.getId()).willReturn(requestId);
return mockAuthenticationRequest;
}

Expand Down

0 comments on commit dd18b56

Please # to comment.