Skip to content

Commit

Permalink
support using certs for authenticating azure kv
Browse files Browse the repository at this point in the history
Signed-off-by: sp98 <sapillai@redhat.com>
  • Loading branch information
sp98 committed Feb 14, 2024
1 parent 5f4b25c commit 698ba84
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 9 deletions.
13 changes: 9 additions & 4 deletions azure/azure_kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ const (
AzureClientID = "AZURE_CLIENT_ID"
// AzureClientSecret of service principal account
AzureClientSecret = "AZURE_CLIENT_SECRET"
// AzureClientCertPath is path of a client certificate of service principal account
AzureClientCertPath = "AZURE_CLIENT_CERT_PATH"
// AzureClientCertPassword is the password of the client certificate of service principal account
AzureClientCertPassword = "AZURE_CIENT_CERT_PASSWORD"
// AzureEnviornment to connect
AzureEnviornment = "AZURE_ENVIRONMENT"
// AzureVaultURI of azure key vault
Expand Down Expand Up @@ -61,10 +65,11 @@ func New(
if clientID == "" {
return nil, ErrAzureClientIDNotSet
}

secretID := getAzureKVParams(secretConfig, AzureClientSecret)
if secretID == "" {
return nil, ErrAzureSecretIDNotSet
}
clientCertPath := getAzureKVParams(secretConfig, AzureClientCertPath)
clientCertPassword := getAzureKVParams(secretConfig, AzureClientCertPassword)

envName := getAzureKVParams(secretConfig, AzureEnviornment)
if envName == "" {
// we set back to default AzurePublicCloud
Expand All @@ -75,7 +80,7 @@ func New(
return nil, ErrAzureVaultURLNotSet
}

client, err := getAzureVaultClient(clientID, secretID, tenantID, envName)
client, err := getAzureVaultClient(clientID, secretID, clientCertPath, clientCertPassword, tenantID, envName)
if err != nil {
return nil, err
}
Expand Down
47 changes: 42 additions & 5 deletions azure/azure_kv_helper.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package azure

import (
"crypto/rsa"
"crypto/x509"
"fmt"
"net/url"
"os"
"strings"
Expand All @@ -9,6 +12,7 @@ import (
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"golang.org/x/crypto/pkcs12"
)

func getAzureKVParams(secretConfig map[string]interface{}, name string) string {
Expand All @@ -19,7 +23,7 @@ func getAzureKVParams(secretConfig map[string]interface{}, name string) string {
}
}

func getAzureVaultClient(clientID, secretID, tenantID, envName string) (keyvault.BaseClient, error) {
func getAzureVaultClient(clientID, secretID, certPath, certPassword, tenantID, envName string) (keyvault.BaseClient, error) {
var environment *azure.Environment
alternateEndpoint, _ := url.Parse(
"https://#.windows.net/" + tenantID + "/oauth2/token")
Expand All @@ -37,11 +41,44 @@ func getAzureVaultClient(clientID, secretID, tenantID, envName string) (keyvault
}
oauthconfig.AuthorizeEndpoint = *alternateEndpoint

token, err := adal.NewServicePrincipalToken(
*oauthconfig, clientID, secretID, strings.TrimSuffix(environment.KeyVaultEndpoint, "/"))
if err != nil {
return keyClient, err
var token *adal.ServicePrincipalToken
if secretID != "" {
token, err = adal.NewServicePrincipalToken(
*oauthconfig, clientID, secretID, strings.TrimSuffix(environment.KeyVaultEndpoint, "/"))
if err != nil {
return keyClient, err
}
} else if certPath != "" && certPassword != "" {
certData, err := os.ReadFile(certPath)
if err != nil {
return keyClient, fmt.Errorf("reading the client certificate from file %s: %v", certPath, err)
}
certificate, privateKey, err := decodePkcs12(certData, certPassword)
if err != nil {
return keyClient, fmt.Errorf("decoding the client certificate: %v", err)
}
token, err = adal.NewServicePrincipalTokenFromCertificate(
*oauthconfig, clientID, certificate, privateKey, strings.TrimSuffix(environment.KeyVaultEndpoint, "/"))
if err != nil {
return keyClient, err
}
}

keyClient.Authorizer = autorest.NewBearerAuthorizer(token)
return keyClient, nil
}

// decodePkcs12 decodes a PKCS#12 client certificate by extracting the public certificate and
// the private RSA key
func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
privateKey, certificate, err := pkcs12.Decode(pkcs, password)
if err != nil {
return nil, nil, fmt.Errorf("decoding the PKCS#12 client certificate: %v", err)
}
rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
if !isRsaKey {
return nil, nil, fmt.Errorf("PKCS#12 certificate must contain a RSA private key")
}

return certificate, rsaPrivateKey, nil
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ require (
github.com/onsi/gomega v1.24.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/rogpeppe/go-internal v1.8.1 // indirect
golang.org/x/crypto v0.6.0
golang.org/x/oauth2 v0.0.0-20220524215830-622c5d57e401
golang.org/x/time v0.3.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect
Expand Down

0 comments on commit 698ba84

Please # to comment.