Skip to content

Commit

Permalink
Merge pull request #2291 from vincepri/tls-opts-getcertificate
Browse files Browse the repository at this point in the history
🌱 Handle TLSOpts.GetCertificate in webhook
  • Loading branch information
k8s-ci-robot authored May 2, 2023
2 parents 0ef0753 + bd12701 commit 94bb74b
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 24 deletions.
53 changes: 30 additions & 23 deletions pkg/webhook/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ type Server struct {
CertDir string

// CertName is the server certificate name. Defaults to tls.crt.
//
// Note: This option should only be set when TLSOpts does not override GetCertificate.
CertName string

// KeyName is the server key name. Defaults to tls.key.
//
// Note: This option should only be set when TLSOpts does not override GetCertificate.
KeyName string

// ClientCAName is the CA certificate name which server used to verify remote(client)'s certificate.
Expand Down Expand Up @@ -169,32 +173,40 @@ func (s *Server) Start(ctx context.Context) error {
baseHookLog := log.WithName("webhooks")
baseHookLog.Info("Starting webhook server")

certPath := filepath.Join(s.CertDir, s.CertName)
keyPath := filepath.Join(s.CertDir, s.KeyName)

certWatcher, err := certwatcher.New(certPath, keyPath)
if err != nil {
return err
}

go func() {
if err := certWatcher.Start(ctx); err != nil {
log.Error(err, "certificate watcher error")
}
}()

tlsMinVersion, err := tlsVersion(s.TLSMinVersion)
if err != nil {
return err
}

cfg := &tls.Config{ //nolint:gosec
NextProtos: []string{"h2"},
GetCertificate: certWatcher.GetCertificate,
MinVersion: tlsMinVersion,
NextProtos: []string{"h2"},
MinVersion: tlsMinVersion,
}
// fallback TLS config ready, will now mutate if passer wants full control over it
for _, op := range s.TLSOpts {
op(cfg)
}

if cfg.GetCertificate == nil {
certPath := filepath.Join(s.CertDir, s.CertName)
keyPath := filepath.Join(s.CertDir, s.KeyName)

// Create the certificate watcher and
// set the config's GetCertificate on the TLSConfig
certWatcher, err := certwatcher.New(certPath, keyPath)
if err != nil {
return err
}
cfg.GetCertificate = certWatcher.GetCertificate

go func() {
if err := certWatcher.Start(ctx); err != nil {
log.Error(err, "certificate watcher error")
}
}()
}

// load CA to verify client certificate
// Load CA to verify client certificate, if configured.
if s.ClientCAName != "" {
certPool := x509.NewCertPool()
clientCABytes, err := os.ReadFile(filepath.Join(s.CertDir, s.ClientCAName))
Expand All @@ -211,11 +223,6 @@ func (s *Server) Start(ctx context.Context) error {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}

// fallback TLS config ready, will now mutate if passer wants full control over it
for _, op := range s.TLSOpts {
op(cfg)
}

listener, err := tls.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)), cfg)
if err != nil {
return err
Expand Down
51 changes: 50 additions & 1 deletion pkg/webhook/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"io"
"net"
"net/http"
"path"
"reflect"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -181,7 +183,7 @@ var _ = Describe("Webhook Server", func() {
}
server.Register("/somepath", &testHandler{})
doneCh := genericStartServer(func(ctx context.Context) {
Expect(server.Start(ctx))
Expect(server.Start(ctx)).To(Succeed())
})

Eventually(func() ([]byte, error) {
Expand All @@ -199,6 +201,53 @@ var _ = Describe("Webhook Server", func() {
ctxCancel()
Eventually(doneCh, "4s").Should(BeClosed())
})

It("should prefer GetCertificate through TLSOpts", func() {
var finalCfg *tls.Config
finalCert, err := tls.LoadX509KeyPair(
path.Join(servingOpts.LocalServingCertDir, "tls.crt"),
path.Join(servingOpts.LocalServingCertDir, "tls.key"),
)
Expect(err).NotTo(HaveOccurred())
finalGetCertificate := func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
return &finalCert, nil
}
server = &webhook.Server{
Host: servingOpts.LocalServingHost,
Port: servingOpts.LocalServingPort,
CertDir: servingOpts.LocalServingCertDir,
TLSMinVersion: "1.2",
TLSOpts: []func(*tls.Config){
func(cfg *tls.Config) {
cfg.GetCertificate = finalGetCertificate
// save cfg after changes to test against
finalCfg = cfg
},
},
}
server.Register("/somepath", &testHandler{})
doneCh := genericStartServer(func(ctx context.Context) {
Expect(server.Start(ctx)).To(Succeed())
})

Eventually(func() ([]byte, error) {
resp, err := client.Get(fmt.Sprintf("https://%s/somepath", testHostPort))
Expect(err).NotTo(HaveOccurred())
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}).Should(Equal([]byte("gadzooks!")))
Expect(finalCfg.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
// We can't compare the functions directly, but we can compare their pointers
if reflect.ValueOf(finalCfg.GetCertificate).Pointer() != reflect.ValueOf(finalGetCertificate).Pointer() {
Fail("GetCertificate was not set properly, or overwritten")
}
cert, err := finalCfg.GetCertificate(nil)
Expect(err).NotTo(HaveOccurred())
Expect(cert).To(BeEquivalentTo(&finalCert))

ctxCancel()
Eventually(doneCh, "4s").Should(BeClosed())
})
})

type testHandler struct {
Expand Down

0 comments on commit 94bb74b

Please # to comment.