diff --git a/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/AWSClientBuilder.java b/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/AWSClientBuilder.java new file mode 100644 index 0000000..85bf254 --- /dev/null +++ b/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/AWSClientBuilder.java @@ -0,0 +1,93 @@ +package edu.harvard.dbmi.avillach.dataupload.aws; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Profile; +import org.springframework.stereotype.Service; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3ClientBuilder; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; +import software.amazon.awssdk.services.sts.model.AssumeRoleResponse; +import software.amazon.awssdk.services.sts.model.Credentials; + +import java.util.Map; +import java.util.Optional; + +@Profile("!dev") +@Service +public class AWSClientBuilder { + + private static final Logger log = LoggerFactory.getLogger(AWSClientBuilder.class); + + private final Map sites; + private final StsClientProvider stsClientProvider; + private final S3ClientBuilder s3ClientBuilder; + private final SdkHttpClient sdkHttpClient; + + @Autowired + public AWSClientBuilder( + Map sites, + StsClientProvider stsClientProvider, + S3ClientBuilder s3ClientBuilder, + @Autowired(required = false) SdkHttpClient sdkHttpClient + ) { + this.sites = sites; + this.stsClientProvider = stsClientProvider; + this.s3ClientBuilder = s3ClientBuilder; + this.sdkHttpClient = sdkHttpClient; + } + + public Optional buildClientForSite(String siteName) { + log.info("Building client for site {}", siteName); + if (!sites.containsKey(siteName)) { + log.warn("Could not find site {}", siteName); + return Optional.empty(); + } + + log.info("Found site, making assume role request"); + SiteAWSInfo site = sites.get(siteName); + AssumeRoleRequest roleRequest = AssumeRoleRequest.builder() + .roleArn(site.roleARN()) + .roleSessionName("test_session" + System.nanoTime()) + .externalId(site.externalId()) + .durationSeconds(60*60) // 1 hour + .build(); + Optional assumeRoleResponse = stsClientProvider.createClient() + .map(c -> c.assumeRole(roleRequest)) + .map(AssumeRoleResponse::credentials); + if (assumeRoleResponse.isEmpty() ) { + log.error("Error assuming role {} , no credentials returned", site.roleARN()); + return Optional.empty(); + } + log.info("Successfully assumed role {} for site {}", site.roleARN(), site.siteName()); + + log.info("Building S3 client for site {}", site.siteName()); + // Use the credentials from the role to create the S3 client + Credentials credentials = assumeRoleResponse.get(); + AwsSessionCredentials sessionCredentials = AwsSessionCredentials.builder() + .accessKeyId(credentials.accessKeyId()) + .secretAccessKey(credentials.secretAccessKey()) + .sessionToken(credentials.sessionToken()) + .expirationTime(credentials.expiration()) + .build(); + StaticCredentialsProvider provider = StaticCredentialsProvider.create(sessionCredentials); + return Optional.of(buildFromProvider(provider)); + } + + private S3Client buildFromProvider(StaticCredentialsProvider provider) { + if (sdkHttpClient == null) { + return s3ClientBuilder.credentialsProvider(provider).build(); + } + log.info("Http proxy detected and added to S3 client"); + return s3ClientBuilder + .credentialsProvider(provider) + .httpClient(sdkHttpClient) + .build(); + + } + +} diff --git a/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/AWSConfiguration.java b/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/AWSConfiguration.java index 7c7203b..0836dd9 100644 --- a/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/AWSConfiguration.java +++ b/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/AWSConfiguration.java @@ -11,9 +11,12 @@ import org.springframework.context.annotation.Configuration; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.util.StringUtils; +import org.springframework.web.context.annotation.RequestScope; import software.amazon.awssdk.auth.credentials.*; import software.amazon.awssdk.http.SdkHttpClient; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3ClientBuilder; import software.amazon.awssdk.services.sts.StsClient; import software.amazon.awssdk.services.sts.StsClientBuilder; import software.amazon.encryption.s3.S3EncryptionClient; @@ -82,4 +85,15 @@ StsClientBuilder stsClientBuilder() { // This is a bean for mocking purposes return StsClient.builder(); } + + @Bean + S3ClientBuilder s3ClientBuilder() { + return S3Client.builder(); + } + + @Bean + @RequestScope + StsClient getStsClient() { + return StsClient.builder().region(Region.US_EAST_1).build(); + } } diff --git a/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/S3StateVerifier.java b/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/S3StateVerifier.java index 911ce7f..e116a86 100644 --- a/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/S3StateVerifier.java +++ b/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/S3StateVerifier.java @@ -28,7 +28,7 @@ public class S3StateVerifier { private Map sites; @Autowired - private SelfRefreshingS3Client client; + private AWSClientBuilder clientBuilder; @PostConstruct private void verifyS3Status() { @@ -39,7 +39,7 @@ private void verifyS3Status() { private void asyncVerify(SiteAWSInfo institution) { LOG.info("Checking S3 connection to {} ...", institution.siteName()); createTempFileWithText(institution) - .map(p -> uploadFileFromPath(p, institution)) + .flatMap(p -> uploadFileFromPath(p, institution)) .map(this::waitABit) .flatMap(s1 -> deleteFileFromBucket(s1, institution)) .orElseThrow(); @@ -49,8 +49,10 @@ private void asyncVerify(SiteAWSInfo institution) { private Optional deleteFileFromBucket(String s, SiteAWSInfo info) { LOG.info("Verifying delete capabilities"); DeleteObjectRequest request = DeleteObjectRequest.builder().bucket(info.bucket()).key(s).build(); - DeleteObjectResponse deleteObjectResponse = client.getS3Client(info.siteName()).deleteObject(request); - return deleteObjectResponse.deleteMarker() ? Optional.of(s) : Optional.empty(); + return clientBuilder.buildClientForSite(info.siteName()) + .map(c -> c.deleteObject(request)) + .map(DeleteObjectResponse::deleteMarker) + .map((ignored) -> s); } private String waitABit(String s) { @@ -62,7 +64,7 @@ private String waitABit(String s) { return s; } - private String uploadFileFromPath(Path p, SiteAWSInfo info) { + private Optional uploadFileFromPath(Path p, SiteAWSInfo info) { LOG.info("Verifying upload capabilities"); RequestBody body = RequestBody.fromFile(p.toFile()); PutObjectRequest request = PutObjectRequest.builder() @@ -71,8 +73,9 @@ private String uploadFileFromPath(Path p, SiteAWSInfo info) { .ssekmsKeyId(info.kmsKeyID()) .key(p.getFileName().toString()) .build(); - client.getS3Client(info.siteName()).putObject(request, body); - return p.getFileName().toString(); + return clientBuilder.buildClientForSite(info.siteName()) + .map(client -> client.putObject(request, body)) + .map(resp -> p.getFileName().toString()); } private Optional createTempFileWithText(SiteAWSInfo info) { diff --git a/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/SelfRefreshingS3Client.java b/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/SelfRefreshingS3Client.java deleted file mode 100644 index ce0d97b..0000000 --- a/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/SelfRefreshingS3Client.java +++ /dev/null @@ -1,159 +0,0 @@ -package edu.harvard.dbmi.avillach.dataupload.aws; - -import edu.harvard.dbmi.avillach.dataupload.status.StatusService; -import jakarta.annotation.PostConstruct; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.context.ConfigurableApplicationContext; -import org.springframework.stereotype.Service; -import org.springframework.util.StringUtils; -import software.amazon.awssdk.auth.credentials.AwsCredentials; -import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.http.SdkHttpClient; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.S3ClientBuilder; -import software.amazon.awssdk.services.sts.StsClient; -import software.amazon.awssdk.services.sts.StsClientBuilder; -import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; -import software.amazon.awssdk.services.sts.model.AssumeRoleResponse; -import software.amazon.awssdk.services.sts.model.Credentials; - -import java.time.Duration; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.function.Function; -import java.util.stream.Collectors; - -/** - * In order to make s3 requests across accounts, we have to assume a role in AWS - * These operations last 1 hour, after which point requests will 401/3 - * This class wraps the S3 client in a getter, and runs an automated task to refresh - * the client when the token expires. Requests to the getter will block while this - * refresh task is running - */ -@ConditionalOnProperty(name = "production", havingValue = "true") -@Service -public class SelfRefreshingS3Client { - private static final Logger LOG = LoggerFactory.getLogger(SelfRefreshingS3Client.class); - private Map locks; - private Map clients = new HashMap<>(); - - @Autowired - private ConfigurableApplicationContext context; - - @Autowired - private Map roleARNs; - - @Autowired - StatusService statusService; - - @Autowired(required = false) - private SdkHttpClient sdkHttpClient; - - @Autowired - private AWSCredentialsService credentialsService; - - @Autowired - private StsClientBuilder stsClientBuilder; - - @Value("${http.proxyUser:}") - private String proxyUser; - - @PostConstruct - private void refreshClient() { - locks = roleARNs.keySet().stream() - .collect(Collectors.toMap(Function.identity(), (s) -> new ReentrantReadWriteLock())); - roleARNs.keySet().stream().parallel().forEach(this::refreshClient); - } - - private StsClient createStsClient() { - StsClientBuilder builder = stsClientBuilder - .region(Region.US_EAST_1) - .credentialsProvider(StaticCredentialsProvider.create(credentialsService.constructCredentials())); - - if (StringUtils.hasLength(proxyUser)) { - builder.httpClient(sdkHttpClient); - } - - return builder.build(); - } - - // exposed for testing - void refreshClient(String siteName) { - LOG.info("Starting client refresh for {}", siteName); - - // block further s3 calls while we refresh - LOG.info("Locking s3 client while refreshing session"); - locks.get(siteName).writeLock().lock(); - statusService.setClientStatus("initializing"); - - // assume the role - LOG.info("Attempting to assume data uploader role"); - AssumeRoleRequest roleRequest = AssumeRoleRequest.builder() - .roleArn(roleARNs.get(siteName).roleARN()) - .roleSessionName("test_session" + System.nanoTime()) - .externalId(roleARNs.get(siteName).externalId()) - .durationSeconds(60*60) // 1 hour - .build(); - AssumeRoleResponse assumeRoleResponse = createStsClient().assumeRole(roleRequest); - if (assumeRoleResponse.credentials() == null ) { - LOG.error("Error assuming role, no credentials returned! Exiting!"); - statusService.setClientStatus("error"); - context.close(); - } - LOG.info("Successfully assumed role, using credentials to create new S3 client"); - - // Use the credentials from the role to create the S3 client - Credentials credentials = assumeRoleResponse.credentials(); - AwsSessionCredentials sessionCredentials = AwsSessionCredentials.builder() - .accessKeyId(credentials.accessKeyId()) - .secretAccessKey(credentials.secretAccessKey()) - .sessionToken(credentials.sessionToken()) - .expirationTime(credentials.expiration()) - .build(); - StaticCredentialsProvider provider = StaticCredentialsProvider.create(sessionCredentials); - S3ClientBuilder builder = S3Client.builder() - .credentialsProvider(provider) - .region(Region.US_EAST_1); - if (sdkHttpClient != null) { - builder.httpClient(sdkHttpClient); - } - LOG.info("Created S3 client"); - clients.put(siteName, builder.build()); - // now that client is refreshed, unlock for reading - LOG.info("Unlocking s3 client. Session refreshed"); - locks.get(siteName).writeLock().unlock(); - statusService.setClientStatus("ready"); - - // create virtual thread to handle next refresh, to occur 5 mins before session expires. - Thread.ofVirtual().start(() -> delayedRefresh(credentials.expiration().minus(5, ChronoUnit.MINUTES), siteName)); - } - - private void delayedRefresh(Instant refresh, String siteName) { - LOG.info("Next refresh will be at {}", refresh); - try { - Thread.sleep(Duration.between(Instant.now(), refresh)); - } catch (InterruptedException e) { - LOG.warn("Couldn't wait. Refreshing early", e); - } - LOG.info("Refreshing s3 client"); - refreshClient(siteName); - } - - public S3Client getS3Client(String siteName) { - S3Client client; - locks.get(siteName).readLock().lock(); - client = clients.get(siteName); - locks.get(siteName).readLock().unlock(); - return client; - } -} diff --git a/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/StsClientProvider.java b/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/StsClientProvider.java new file mode 100644 index 0000000..6ea8652 --- /dev/null +++ b/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/aws/StsClientProvider.java @@ -0,0 +1,20 @@ +package edu.harvard.dbmi.avillach.dataupload.aws; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sts.StsClient; + +import java.util.Optional; + +@Service +public class StsClientProvider { + + private static final Logger log = LoggerFactory.getLogger(StsClientProvider.class); + + public Optional createClient() { + StsClient client = StsClient.builder().region(Region.US_EAST_1).build(); + return Optional.of(client); + } +} diff --git a/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/upload/DataUploadService.java b/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/upload/DataUploadService.java index be2096c..a5cadab 100644 --- a/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/upload/DataUploadService.java +++ b/uploader/src/main/java/edu/harvard/dbmi/avillach/dataupload/upload/DataUploadService.java @@ -1,23 +1,22 @@ package edu.harvard.dbmi.avillach.dataupload.upload; -import edu.harvard.dbmi.avillach.dataupload.aws.SelfRefreshingS3Client; +import edu.harvard.dbmi.avillach.dataupload.aws.AWSClientBuilder; import edu.harvard.dbmi.avillach.dataupload.aws.SiteAWSInfo; import edu.harvard.dbmi.avillach.dataupload.hpds.HPDSClient; import edu.harvard.dbmi.avillach.dataupload.hpds.hpdsartifactsdonotchange.Query; import edu.harvard.dbmi.avillach.dataupload.status.DataUploadStatuses; import edu.harvard.dbmi.avillach.dataupload.status.UploadStatus; import edu.harvard.dbmi.avillach.dataupload.status.StatusService; -import jakarta.annotation.PostConstruct; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.context.annotation.Bean; import org.springframework.stereotype.Service; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3ClientBuilder; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; @@ -44,7 +43,7 @@ public class DataUploadService { private String home; @Autowired - private SelfRefreshingS3Client s3; + private AWSClientBuilder s3ClientBuilder; @Autowired private HPDSClient hpds; @@ -127,11 +126,12 @@ private boolean uploadFileFromPath(Path p, SiteAWSInfo site, String dir) { .ssekmsKeyId(site.kmsKeyID()) .key(Path.of(dir, home + "_" + p.getFileName().toString()).toString()) .build(); - s3.getS3Client(site.siteName()).putObject(request, body); + return s3ClientBuilder.buildClientForSite(site.siteName()) + .map(client -> client.putObject(request, body)) + .isPresent(); } catch (AwsServiceException | SdkClientException e) { LOG.info("Error uploading file from {} to bucket {}", p, site.bucket(), e); return false; } - return true; } } diff --git a/uploader/src/main/resources/application.properties b/uploader/src/main/resources/application.properties index 49fda4d..0daf21c 100644 --- a/uploader/src/main/resources/application.properties +++ b/uploader/src/main/resources/application.properties @@ -21,3 +21,5 @@ institution.name=${HOME_INSTITUTION_NAME} institution.short-display=${HOME_INSTITUTION_DISPLAY} institution.long-display=${HOME_INSTITUTION_LONG_DISPLAY} server.port=${PORT:80} + +spring.profiles.active=prod \ No newline at end of file diff --git a/uploader/src/test/java/edu/harvard/dbmi/avillach/dataupload/aws/AWSClientBuilderTest.java b/uploader/src/test/java/edu/harvard/dbmi/avillach/dataupload/aws/AWSClientBuilderTest.java new file mode 100644 index 0000000..2f39dae --- /dev/null +++ b/uploader/src/test/java/edu/harvard/dbmi/avillach/dataupload/aws/AWSClientBuilderTest.java @@ -0,0 +1,125 @@ +package edu.harvard.dbmi.avillach.dataupload.aws; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatcher; +import org.mockito.Mockito; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.boot.test.mock.mockito.SpyBean; +import org.springframework.context.annotation.Profile; +import org.springframework.test.context.ActiveProfiles; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3ClientBuilder; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; +import software.amazon.awssdk.services.sts.model.AssumeRoleResponse; +import software.amazon.awssdk.services.sts.model.Credentials; + +import java.time.Instant; +import java.util.Map; +import java.util.Optional; + +@ActiveProfiles("aws_mock") +@SpringBootTest +class AWSClientBuilderTest { + + @MockBean + Map sites; + + @MockBean + StsClient stsClient; + + @MockBean + StsClientProvider stsClientProvider; + + @MockBean + S3ClientBuilder s3ClientBuilder; + + @Autowired + AWSClientBuilder subject; + + @Test + void shouldNotBuildClientIfSiteDNE() { + Mockito.when(sites.get("Narnia")) + .thenReturn(null); + + Optional actual = subject.buildClientForSite("Narnia"); + Optional expected = Optional.empty(); + + Assertions.assertEquals(expected, actual); + } + + @Test + void shouldNotBuildClientIfRoleRequestFails() { + SiteAWSInfo siteAWSInfo = new SiteAWSInfo("bch", "aws:arn:420", "external", "bucket", "aws:kms:420"); + Mockito.when(sites.get("bch")) + .thenReturn(siteAWSInfo); + Mockito.when(sites.containsKey("bch")).thenReturn(true); + + ArgumentMatcher requestMatcher = + (r) -> r.roleArn().equals("aws:arn:420") + && r.roleSessionName().startsWith("test_session") + && r.externalId().equals("external") + && r.durationSeconds().equals(3600); + AssumeRoleResponse response = Mockito.mock(AssumeRoleResponse.class); + Mockito.when(stsClient.assumeRole(Mockito.argThat(requestMatcher))) + .thenReturn(response); + Mockito.when(stsClientProvider.createClient()) + .thenReturn(Optional.of(stsClient)); + + Optional actual = subject.buildClientForSite("bch"); + Optional expected = Optional.empty(); + + Assertions.assertEquals(expected, actual); + } + + @Test + void shouldBuildClient() { + SiteAWSInfo siteAWSInfo = new SiteAWSInfo("bch", "aws:arn:420", "external", "bucket", "aws:kms:420"); + Mockito.when(sites.get("bch")) + .thenReturn(siteAWSInfo); + Mockito.when(sites.containsKey("bch")).thenReturn(true); + + Credentials credentials = Mockito.mock(Credentials.class); + Mockito.when(credentials.accessKeyId()).thenReturn("access_key_id"); + Mockito.when(credentials.secretAccessKey()).thenReturn("secret"); + Mockito.when(credentials.sessionToken()).thenReturn("session"); + Mockito.when(credentials.expiration()).thenReturn(Instant.MAX); + AssumeRoleResponse assumeRoleResponse = Mockito.mock(AssumeRoleResponse.class); + Mockito.when(assumeRoleResponse.credentials()) + .thenReturn(credentials); + AwsSessionCredentials sessionCredentials = AwsSessionCredentials.builder() + .accessKeyId(credentials.accessKeyId()) + .secretAccessKey(credentials.secretAccessKey()) + .sessionToken(credentials.sessionToken()) + .expirationTime(credentials.expiration()) + .build(); + ArgumentMatcher requestMatcher = + (r) -> r.roleArn().equals("aws:arn:420") + && r.roleSessionName().startsWith("test_session") + && r.externalId().equals("external") + && r.durationSeconds().equals(3600); + Mockito.when(stsClient.assumeRole(Mockito.argThat(requestMatcher))) + .thenReturn(assumeRoleResponse); + Mockito.when(stsClientProvider.createClient()) + .thenReturn(Optional.of(stsClient)); + + StaticCredentialsProvider provider = StaticCredentialsProvider.create(sessionCredentials); + ArgumentMatcher credsMatcher = (AwsCredentialsProvider p) -> p.toString().equals(provider.toString()); + S3Client s3Client = Mockito.mock(S3Client.class); + Mockito.when(s3ClientBuilder.credentialsProvider(Mockito.argThat(credsMatcher))) + .thenReturn(s3ClientBuilder); + Mockito.when(s3ClientBuilder.build()) + .thenReturn(s3Client); + + Optional actual = subject.buildClientForSite("bch"); + Optional expected = Optional.of(s3Client); + + Assertions.assertEquals(expected, actual); + } +} \ No newline at end of file diff --git a/uploader/src/test/java/edu/harvard/dbmi/avillach/dataupload/upload/DataUploadServiceTest.java b/uploader/src/test/java/edu/harvard/dbmi/avillach/dataupload/upload/DataUploadServiceTest.java index 4777c6c..a4144c1 100644 --- a/uploader/src/test/java/edu/harvard/dbmi/avillach/dataupload/upload/DataUploadServiceTest.java +++ b/uploader/src/test/java/edu/harvard/dbmi/avillach/dataupload/upload/DataUploadServiceTest.java @@ -1,13 +1,12 @@ package edu.harvard.dbmi.avillach.dataupload.upload; -import edu.harvard.dbmi.avillach.dataupload.aws.SelfRefreshingS3Client; +import edu.harvard.dbmi.avillach.dataupload.aws.AWSClientBuilder; import edu.harvard.dbmi.avillach.dataupload.aws.SiteAWSInfo; import edu.harvard.dbmi.avillach.dataupload.hpds.HPDSClient; import edu.harvard.dbmi.avillach.dataupload.hpds.hpdsartifactsdonotchange.Query; import edu.harvard.dbmi.avillach.dataupload.status.StatusService; import edu.harvard.dbmi.avillach.dataupload.status.UploadStatus; import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.mockito.InjectMocks; @@ -26,6 +25,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Map; +import java.util.Optional; import java.util.concurrent.Semaphore; @SpringBootTest @@ -47,7 +47,7 @@ class DataUploadServiceTest { S3Client s3Client; @Mock - SelfRefreshingS3Client s3; + AWSClientBuilder s3; @InjectMocks DataUploadService subject; @@ -101,7 +101,7 @@ void shouldNotUploadDataIfAWSUpset(@TempDir Path tempDir) throws IOException, In Mockito.when(sharingRoot.toString()).thenReturn(tempDir.toString()); Mockito.when(hpds.writePhenotypicData(q)).thenReturn(true); - Mockito.when(s3.getS3Client("bch")).thenReturn(s3Client); + Mockito.when(s3.buildClientForSite("bch")).thenReturn(Optional.of(s3Client)); Mockito.when(s3Client.putObject(Mockito.any(PutObjectRequest.class), Mockito.any(RequestBody.class))) .thenThrow(AwsServiceException.builder().build()); @@ -128,7 +128,7 @@ void shouldUploadData(@TempDir Path tempDir) throws IOException, InterruptedExce Mockito.when(sharingRoot.toString()).thenReturn(tempDir.toString()); Mockito.when(hpds.writePhenotypicData(q)).thenReturn(true); - Mockito.when(s3.getS3Client("bch")).thenReturn(s3Client); + Mockito.when(s3.buildClientForSite("bch")).thenReturn(Optional.of(s3Client)); Mockito.when(s3Client.putObject(Mockito.any(PutObjectRequest.class), Mockito.any(RequestBody.class))) .thenReturn(Mockito.mock(PutObjectResponse.class)); diff --git a/uploader/src/test/resources/application.properties b/uploader/src/test/resources/application.properties index eb615a0..a1283c0 100644 --- a/uploader/src/test/resources/application.properties +++ b/uploader/src/test/resources/application.properties @@ -11,3 +11,5 @@ aws.region=us-east-1 institution.name=bch institution.short-display=BCH institution.long-display=Boston Children's Hospital + +spring.profiles.active=dev \ No newline at end of file