From a98f5aceb0451f29803a95959c0f79d54b23ad2f Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 10 Nov 2020 13:47:57 +0900 Subject: [PATCH] Add a mechanism for manually propagating an aws sdk's request's segment through a static map. --- .../xray/interceptors/TracingInterceptor.java | 67 ++++++++++--- .../interceptors/TracingInterceptorTest.java | 93 ++++++++++++++----- 2 files changed, 126 insertions(+), 34 deletions(-) diff --git a/aws-xray-recorder-sdk-aws-sdk-v2/src/main/java/com/amazonaws/xray/interceptors/TracingInterceptor.java b/aws-xray-recorder-sdk-aws-sdk-v2/src/main/java/com/amazonaws/xray/interceptors/TracingInterceptor.java index b07a537c..cd4f75a0 100644 --- a/aws-xray-recorder-sdk-aws-sdk-v2/src/main/java/com/amazonaws/xray/interceptors/TracingInterceptor.java +++ b/aws-xray-recorder-sdk-aws-sdk-v2/src/main/java/com/amazonaws/xray/interceptors/TracingInterceptor.java @@ -21,6 +21,7 @@ import com.amazonaws.xray.entities.EntityDataKeys; import com.amazonaws.xray.entities.EntityHeaderKeys; import com.amazonaws.xray.entities.Namespace; +import com.amazonaws.xray.entities.Segment; import com.amazonaws.xray.entities.Subsegment; import com.amazonaws.xray.entities.TraceHeader; import com.amazonaws.xray.handlers.config.AWSOperationHandler; @@ -34,14 +35,17 @@ import java.io.IOException; import java.net.URL; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.WeakHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import software.amazon.awssdk.awscore.AwsExecutionAttribute; +import software.amazon.awssdk.awscore.AwsRequest; import software.amazon.awssdk.awscore.AwsResponse; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SdkResponse; @@ -65,6 +69,8 @@ public class TracingInterceptor implements ExecutionInterceptor { // TODO(anuraaga): Make private in next major version and rename. public static final ExecutionAttribute entityKey = new ExecutionAttribute("AWS X-Ray Entity"); + private static final Map MANUALLY_PROPAGATED_SEGMENTS = Collections.synchronizedMap(new WeakHashMap<>()); + private static final Log logger = LogFactory.getLog(TracingInterceptor.class); private static final ObjectMapper MAPPER = new ObjectMapper() @@ -82,6 +88,23 @@ public class TracingInterceptor implements ExecutionInterceptor { private AWSXRayRecorder recorder; private final String accountId; + /** + * Adds an external reference to the {@link Segment} from the {@link AwsRequest} which will be used in-place of parenting via + * a {@link ThreadLocal}. This is an advanced API for users that can't rely on thread-local propagation, but it is strongly + * discouraged from using this. + * + *

To manually register a segment to a request, call this method just before sending it. + * + *

{@code
+     * ListTablesRequest request = ListTablesRequest.builder().build();
+     * TracingInterceptor.unsafeAddSegmentToRequest(AWSXray.getCurrentSegment(), request);
+     * client.listTables(request);
+     * }
+ */ + public static void unsafeAddSegmentToRequest(Segment segment, AwsRequest request) { + MANUALLY_PROPAGATED_SEGMENTS.put(request, segment); + } + public TracingInterceptor() { this(null, null, null); } @@ -220,22 +243,42 @@ private HashMap extractResponseParameters( @Override public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) { AWSXRayRecorder recorder = getRecorder(); - Entity origin = recorder.getTraceEntity(); - Subsegment subsegment = recorder.beginSubsegment(executionAttributes.getAttribute(SdkExecutionAttribute.SERVICE_NAME)); - subsegment.setNamespace(Namespace.AWS.toString()); - subsegment.putAws(EntityDataKeys.AWS.OPERATION_KEY, - executionAttributes.getAttribute(SdkExecutionAttribute.OPERATION_NAME)); - Region region = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION); - if (region != null) { - subsegment.putAws(EntityDataKeys.AWS.REGION_KEY, region.id()); + Entity origin = null; + SdkRequest sdkRequest = context.request(); + if (sdkRequest instanceof AwsRequest) { + origin = MANUALLY_PROPAGATED_SEGMENTS.remove(sdkRequest); + } + boolean manuallyPropagatedParent = false; + if (origin != null) { + manuallyPropagatedParent = true; + recorder.setTraceEntity(origin); + } else { + origin = recorder.getTraceEntity(); } - subsegment.putAllAws(extractRequestParameters(context, executionAttributes)); - if (accountId != null) { - subsegment.putAws(EntityDataKeys.AWS.ACCOUNT_ID_SUBSEGMENT_KEY, accountId); + + + Subsegment subsegment = recorder.beginSubsegment(executionAttributes.getAttribute(SdkExecutionAttribute.SERVICE_NAME)); + try { + subsegment.setNamespace(Namespace.AWS.toString()); + subsegment.putAws(EntityDataKeys.AWS.OPERATION_KEY, + executionAttributes.getAttribute(SdkExecutionAttribute.OPERATION_NAME)); + Region region = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION); + if (region != null) { + subsegment.putAws(EntityDataKeys.AWS.REGION_KEY, region.id()); + } + subsegment.putAllAws(extractRequestParameters(context, executionAttributes)); + if (accountId != null) { + subsegment.putAws(EntityDataKeys.AWS.ACCOUNT_ID_SUBSEGMENT_KEY, accountId); + } + } finally { + if (manuallyPropagatedParent) { + recorder.setTraceEntity(null); + } else { + recorder.setTraceEntity(origin); + } } - recorder.setTraceEntity(origin); // store the subsegment in the AWS SDK's executionAttributes so it can be accessed across threads executionAttributes.putAttribute(entityKey, subsegment); } diff --git a/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java b/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java index 59dc6d2e..81984a02 100644 --- a/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java +++ b/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java @@ -15,26 +15,30 @@ package com.amazonaws.xray.interceptors; +import static org.assertj.core.api.Assertions.assertThat; + import com.amazonaws.xray.AWSXRay; import com.amazonaws.xray.AWSXRayRecorderBuilder; import com.amazonaws.xray.emitters.Emitter; import com.amazonaws.xray.entities.Cause; import com.amazonaws.xray.entities.Segment; import com.amazonaws.xray.entities.Subsegment; +import com.amazonaws.xray.strategy.IgnoreErrorContextMissingStrategy; import java.io.ByteArrayInputStream; import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.concurrent.CompletableFuture; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; -import org.junit.FixMethodOrder; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.MethodSorters; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; import org.mockito.Mockito; -import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; import org.mockito.stubbing.Answer; import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; @@ -55,13 +59,15 @@ import software.amazon.awssdk.services.lambda.LambdaClient; import software.amazon.awssdk.services.lambda.model.InvokeRequest; -@FixMethodOrder(MethodSorters.JVM) -@RunWith(MockitoJUnitRunner.class) -public class TracingInterceptorTest { +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +class TracingInterceptorTest { + + @Mock + private Emitter blankEmitter; - @Before + @BeforeEach public void setup() { - Emitter blankEmitter = Mockito.mock(Emitter.class); Mockito.doReturn(true).when(blankEmitter).sendSegment(Mockito.anyObject()); AWSXRay.setGlobalRecorder( @@ -73,7 +79,7 @@ public void setup() { AWSXRay.beginSegment("test"); } - @After + @AfterEach public void teardown() { AWSXRay.endSegment(); } @@ -123,7 +129,7 @@ private SdkHttpResponse generateLambdaInvokeResponse(int statusCode) { } @Test - public void testResponseDescriptors() throws Exception { + void testResponseDescriptors() throws Exception { String responseBody = "{\"LastEvaluatedTableName\":\"baz\",\"TableNames\":[\"foo\",\"bar\",\"baz\"]}"; SdkHttpResponse mockResponse = SdkHttpResponse.builder() .statusCode(200) @@ -170,7 +176,50 @@ public void testResponseDescriptors() throws Exception { } @Test - public void testLambdaInvokeSubsegmentContainsFunctionName() throws Exception { + void manualSegmentPropagation() throws Exception { + AWSXRay.setGlobalRecorder( + AWSXRayRecorderBuilder.standard() + .withEmitter(blankEmitter) + .withContextMissingStrategy(new IgnoreErrorContextMissingStrategy()) + .build()); + Segment segment = AWSXRay.beginSegment("test"); + AWSXRay.clearTraceEntity(); + assertThat(AWSXRay.getCurrentSegment()).isNull(); + + String responseBody = "{\"LastEvaluatedTableName\":\"baz\",\"TableNames\":[\"foo\",\"bar\",\"baz\"]}"; + SdkHttpResponse mockResponse = + SdkHttpResponse.builder() + .statusCode(200) + .putHeader("x-amzn-requestid", "1111-2222-3333-4444") + .putHeader("Content-Length", "84") + .putHeader("Content-Type", "application/x-amz-json-1.0") + .build(); + SdkHttpClient mockClient = mockSdkHttpClient(mockResponse, responseBody); + + DynamoDbClient client = + DynamoDbClient.builder() + .httpClient(mockClient) + .endpointOverride(URI.create("http://example.com")) + .region(Region.of("us-west-42")) + .credentialsProvider(StaticCredentialsProvider.create( + AwsSessionCredentials.create("key", "secret", "session"))) + .overrideConfiguration( + ClientOverrideConfiguration.builder() + .addExecutionInterceptor(new TracingInterceptor()) + .build()) + .build(); + + ListTablesRequest request = ListTablesRequest.builder().limit(3).build(); + TracingInterceptor.unsafeAddSegmentToRequest(segment, request); + client.listTables(request); + + Assert.assertEquals(1, segment.getSubsegments().size()); + Subsegment subsegment = segment.getSubsegments().get(0); + assertThat(subsegment.getName()).isEqualTo("DynamoDb"); + } + + @Test + void testLambdaInvokeSubsegmentContainsFunctionName() throws Exception { SdkHttpClient mockClient = mockSdkHttpClient(generateLambdaInvokeResponse(200)); LambdaClient client = LambdaClient.builder() @@ -213,7 +262,7 @@ public void testLambdaInvokeSubsegmentContainsFunctionName() throws Exception { } @Test - public void testAsyncLambdaInvokeSubsegmentContainsFunctionName() { + void testAsyncLambdaInvokeSubsegmentContainsFunctionName() { SdkAsyncHttpClient mockClient = mockSdkAsyncHttpClient(generateLambdaInvokeResponse(200)); LambdaAsyncClient client = LambdaAsyncClient.builder() @@ -255,7 +304,7 @@ public void testAsyncLambdaInvokeSubsegmentContainsFunctionName() { } @Test - public void test400Exception() throws Exception { + void test400Exception() throws Exception { SdkHttpClient mockClient = mockSdkHttpClient(generateLambdaInvokeResponse(400)); LambdaClient client = LambdaClient.builder() @@ -307,7 +356,7 @@ public void test400Exception() throws Exception { } @Test - public void testAsync400Exception() { + void testAsync400Exception() { SdkAsyncHttpClient mockClient = mockSdkAsyncHttpClient(generateLambdaInvokeResponse(400)); LambdaAsyncClient client = LambdaAsyncClient.builder() @@ -358,7 +407,7 @@ public void testAsync400Exception() { } @Test - public void testThrottledException() throws Exception { + void testThrottledException() throws Exception { SdkHttpClient mockClient = mockSdkHttpClient(generateLambdaInvokeResponse(429)); LambdaClient client = LambdaClient.builder() @@ -410,7 +459,7 @@ public void testThrottledException() throws Exception { } @Test - public void testAsyncThrottledException() { + void testAsyncThrottledException() { SdkAsyncHttpClient mockClient = mockSdkAsyncHttpClient(generateLambdaInvokeResponse(429)); LambdaAsyncClient client = LambdaAsyncClient.builder() @@ -461,7 +510,7 @@ public void testAsyncThrottledException() { } @Test - public void test500Exception() throws Exception { + void test500Exception() throws Exception { SdkHttpClient mockClient = mockSdkHttpClient(generateLambdaInvokeResponse(500)); LambdaClient client = LambdaClient.builder() @@ -513,7 +562,7 @@ public void test500Exception() throws Exception { } @Test - public void testAsync500Exception() { + void testAsync500Exception() { SdkAsyncHttpClient mockClient = mockSdkAsyncHttpClient(generateLambdaInvokeResponse(500)); LambdaAsyncClient client = LambdaAsyncClient.builder()