diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/HttpUtils.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/HttpUtils.java new file mode 100644 index 000000000..f6eba4157 --- /dev/null +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/HttpUtils.java @@ -0,0 +1,68 @@ +package com.amazonaws.serverless.proxy.internal; + +import org.apache.commons.io.Charsets; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.nio.charset.UnsupportedCharsetException; + +public final class HttpUtils { + + static final String HEADER_KEY_VALUE_SEPARATOR = "="; + static final String HEADER_VALUE_SEPARATOR = ";"; + static final String ENCODING_VALUE_KEY = "charset"; + + + static public Charset parseCharacterEncoding(String contentTypeHeader,Charset defaultCharset) { + // we only look at content-type because content-encoding should only be used for + // "binary" requests such as gzip/deflate. + if (contentTypeHeader == null) { + return defaultCharset; + } + + String[] contentTypeValues = contentTypeHeader.split(HEADER_VALUE_SEPARATOR); + if (contentTypeValues.length <= 1) { + return defaultCharset; + } + + for (String contentTypeValue : contentTypeValues) { + if (contentTypeValue.trim().startsWith(ENCODING_VALUE_KEY)) { + String[] encodingValues = contentTypeValue.split(HEADER_KEY_VALUE_SEPARATOR); + if (encodingValues.length <= 1) { + return defaultCharset; + } + try { + return Charsets.toCharset(encodingValues[1]); + } catch (UnsupportedCharsetException ex) { + return defaultCharset; + } + } + } + return defaultCharset; + } + + + static public String appendCharacterEncoding(String currentContentType, String newEncoding) { + if (currentContentType == null || currentContentType.trim().isEmpty()) { + return null; + } + + if (currentContentType.contains(HEADER_VALUE_SEPARATOR)) { + String[] contentTypeValues = currentContentType.split(HEADER_VALUE_SEPARATOR); + StringBuilder contentType = new StringBuilder(contentTypeValues[0]); + + for (int i = 1; i < contentTypeValues.length; i++) { + String contentTypeValue = contentTypeValues[i]; + String contentTypeString = HEADER_VALUE_SEPARATOR + " " + contentTypeValue; + if (contentTypeValue.trim().startsWith(ENCODING_VALUE_KEY)) { + contentTypeString = HEADER_VALUE_SEPARATOR + " " + ENCODING_VALUE_KEY + HEADER_KEY_VALUE_SEPARATOR + newEncoding; + } + contentType.append(contentTypeString); + } + + return contentType.toString(); + } else { + return currentContentType + HEADER_VALUE_SEPARATOR + " " + ENCODING_VALUE_KEY + HEADER_KEY_VALUE_SEPARATOR + newEncoding; + } + } +} diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpApiV2ProxyHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpApiV2ProxyHttpServletRequest.java index 537e10759..5a2534a2b 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpApiV2ProxyHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpApiV2ProxyHttpServletRequest.java @@ -12,6 +12,7 @@ */ package com.amazonaws.serverless.proxy.internal.servlet; +import com.amazonaws.serverless.proxy.internal.HttpUtils; import com.amazonaws.serverless.proxy.internal.LambdaContainerHandler; import com.amazonaws.serverless.proxy.internal.SecurityUtils; import com.amazonaws.serverless.proxy.model.ContainerConfig; @@ -32,6 +33,7 @@ import java.io.StringReader; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; +import java.nio.charset.Charset; import java.security.Principal; import java.time.Instant; import java.time.ZonedDateTime; @@ -232,7 +234,8 @@ public String getCharacterEncoding() { if (headers == null) { return config.getDefaultContentCharset(); } - return parseCharacterEncoding(headers.getFirst(HttpHeaders.CONTENT_TYPE)); + Charset charset = HttpUtils.parseCharacterEncoding(headers.getFirst(HttpHeaders.CONTENT_TYPE),null); + return charset != null ? charset.name() : null; } @Override @@ -242,7 +245,7 @@ public void setCharacterEncoding(String s) throws UnsupportedEncodingException { return; } String currentContentType = headers.getFirst(HttpHeaders.CONTENT_TYPE); - headers.putSingle(HttpHeaders.CONTENT_TYPE, appendCharacterEncoding(currentContentType, s)); + headers.putSingle(HttpHeaders.CONTENT_TYPE, HttpUtils.appendCharacterEncoding(currentContentType, s)); } @Override diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java index ea8ef4a1a..97beabe64 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java @@ -369,53 +369,9 @@ protected StringBuffer generateRequestURL(String requestPath) { return new StringBuffer(getScheme() + "://" + url); } - protected String parseCharacterEncoding(String contentTypeHeader) { - // we only look at content-type because content-encoding should only be used for - // "binary" requests such as gzip/deflate. - if (contentTypeHeader == null) { - return null; - } - String[] contentTypeValues = contentTypeHeader.split(HEADER_VALUE_SEPARATOR); - if (contentTypeValues.length <= 1) { - return null; - } - for (String contentTypeValue : contentTypeValues) { - if (contentTypeValue.trim().startsWith(ENCODING_VALUE_KEY)) { - String[] encodingValues = contentTypeValue.split(HEADER_KEY_VALUE_SEPARATOR); - if (encodingValues.length <= 1) { - return null; - } - return encodingValues[1]; - } - } - return null; - } - protected String appendCharacterEncoding(String currentContentType, String newEncoding) { - if (currentContentType == null || currentContentType.trim().isEmpty()) { - return null; - } - - if (currentContentType.contains(HEADER_VALUE_SEPARATOR)) { - String[] contentTypeValues = currentContentType.split(HEADER_VALUE_SEPARATOR); - StringBuilder contentType = new StringBuilder(contentTypeValues[0]); - - for (int i = 1; i < contentTypeValues.length; i++) { - String contentTypeValue = contentTypeValues[i]; - String contentTypeString = HEADER_VALUE_SEPARATOR + " " + contentTypeValue; - if (contentTypeValue.trim().startsWith(ENCODING_VALUE_KEY)) { - contentTypeString = HEADER_VALUE_SEPARATOR + " " + ENCODING_VALUE_KEY + HEADER_KEY_VALUE_SEPARATOR + newEncoding; - } - contentType.append(contentTypeString); - } - - return contentType.toString(); - } else { - return currentContentType + HEADER_VALUE_SEPARATOR + " " + ENCODING_VALUE_KEY + HEADER_KEY_VALUE_SEPARATOR + newEncoding; - } - } protected ServletInputStream bodyStringToInputStream(String body, boolean isBase64Encoded) throws IOException { if (body == null) { diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java index fe514e65d..6406be478 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java @@ -13,6 +13,7 @@ package com.amazonaws.serverless.proxy.internal.servlet; +import com.amazonaws.serverless.proxy.internal.HttpUtils; import com.amazonaws.serverless.proxy.internal.LambdaContainerHandler; import com.amazonaws.serverless.proxy.internal.SecurityUtils; import com.amazonaws.serverless.proxy.model.AwsProxyRequest; @@ -35,6 +36,7 @@ import java.io.IOException; import java.io.StringReader; import java.io.UnsupportedEncodingException; +import java.nio.charset.Charset; import java.security.Principal; import java.time.Instant; import java.time.ZonedDateTime; @@ -273,7 +275,8 @@ public String getCharacterEncoding() { if (request.getMultiValueHeaders() == null) { return config.getDefaultContentCharset(); } - return parseCharacterEncoding(request.getMultiValueHeaders().getFirst(HttpHeaders.CONTENT_TYPE)); + Charset charset = HttpUtils.parseCharacterEncoding(request.getMultiValueHeaders().getFirst(HttpHeaders.CONTENT_TYPE),null); + return charset != null ? charset.name() : null; } @@ -284,12 +287,12 @@ public void setCharacterEncoding(String s) request.setMultiValueHeaders(new Headers()); } String currentContentType = request.getMultiValueHeaders().getFirst(HttpHeaders.CONTENT_TYPE); - if (currentContentType == null || "".equals(currentContentType)) { + if (currentContentType == null || currentContentType.isEmpty()) { log.debug("Called set character encoding to " + SecurityUtils.crlf(s) + " on a request without a content type. Character encoding will not be set"); return; } - request.getMultiValueHeaders().putSingle(HttpHeaders.CONTENT_TYPE, appendCharacterEncoding(currentContentType, s)); + request.getMultiValueHeaders().putSingle(HttpHeaders.CONTENT_TYPE, HttpUtils.appendCharacterEncoding(currentContentType, s)); } diff --git a/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringHttpProcessingUtils.java b/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringHttpProcessingUtils.java index c7e507f39..01074f865 100644 --- a/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringHttpProcessingUtils.java +++ b/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/AwsSpringHttpProcessingUtils.java @@ -1,24 +1,28 @@ package com.amazonaws.serverless.proxy.spring; import java.io.InputStream; +import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.util.Iterator; +import java.nio.charset.UnsupportedCharsetException; +import java.util.Base64; import java.util.Map; import java.util.Map.Entry; -import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import com.amazonaws.serverless.proxy.internal.HttpUtils; +import org.apache.commons.io.Charsets; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.cloud.function.serverless.web.ServerlessHttpServletRequest; import org.springframework.cloud.function.serverless.web.ServerlessMVC; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; import org.springframework.util.CollectionUtils; import org.springframework.util.FileCopyUtils; import org.springframework.util.MultiValueMapAdapter; import org.springframework.util.StringUtils; -import com.amazonaws.serverless.proxy.AsyncInitializationWrapper; import com.amazonaws.serverless.proxy.AwsHttpApiV2SecurityContextWriter; import com.amazonaws.serverless.proxy.AwsProxySecurityContextWriter; import com.amazonaws.serverless.proxy.RequestReader; @@ -120,10 +124,12 @@ private static HttpServletRequest generateRequest1(String request, Context lambd MultiValueMapAdapter headers = new MultiValueMapAdapter(v1Request.getMultiValueHeaders()); httpRequest.setHeaders(headers); } - if (StringUtils.hasText(v1Request.getBody())) { - httpRequest.setContentType("application/json"); - httpRequest.setContent(v1Request.getBody().getBytes(StandardCharsets.UTF_8)); - } + populateContentAndContentType( + v1Request.getBody(), + v1Request.getHeaders().get(HttpHeaders.CONTENT_TYPE), + v1Request.isBase64Encoded(), + httpRequest + ); if (v1Request.getRequestContext() != null) { httpRequest.setAttribute(RequestReader.API_GATEWAY_CONTEXT_PROPERTY, v1Request.getRequestContext()); httpRequest.setAttribute(RequestReader.ALB_CONTEXT_PROPERTY, v1Request.getRequestContext().getElb()); @@ -149,11 +155,14 @@ private static HttpServletRequest generateRequest2(String request, Context lambd populateQueryStringparameters(v2Request.getQueryStringParameters(), httpRequest); v2Request.getHeaders().forEach(httpRequest::setHeader); - - if (StringUtils.hasText(v2Request.getBody())) { - httpRequest.setContentType("application/json"); - httpRequest.setContent(v2Request.getBody().getBytes(StandardCharsets.UTF_8)); - } + + populateContentAndContentType( + v2Request.getBody(), + v2Request.getHeaders().get(HttpHeaders.CONTENT_TYPE), + v2Request.isBase64Encoded(), + httpRequest + ); + httpRequest.setAttribute(RequestReader.HTTP_API_CONTEXT_PROPERTY, v2Request.getRequestContext()); httpRequest.setAttribute(RequestReader.HTTP_API_STAGE_VARS_PROPERTY, v2Request.getStageVariables()); httpRequest.setAttribute(RequestReader.HTTP_API_EVENT_PROPERTY, v2Request); @@ -180,4 +189,22 @@ private static T readValue(String json, Class clazz, ObjectMapper mapper) } } + private static void populateContentAndContentType( + String body, + String contentType, + boolean base64Encoded, + ServerlessHttpServletRequest httpRequest) { + if (StringUtils.hasText(body)) { + httpRequest.setContentType(contentType == null ? MediaType.APPLICATION_JSON_VALUE : contentType); + if (base64Encoded) { + httpRequest.setContent(Base64.getMimeDecoder().decode(body)); + } else { + Charset charseEncoding = HttpUtils.parseCharacterEncoding(contentType,StandardCharsets.UTF_8); + httpRequest.setContent(body.getBytes(charseEncoding)); + } + } + } + + + } diff --git a/aws-serverless-java-container-springboot3/src/test/java/com/amazonaws/serverless/proxy/spring/SpringDelegatingLambdaContainerHandlerTests.java b/aws-serverless-java-container-springboot3/src/test/java/com/amazonaws/serverless/proxy/spring/SpringDelegatingLambdaContainerHandlerTests.java index 2fb85e7e7..61957fe24 100644 --- a/aws-serverless-java-container-springboot3/src/test/java/com/amazonaws/serverless/proxy/spring/SpringDelegatingLambdaContainerHandlerTests.java +++ b/aws-serverless-java-container-springboot3/src/test/java/com/amazonaws/serverless/proxy/spring/SpringDelegatingLambdaContainerHandlerTests.java @@ -6,15 +6,11 @@ import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; +import java.util.*; import com.amazonaws.serverless.exceptions.ContainerInitializationException; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import org.springframework.cloud.function.serverless.web.ServerlessServletContext; import org.springframework.util.CollectionUtils; import com.amazonaws.serverless.proxy.spring.servletapp.MessageData; @@ -214,7 +210,7 @@ public static Collection data() { public void validateComplesrequest(String jsonEvent) throws Exception { initServletAppTest(); InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", - "/foo/male/list/24", "{\"name\":\"bob\"}", null)); + "/foo/male/list/24", "{\"name\":\"bob\"}", false,null)); ByteArrayOutputStream output = new ByteArrayOutputStream(); handler.handleRequest(targetStream, output, null); Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class); @@ -229,7 +225,7 @@ public void validateComplesrequest(String jsonEvent) throws Exception { @ParameterizedTest public void testAsyncPost(String jsonEvent) throws Exception { initServletAppTest(); - InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/async", "{\"name\":\"bob\"}", null)); + InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/async", "{\"name\":\"bob\"}",false, null)); ByteArrayOutputStream output = new ByteArrayOutputStream(); handler.handleRequest(targetStream, output, null); Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class); @@ -242,7 +238,7 @@ public void testAsyncPost(String jsonEvent) throws Exception { public void testValidate400(String jsonEvent) throws Exception { initServletAppTest(); UserData ud = new UserData(); - InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud), null)); + InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud),false, null)); ByteArrayOutputStream output = new ByteArrayOutputStream(); handler.handleRequest(targetStream, output, null); Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class); @@ -258,7 +254,7 @@ public void testValidate200(String jsonEvent) throws Exception { ud.setFirstName("bob"); ud.setLastName("smith"); ud.setEmail("foo@bar.com"); - InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud), null)); + InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud),false, null)); ByteArrayOutputStream output = new ByteArrayOutputStream(); handler.handleRequest(targetStream, output, null); Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class); @@ -266,12 +262,31 @@ public void testValidate200(String jsonEvent) throws Exception { assertEquals("VALID", result.get("body")); } + @MethodSource("data") + @ParameterizedTest + public void testValidate200Base64(String jsonEvent) throws Exception { + initServletAppTest(); + UserData ud = new UserData(); + ud.setFirstName("bob"); + ud.setLastName("smith"); + ud.setEmail("foo@bar.com"); + InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", + Base64.getMimeEncoder().encodeToString(mapper.writeValueAsString(ud).getBytes()),true, null)); + + ByteArrayOutputStream output = new ByteArrayOutputStream(); + handler.handleRequest(targetStream, output, null); + Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class); + assertEquals(200, result.get("statusCode")); + assertEquals("VALID", result.get("body")); + } + + @MethodSource("data") @ParameterizedTest public void messageObject_parsesObject_returnsCorrectMessage(String jsonEvent) throws Exception { initServletAppTest(); InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/message", - mapper.writeValueAsString(new MessageData("test message")), null)); + mapper.writeValueAsString(new MessageData("test message")),false, null)); ByteArrayOutputStream output = new ByteArrayOutputStream(); handler.handleRequest(targetStream, output, null); Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class); @@ -279,6 +294,8 @@ public void messageObject_parsesObject_returnsCorrectMessage(String jsonEvent) t assertEquals("test message", result.get("body")); } + + @SuppressWarnings({"unchecked" }) @MethodSource("data") @ParameterizedTest @@ -289,7 +306,7 @@ void messageObject_propertiesInContentType_returnsCorrectMessage(String jsonEven headers.put(HttpHeaders.CONTENT_TYPE, "application/json;v=1"); headers.put(HttpHeaders.ACCEPT, "application/json;v=1"); InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/message", - mapper.writeValueAsString(new MessageData("test message")), headers)); + mapper.writeValueAsString(new MessageData("test message")),false, headers)); ByteArrayOutputStream output = new ByteArrayOutputStream(); handler.handleRequest(targetStream, output, null); @@ -297,19 +314,20 @@ void messageObject_propertiesInContentType_returnsCorrectMessage(String jsonEven assertEquals("test message", result.get("body")); } - private byte[] generateHttpRequest(String jsonEvent, String method, String path, String body, Map headers) throws Exception { + private byte[] generateHttpRequest(String jsonEvent, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception { Map requestMap = mapper.readValue(jsonEvent, Map.class); if (requestMap.get("version").equals("2.0")) { - return generateHttpRequest2(requestMap, method, path, body, headers); + return generateHttpRequest2(requestMap, method, path, body, isBase64Encoded,headers); } - return generateHttpRequest(requestMap, method, path, body, headers); + return generateHttpRequest(requestMap, method, path, body,isBase64Encoded, headers); } @SuppressWarnings({ "unchecked"}) - private byte[] generateHttpRequest(Map requestMap, String method, String path, String body, Map headers) throws Exception { + private byte[] generateHttpRequest(Map requestMap, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception { requestMap.put("path", path); requestMap.put("httpMethod", method); requestMap.put("body", body); + requestMap.put("isBase64Encoded", isBase64Encoded); if (!CollectionUtils.isEmpty(headers)) { requestMap.put("headers", headers); } @@ -317,12 +335,13 @@ private byte[] generateHttpRequest(Map requestMap, String method, String path, S } @SuppressWarnings({ "unchecked"}) - private byte[] generateHttpRequest2(Map requestMap, String method, String path, String body, Map headers) throws Exception { + private byte[] generateHttpRequest2(Map requestMap, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception { Map map = mapper.readValue(API_GATEWAY_EVENT_V2, Map.class); Map http = (Map) ((Map) map.get("requestContext")).get("http"); http.put("path", path); http.put("method", method); map.put("body", body); + map.put("isBase64Encoded", isBase64Encoded); if (!CollectionUtils.isEmpty(headers)) { map.put("headers", headers); }