diff --git a/azure/azure_kv.go b/azure/azure_kv.go index 4a7d1e82..ae9b4da0 100644 --- a/azure/azure_kv.go +++ b/azure/azure_kv.go @@ -21,6 +21,10 @@ const ( AzureClientID = "AZURE_CLIENT_ID" // AzureClientSecret of service principal account AzureClientSecret = "AZURE_CLIENT_SECRET" + // AzureClientCertPath is path of a client certificate of service principal account + AzureClientCertPath = "AZURE_CLIENT_CERT_PATH" + // AzureClientCertPassword is the password of the client certificate of service principal account + AzureClientCertPassword = "AZURE_CIENT_CERT_PASSWORD" // AzureEnviornment to connect AzureEnviornment = "AZURE_ENVIRONMENT" // AzureVaultURI of azure key vault @@ -61,10 +65,11 @@ func New( if clientID == "" { return nil, ErrAzureClientIDNotSet } + secretID := getAzureKVParams(secretConfig, AzureClientSecret) - if secretID == "" { - return nil, ErrAzureSecretIDNotSet - } + clientCertPath := getAzureKVParams(secretConfig, AzureClientCertPath) + clientCertPassword := getAzureKVParams(secretConfig, AzureClientCertPassword) + envName := getAzureKVParams(secretConfig, AzureEnviornment) if envName == "" { // we set back to default AzurePublicCloud @@ -75,7 +80,7 @@ func New( return nil, ErrAzureVaultURLNotSet } - client, err := getAzureVaultClient(clientID, secretID, tenantID, envName) + client, err := getAzureVaultClient(clientID, secretID, clientCertPath, clientCertPassword, tenantID, envName) if err != nil { return nil, err } diff --git a/azure/azure_kv_helper.go b/azure/azure_kv_helper.go index 67c97a08..9edc026e 100644 --- a/azure/azure_kv_helper.go +++ b/azure/azure_kv_helper.go @@ -1,6 +1,9 @@ package azure import ( + "crypto/rsa" + "crypto/x509" + "fmt" "net/url" "os" "strings" @@ -9,6 +12,7 @@ import ( "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/adal" "github.com/Azure/go-autorest/autorest/azure" + "golang.org/x/crypto/pkcs12" ) func getAzureKVParams(secretConfig map[string]interface{}, name string) string { @@ -19,7 +23,7 @@ func getAzureKVParams(secretConfig map[string]interface{}, name string) string { } } -func getAzureVaultClient(clientID, secretID, tenantID, envName string) (keyvault.BaseClient, error) { +func getAzureVaultClient(clientID, secretID, certPath, certPassword, tenantID, envName string) (keyvault.BaseClient, error) { var environment *azure.Environment alternateEndpoint, _ := url.Parse( "https://login.windows.net/" + tenantID + "/oauth2/token") @@ -37,11 +41,44 @@ func getAzureVaultClient(clientID, secretID, tenantID, envName string) (keyvault } oauthconfig.AuthorizeEndpoint = *alternateEndpoint - token, err := adal.NewServicePrincipalToken( - *oauthconfig, clientID, secretID, strings.TrimSuffix(environment.KeyVaultEndpoint, "/")) - if err != nil { - return keyClient, err + var token *adal.ServicePrincipalToken + if secretID != "" { + token, err = adal.NewServicePrincipalToken( + *oauthconfig, clientID, secretID, strings.TrimSuffix(environment.KeyVaultEndpoint, "/")) + if err != nil { + return keyClient, err + } + } else if certPath != "" && certPassword != "" { + certData, err := os.ReadFile(certPath) + if err != nil { + return keyClient, fmt.Errorf("reading the client certificate from file %s: %v", certPath, err) + } + certificate, privateKey, err := decodePkcs12(certData, certPassword) + if err != nil { + return keyClient, fmt.Errorf("decoding the client certificate: %v", err) + } + token, err = adal.NewServicePrincipalTokenFromCertificate( + *oauthconfig, clientID, certificate, privateKey, strings.TrimSuffix(environment.KeyVaultEndpoint, "/")) + if err != nil { + return keyClient, err + } } + keyClient.Authorizer = autorest.NewBearerAuthorizer(token) return keyClient, nil } + +// decodePkcs12 decodes a PKCS#12 client certificate by extracting the public certificate and +// the private RSA key +func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) { + privateKey, certificate, err := pkcs12.Decode(pkcs, password) + if err != nil { + return nil, nil, fmt.Errorf("decoding the PKCS#12 client certificate: %v", err) + } + rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey) + if !isRsaKey { + return nil, nil, fmt.Errorf("PKCS#12 certificate must contain a RSA private key") + } + + return certificate, rsaPrivateKey, nil +} diff --git a/go.mod b/go.mod index 374fa370..57e3497b 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/onsi/gomega v1.24.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.8.1 // indirect + golang.org/x/crypto v0.6.0 golang.org/x/oauth2 v0.0.0-20220524215830-622c5d57e401 golang.org/x/time v0.3.0 // indirect google.golang.org/protobuf v1.28.1 // indirect