Skip to content

Multi sp #40

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 3 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ AUTH_IDP_METADATA=https://idp.example.net/metadata \
```
--cert string HTTPS Certificate
--db-connection string Database connection string
--db-prefix string Database table prefix
--debug Enable debug logging
-h, --help help for http-auth-server
--idp-certificate string IdP Certificate/Public Key
Expand Down
10 changes: 10 additions & 0 deletions config_multiple.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
service-providers:
- sp-url: http://localhost:9091/a
sp-cert: ./samlsp.crt
sp-key: ./samlsp.key
idp-metadata: https://mocksaml.com/api/saml/metadata
- name: b
sp-url: http://localhost:9091/b
sp-cert: ./samlsp.crt
sp-key: ./samlsp.key
idp-metadata: https://mocksaml.com/api/saml/metadata
6 changes: 6 additions & 0 deletions config_one.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
service-providers:
- name: one
sp-url: http://localhost:9091
sp-cert: ./samlsp.crt
sp-key: ./samlsp.key
idp-metadata: https://mocksaml.com/api/saml/metadata
3 changes: 3 additions & 0 deletions config_single.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
sp-cert: ./samlsp.crt
sp-key: ./samlsp.key
idp-metadata: https://mocksaml.com/api/saml/metadata
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ require (
github.com/spf13/cobra v1.8.1
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.19.0
gitlab.com/andrewheberle/routerswapper v1.2.0
)

require (
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
gitlab.com/andrewheberle/routerswapper v1.2.0 h1:43e23lnlcTI31DoI/4HP2aw27WCgsghLCcezgCCraz0=
gitlab.com/andrewheberle/routerswapper v1.2.0/go.mod h1:olw/7+vGWD6II0k84qQuevoj46o5DIcG1OvM9MmyW5Q=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
Expand Down
234 changes: 136 additions & 98 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"gitlab.com/andrewheberle/routerswapper"
)

var rootCmd = &cobra.Command{
Expand Down Expand Up @@ -51,13 +50,14 @@ func init() {
rootCmd.Flags().String("idp-certificate", "", "IdP Certificate/Public Key")
rootCmd.Flags().String("db-connection", "", "Database connection string")
rootCmd.Flags().String("db-prefix", "", "Database table prefix")
rootCmd.Flags().StringP("config", "c", "", "Configuration file")
rootCmd.Flags().Bool("debug", false, "Enable debug logging")

// flag requirements
rootCmd.MarkFlagsRequiredTogether("cert", "key")
rootCmd.MarkFlagsRequiredTogether("sp-cert", "sp-key")
rootCmd.MarkFlagRequired("sp-cert")
rootCmd.MarkFlagRequired("sp-key")
rootCmd.MarkFlagsRequiredTogether("cert", "key")
rootCmd.MarkFlagsRequiredTogether("idp-issuer", "idp-sso-endpoint", "idp-certificate")
rootCmd.MarkFlagsMutuallyExclusive("idp-metadata", "idp-issuer")
rootCmd.MarkFlagsMutuallyExclusive("idp-metadata", "idp-sso-endpoint")
Expand All @@ -74,14 +74,45 @@ func initConfig() {
// bind flags to viper
viper.BindPFlags(rootCmd.Flags())

// set any flags found in environment via viper
// load config file if flag is set
if config := viper.GetString("config"); config != "" {
viper.SetConfigFile(config)
if err := viper.ReadInConfig(); err != nil {
slog.Error("problem loading configuration", "error", err)
os.Exit(1)
}

// set sp-cert and sp-key to something just to allow things to work when using multiple SP's
for _, name := range []string{"sp-cert", "sp-key"} {
if !viper.IsSet(name) {
rootCmd.Flags().Set(name, "unused")
}
}
}

// set any flags found in environment/config via viper
rootCmd.Flags().VisitAll(func(f *pflag.Flag) {
if viper.IsSet(f.Name) && viper.GetString(f.Name) != "" {
slog.Info("setting flag", "name", f.Name, "value", viper.GetString(f.Name))
rootCmd.Flags().Set(f.Name, viper.GetString(f.Name))
}
})
}

type serviceProvider struct {
Name string `mapstructure:"name"`
ServiceProviderURL string `mapstructure:"sp-url"`
ServiceProviderClaimMapping map[string]string `mapstructure:"sp-claim-mapping"`
ServiceProviderCertificate string `mapstructure:"sp-cert"`
ServiceProviderKey string `mapstructure:"sp-key"`
IdPMetadata string `mapstructure:"idp-metadata"`
IdPIssuer string `mapstructure:"idp-issuer"`
IdPSSOEndpoint string `mapstructure:"idp-sso-endpoint"`
IdPCertificate string `mapstructure:"idp-certificate"`
DatabaseConnection string `mapstructure:"db-connection"`
DatabaseTablePrefix string `mapstructure:"db-prefix"`
}

func runRootCmd() error {
// logging setup
var logLevel = new(slog.LevelVar)
Expand All @@ -91,76 +122,131 @@ func runRootCmd() error {
logLevel.Set(slog.LevelDebug)
}

// validate service provider root url
root, err := url.Parse(viper.GetString("sp-url"))
if err != nil {
return fmt.Errorf("problem with SP URL: %w", err)
}
// did we load in via a config file
var serviceProviders []serviceProvider
if viper.ConfigFileUsed() != "" {
// has a list of service providers been provided?
if viper.Get("service-providers") != nil {
if err := viper.UnmarshalKey("service-providers", &serviceProviders); err != nil {
return fmt.Errorf("error with service providers list: %w", err)
}
} else {
var sp serviceProvider
if err := viper.Unmarshal(&sp); err != nil {
return fmt.Errorf("error with service provider: %w", err)
}

// set up service provider options
opts := []sp.ServiceProviderOption{
sp.WithClaimMapping(viper.GetStringMapString("sp-claim-mapping")),
serviceProviders = []serviceProvider{sp}
}
}

// handle metadata
if m := viper.GetString("idp-metadata"); m != "" {
metadata, err := url.Parse(m)
// create run group
g := run.Group{}

// new mux
mux := http.NewServeMux()

// set up service provider(s)
for _, spConfig := range serviceProviders {
// validate service provider root url
root, err := url.Parse(spConfig.ServiceProviderURL)
if err != nil {
return fmt.Errorf("problem parsing IdP metadata url: %w", err)
return fmt.Errorf("problem with SP URL: %w", err)
}

opts = append(opts, sp.WithMetadataURL(metadata))
} else {
metadata := sp.ServiceProviderMetadata{
Issuer: viper.GetString("idp-issuer"),
Endpoint: viper.GetString("idp-sso-endpoint"),
NameId: "persistent",
Certificate: viper.GetString("idp-certificate"),
// set up service provider options
opts := []sp.ServiceProviderOption{
sp.WithClaimMapping(spConfig.ServiceProviderClaimMapping),
}

opts = append(opts, sp.WithCustomMetadata(metadata))
}
// handle metadata
if spConfig.IdPMetadata != "" {
metadata, err := url.Parse(spConfig.IdPMetadata)
if err != nil {
return fmt.Errorf("problem parsing IdP metadata url: %w", err)
}

// are we using a database for storing session attributes
if dsn := viper.GetString("db-connection"); dsn != "" {
store, err := sp.NewDbAttributeStore(viper.GetString("db-prefix"), dsn)
opts = append(opts, sp.WithMetadataURL(metadata))
} else {
metadata := sp.ServiceProviderMetadata{
Issuer: spConfig.IdPIssuer,
Endpoint: spConfig.IdPSSOEndpoint,
Certificate: spConfig.IdPCertificate,
}

opts = append(opts, sp.WithCustomMetadata(metadata))
}

// are we using a database for storing session attributes
if dsn := spConfig.DatabaseConnection; dsn != "" {
store, err := sp.NewDbAttributeStore(spConfig.DatabaseTablePrefix, dsn)
if err != nil {
return fmt.Errorf("problem setting up db attribute store: %w", err)
}
defer store.Close()

opts = append(opts, sp.WithAttributeStore(store))
}

// set Service Provider name if provided
if spConfig.Name != "" {
opts = append(opts, sp.WithName(spConfig.Name))
}

// set up auth provider
provider, err := sp.NewServiceProvider(spConfig.ServiceProviderCertificate, spConfig.ServiceProviderKey, root, opts...)
if err != nil {
return fmt.Errorf("problem setting up db attribute store: %w", err)
return fmt.Errorf("problem setting up SP: %w", err)
}
defer store.Close()

opts = append(opts, sp.WithAttributeStore(store))
}
// set up refresh/reload of service provider metdata
if spConfig.IdPMetadata != "" {
quit := make(chan struct{})
g.Add(func() error {
slog.Info("service provider refresh", "action", "started", "next", time.Now().Add(time.Hour*24))
for {
select {
case <-quit:
return nil
default:
if err := provider.RefreshMetadata(); err != nil {
// not a fatal error
slog.Error("saml service provider reload", "error", err)
continue
}
}

// set up auth provider
provider, err := sp.NewServiceProvider(viper.GetString("sp-cert"), viper.GetString("sp-key"), root, opts...)
if err != nil {
return fmt.Errorf("problem setting up SP: %w", err)
}
// some logging
slog.Info("service provider refresh", "action", "refreshed", "next", time.Now().Add(time.Hour*24))
}
}, func(err error) {
slog.Info("service provider refresh", "action", "shutting down")
close(quit)
})
}

// new server mux
mux := sp.NewMux(provider)
// new server mux
if err := provider.NewMux(mux); err != nil {
return fmt.Errorf("error setting up mux: %w", err)
}

// allow swapping of mux
rs := routerswapper.New(mux)
slog.Info("set up service provider",
"acs-url", provider.AcsURL().String(),
"metdata-url", provider.MetadataURL().String(),
"logout-url", provider.LogoutUrl().String(),
"name", spConfig.Name,
)
}

// set up server
srv := &http.Server{
Addr: viper.GetString("listen"),
Handler: rs,
Handler: mux,
ReadTimeout: time.Second * 3,
WriteTimeout: time.Second * 3,
}

slog.Info("starting service",
"listen", srv.Addr,
"sp-acs-url", provider.AcsURL().String(),
"sp-metdata-url", provider.MetadataURL().String(),
"sp-logout-url", provider.LogoutUrl().String(),
)

// create run group
g := run.Group{}
slog.Info("starting service", "listen", srv.Addr)

// add http server
if viper.GetString("cert") == "" && viper.GetString("key") == "" {
Expand Down Expand Up @@ -213,54 +299,6 @@ func runRootCmd() error {
})
}

// set up refresh/reload of service provider metdata
if viper.GetString("idp-metadata") != "" {
quit := make(chan struct{})
g.Add(func() error {
slog.Info("service provider refresh", "action", "started", "next", time.Now().Add(time.Hour*24))
for {
select {
case <-quit:
return nil
default:
time.Sleep(time.Hour * 24)

// parse url
metadata, _ := url.Parse(viper.GetString("idp-metadata"))
if err != nil {
return fmt.Errorf("problem parsing IdP metadata url: %w", err)
}

// set up service provider options
opts := []sp.ServiceProviderOption{
sp.WithClaimMapping(viper.GetStringMapString("sp-claim-mapping")),
sp.WithMetadataURL(metadata),
}

// set up provider
provider, err := sp.NewServiceProvider(viper.GetString("sp-cert"), viper.GetString("sp-key"), root, opts...)
if err != nil {
// not a fatal error
slog.Error("saml service provider reload", "error", err)
continue
}

// new server mux
mux := sp.NewMux(provider)

// swap to new mux
rs.Swap(mux)
}

// some logging
slog.Info("service provider refresh", "action", "refreshed", "next", time.Now().Add(time.Hour*24))
}
}, func(err error) {
slog.Info("service provider refresh", "action", "shutting down")
close(quit)
})
}

if err := g.Run(); err != nil {
return fmt.Errorf("problem while running: %w", err)
}
Expand Down
17 changes: 15 additions & 2 deletions pkg/sp/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type ServiceProviderOption func(*ServiceProvider)

func WithMetadataURL(metadata *url.URL) ServiceProviderOption {
return func(s *ServiceProvider) {
// populate metadata either from a metadata URL or from custom values
// populate metadata from a metadata URL
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

Expand All @@ -26,13 +26,14 @@ func WithMetadataURL(metadata *url.URL) ServiceProviderOption {
}

s.idpMetadata = idpMetadata
s.idpMetadataURL = metadata
}
}

func WithCustomMetadata(metadata ServiceProviderMetadata) ServiceProviderOption {
return func(s *ServiceProvider) {
// build metadata from provided values
b, err := buildMetadata(metadata.Issuer, metadata.Endpoint, metadata.NameId, metadata.Certificate)
b, err := buildMetadata(metadata.Issuer, metadata.Endpoint, metadata.Certificate)
if err != nil {
slog.Error("metadata build error", "error", err)
return
Expand All @@ -59,3 +60,15 @@ func WithAttributeStore(store AttributeStore) ServiceProviderOption {
s.store = store
}
}

func WithMetadataRefreshInterval(d time.Duration) ServiceProviderOption {
return func(s *ServiceProvider) {
s.idpMetadataRefreshInterval = d
}
}

func WithName(name string) ServiceProviderOption {
return func(s *ServiceProvider) {
s.name = name
}
}
Loading
Loading