Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[ACC-1512] JWT signing optimizations #270

Merged
merged 9 commits into from
Jul 30, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package com.contentgrid.gateway.security.jwt.issuer;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.proc.SimpleSecurityContext;
import com.nimbusds.jose.produce.JWSSignerFactory;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import lombok.SneakyThrows;
import org.springframework.util.ConcurrentLruCache;


public class JwkSourceJwtClaimsSigner implements JwtClaimsSigner {

private final Random random;
private final JWKSource<SecurityContext> jwkSource;
private final Set<JWSAlgorithm> algorithms;


public JwkSourceJwtClaimsSigner(JWKSource<SecurityContext> jwkSource, Set<JWSAlgorithm> algorithms) {
this(new DefaultJWSSignerFactory(), new Random(), jwkSource, algorithms);
}

private ConcurrentLruCache<JWK, JWSSigner> signerCache;

public JwkSourceJwtClaimsSigner(JWSSignerFactory jwsSignerFactory, Random random,
JWKSource<SecurityContext> jwkSource, Set<JWSAlgorithm> algorithms) {
this.random = random;
this.jwkSource = jwkSource;
this.algorithms = algorithms;

signerCache = new ConcurrentLruCache<>(10,
key -> {
try {
return jwsSignerFactory.createJWSSigner(key);
} catch (JOSEException e) {
throw new RuntimeException(e);
}
});
}

@SneakyThrows
private List<JWK> getAllSigningKeys() {
return jwkSource.get(new JWKSelector(new JWKMatcher.Builder()
.keyUse(KeyUse.SIGNATURE)
.build()),
new SimpleSecurityContext());
}

@Override
public JWKSet getSigningKeys() {
return new JWKSet(getAllSigningKeys());
}

@Override
@SneakyThrows
public SignedJWT sign(JWTClaimsSet jwtClaimsSet) {
var jwks = new ArrayList<>(getAllSigningKeys());

Collections.shuffle(jwks, this.random); // Randomly shuffle our keys, so we pick an arbitrary one first

Set<JWSAlgorithm> algorithmsSupportedByKeys = new HashSet<>();

for (JWK selectedKey : jwks) {
if (selectedKey.getExpirationTime() != null && !new Date().before(selectedKey.getExpirationTime())) {
// Skip retired keys
continue;
}

var selectedSigner = getJwsSigner(selectedKey);
algorithmsSupportedByKeys.addAll(selectedSigner.supportedJWSAlgorithms());
var firstSupportedAlgorithm = algorithms
.stream()
.filter(selectedSigner.supportedJWSAlgorithms()::contains)
.findFirst();
if (firstSupportedAlgorithm.isEmpty()) {
// Signer does not support any of the signing algorithms; continue to a next key
continue;
}
var signedJwt = new SignedJWT(new JWSHeader.Builder(firstSupportedAlgorithm.get())
.type(JOSEObjectType.JWT)
.keyID(selectedKey.getKeyID())
.build(),
jwtClaimsSet
);
signedJwt.sign(selectedSigner);
return signedJwt;
}
throw new IllegalStateException(
"No active signing keys support any of the configured algorithms (%s); algorithms that can be used by these keys are %s".formatted(
algorithms,
algorithmsSupportedByKeys
));
}

private JWSSigner getJwsSigner(JWK jwk) {
return signerCache.get(jwk);
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package com.contentgrid.gateway.security.jwt.issuer;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;

public interface JwtClaimsSigner {
JWKSet getSigningKeys();
SignedJWT sign(JWTClaimsSet jwtClaimsSet) throws JOSEException;
SignedJWT sign(JWTClaimsSet jwtClaimsSet);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.contentgrid.gateway.security.jwt.issuer;

import com.nimbusds.jose.JWSAlgorithm;
import java.util.Set;

public interface JwtClaimsSignerProperties {

String getActiveKeys();

String getRetiredKeys();

Set<JWSAlgorithm> getAlgorithms();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.cloud.gateway.config.conditional.ConditionalOnEnabledFilter;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.support.ResourcePatternResolver;

@RequiredArgsConstructor
@Configuration(proxyBeanMethods = false)
@EnableConfigurationProperties(ContentgridGatewayJwtProperties.class)
@Slf4j
public class JwtInternalIssuerConfiguration {

@Bean
JwtSignerRegistry jwtSignerRegistry(ContentgridGatewayJwtProperties properties, ApplicationContext applicationContext) {
return new PropertiesBasedJwtSignerRegistry(properties, applicationContext);
JwtSignerRegistry jwtSignerRegistry(ContentgridGatewayJwtProperties gatewayJwtProperties, ResourcePatternResolver resourcePatternResolver) {
return new PropertiesBasedJwtSignerRegistry(gatewayJwtProperties, resourcePatternResolver);
}

@Bean
Expand Down Expand Up @@ -72,10 +72,10 @@ static class ContentgridGatewayJwtProperties {
@Data
@Builder
@AllArgsConstructor
static class JwtSignerProperties implements PropertiesBasedJwtClaimsSigner.JwtClaimsSignerProperties {
static class JwtSignerProperties implements JwtClaimsSignerProperties {
@NotNull
private String activeKeys;
private String allKeys;
private String retiredKeys;
@Builder.Default
private Set<JWSAlgorithm> algorithms = Set.of(JWSAlgorithm.RS256);

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package com.contentgrid.gateway.security.jwt.issuer;

import com.contentgrid.gateway.security.jwt.issuer.JwtInternalIssuerConfiguration.ContentgridGatewayJwtProperties;
import com.contentgrid.gateway.security.jwt.issuer.jwk.source.FilebasedJWKSetSource;
import com.contentgrid.gateway.security.jwt.issuer.jwk.source.LoggingJWKSetSourceEventListener;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.proc.SecurityContext;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -11,13 +17,14 @@
@RequiredArgsConstructor
class PropertiesBasedJwtSignerRegistry implements JwtSignerRegistry {

private final ContentgridGatewayJwtProperties properties;
private final ResourcePatternResolver resourcePatternResolver;
private final Map<String, JwtClaimsSigner> instantiatedSigners = new ConcurrentHashMap<>();
private final ContentgridGatewayJwtProperties gatewayJwtProperties;
private final ResourcePatternResolver resourcePatternResolver;


@Override
public boolean hasSigner(String signerName) {
return properties.getSigners().containsKey(signerName);
return getJwkSourceMap().containsKey(signerName);
}

@Override
Expand All @@ -42,14 +49,38 @@ public JwtClaimsSigner getRequiredSigner(String signerName) {
return instantiatedSigners.computeIfAbsent(signerName, this::createSigner);
}

private Map<String, JWKSource<SecurityContext>> getJwkSourceMap() {
Map<String, JWKSource<SecurityContext>> jwkSourceMap = new HashMap<>();
gatewayJwtProperties.getSigners().keySet().stream().forEach(
signerName -> {
var signerProperties = gatewayJwtProperties.getSigners().get(signerName);
if (signerProperties == null) {
throw new IllegalArgumentException(
"No JWT signer named '%s'. Available signers are %s".formatted(signerName,
gatewayJwtProperties.getSigners().keySet()));
}
var jwkSource = new FilebasedJWKSetSource(
resourcePatternResolver,
signerProperties.getActiveKeys(),
signerProperties.getRetiredKeys()
);
jwkSourceMap.put(signerName, JWKSourceBuilder.create(jwkSource)
.refreshAheadCache(JWKSourceBuilder.DEFAULT_REFRESH_AHEAD_TIME, true, new LoggingJWKSetSourceEventListener<>())
.build());
}
);

return jwkSourceMap;
}

private JwtClaimsSigner createSigner(String signerName) {
var signerProperties = properties.getSigners().get(signerName);
if (signerProperties == null) {
if (!hasSigner(signerName)) {
throw new IllegalArgumentException(
"No JWT signer named '%s'. Available signers are %s".formatted(signerName,
properties.getSigners().keySet()));
getJwkSourceMap().keySet()));
}
return new PropertiesBasedJwtClaimsSigner(signerProperties, resourcePatternResolver);

return new JwkSourceJwtClaimsSigner(getJwkSourceMap().get(signerName), gatewayJwtProperties.getSigners().get(signerName).getAlgorithms());
}

}
Loading