From 66cb58634d6b170598c09fc8e2c2d84128b6dfea Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Mon, 7 Feb 2022 15:41:57 +0000 Subject: [PATCH] Add a .tsh/config file and add support for configuring custom http headers --- api/client/contextdialer.go | 6 ++- api/client/webclient/webclient.go | 61 ++++++++++++++++---------- api/client/webclient/webclient_test.go | 7 +-- lib/client/api.go | 27 +++++++----- lib/client/keystore.go | 24 +++++++++- lib/client/keystore_test.go | 13 ++++++ lib/reversetunnel/agent.go | 4 +- lib/reversetunnel/resolver.go | 4 +- lib/reversetunnel/transport.go | 3 +- lib/web/apiserver.go | 2 +- tool/tsh/tsh.go | 30 +++++++++++++ tool/tsh/tsh_test.go | 9 ++++ tool/tsh/tshconfig.go | 59 +++++++++++++++++++++++++ 13 files changed, 205 insertions(+), 44 deletions(-) create mode 100644 tool/tsh/tshconfig.go diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index 2c2d08e29e894..158b505931bae 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -57,7 +57,8 @@ func NewDirectDialer(keepAlivePeriod, dialTimeout time.Duration) ContextDialer { func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer { dialer := newTunnelDialer(ssh, keepAlivePeriod, dialTimeout) return ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) { - tunnelAddr, err := webclient.GetTunnelAddr(ctx, discoveryAddr, insecure, nil) + tunnelAddr, err := webclient.GetTunnelAddr( + &webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure}) if err != nil { return nil, trace.Wrap(err) } @@ -91,7 +92,8 @@ func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dur // through the SSH reverse tunnel on the proxy. func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer { return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) { - tunnelAddr, err := webclient.GetTunnelAddr(ctx, discoveryAddr, insecure, nil) + tunnelAddr, err := webclient.GetTunnelAddr( + &webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure}) if err != nil { return nil, trace.Wrap(err) } diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go index b82e16f7ac9fe..379f6ca8cb633 100644 --- a/api/client/webclient/webclient.go +++ b/api/client/webclient/webclient.go @@ -38,13 +38,22 @@ import ( log "github.com/sirupsen/logrus" ) +type Config struct { + Context context.Context + ProxyAddr string + Insecure bool + Pool *x509.CertPool + ConnectorName string + ExtraHeaders map[string]string +} + // newWebClient creates a new client to the HTTPS web proxy. -func newWebClient(insecure bool, pool *x509.CertPool) *http.Client { +func newWebClient(cfg *Config) *http.Client { return &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ - RootCAs: pool, - InsecureSkipVerify: insecure, + RootCAs: cfg.Pool, + InsecureSkipVerify: cfg.Insecure, }, }, } @@ -56,9 +65,13 @@ func newWebClient(insecure bool, pool *x509.CertPool) *http.Client { // * The target host must resolve to the loopback address. // If these conditions are not met, then the plain-HTTP fallback is not allowed, // and a the HTTPS failure will be considered final. -func doWithFallback(clt *http.Client, allowPlainHTTP bool, req *http.Request) (*http.Response, error) { +func doWithFallback(clt *http.Client, allowPlainHTTP bool, extraHeaders map[string]string, req *http.Request) (*http.Response, error) { // first try https and see how that goes req.URL.Scheme = "https" + for k, v := range extraHeaders { + req.Header.Add(k, v) + } + log.Debugf("Attempting %s %s%s", req.Method, req.URL.Host, req.URL.Path) resp, err := clt.Do(req) @@ -88,18 +101,18 @@ func doWithFallback(clt *http.Client, allowPlainHTTP bool, req *http.Request) (* // Find fetches discovery data by connecting to the given web proxy address. // It is designed to fetch proxy public addresses without any inefficiencies. -func Find(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool) (*PingResponse, error) { - clt := newWebClient(insecure, pool) +func Find(cfg *Config) (*PingResponse, error) { + clt := newWebClient(cfg) defer clt.CloseIdleConnections() - endpoint := fmt.Sprintf("https://%s/webapi/find", proxyAddr) + endpoint := fmt.Sprintf("https://%s/webapi/find", cfg.ProxyAddr) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil) if err != nil { return nil, trace.Wrap(err) } - resp, err := doWithFallback(clt, insecure, req) + resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req) if err != nil { return nil, trace.Wrap(err) } @@ -118,21 +131,21 @@ func Find(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertP // errors before being asked for passwords. The second is to return the form // of authentication that the server supports. This also leads to better user // experience: users only get prompted for the type of authentication the server supports. -func Ping(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool, connectorName string) (*PingResponse, error) { - clt := newWebClient(insecure, pool) +func Ping(cfg *Config) (*PingResponse, error) { + clt := newWebClient(cfg) defer clt.CloseIdleConnections() - endpoint := fmt.Sprintf("https://%s/webapi/ping", proxyAddr) - if connectorName != "" { - endpoint = fmt.Sprintf("%s/%s", endpoint, connectorName) + endpoint := fmt.Sprintf("https://%s/webapi/ping", cfg.ProxyAddr) + if cfg.ConnectorName != "" { + endpoint = fmt.Sprintf("%s/%s", endpoint, cfg.ConnectorName) } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil) if err != nil { return nil, trace.Wrap(err) } - resp, err := doWithFallback(clt, insecure, req) + resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req) if err != nil { return nil, trace.Wrap(err) } @@ -147,32 +160,32 @@ func Ping(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertP } // GetTunnelAddr returns the tunnel address either set in an environment variable or retrieved from the web proxy. -func GetTunnelAddr(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool) (string, error) { +func GetTunnelAddr(cfg *Config) (string, error) { // If TELEPORT_TUNNEL_PUBLIC_ADDR is set, nothing else has to be done, return it. if tunnelAddr := os.Getenv(defaults.TunnelPublicAddrEnvar); tunnelAddr != "" { return extractHostPort(tunnelAddr) } // Ping web proxy to retrieve tunnel proxy address. - pr, err := Find(ctx, proxyAddr, insecure, nil) + pr, err := Find(cfg) if err != nil { return "", trace.Wrap(err) } - return tunnelAddr(proxyAddr, pr.Proxy) + return tunnelAddr(cfg.ProxyAddr, pr.Proxy) } -func GetMOTD(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool) (*MotD, error) { - clt := newWebClient(insecure, pool) +func GetMOTD(cfg *Config) (*MotD, error) { + clt := newWebClient(cfg) defer clt.CloseIdleConnections() - endpoint := fmt.Sprintf("https://%s/webapi/motd", proxyAddr) + endpoint := fmt.Sprintf("https://%s/webapi/motd", cfg.ProxyAddr) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil) if err != nil { return nil, trace.Wrap(err) } - resp, err := clt.Do(req) + resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req) if err != nil { return nil, trace.Wrap(err) } diff --git a/api/client/webclient/webclient_test.go b/api/client/webclient/webclient_test.go index 556d6cab96478..495df446f8f87 100644 --- a/api/client/webclient/webclient_test.go +++ b/api/client/webclient/webclient_test.go @@ -52,14 +52,15 @@ func TestPlainHttpFallback(t *testing.T) { desc: "Ping", handler: newPingHandler("/webapi/ping"), actionUnderTest: func(addr string, insecure bool) error { - _, err := Ping(context.Background(), addr, insecure, nil /*pool*/, "") + _, err := Ping( + &Config{Context: context.Background(), ProxyAddr: addr, Insecure: insecure}) return err }, }, { desc: "Find", handler: newPingHandler("/webapi/find"), actionUnderTest: func(addr string, insecure bool) error { - _, err := Find(context.Background(), addr, insecure, nil /*pool*/) + _, err := Find(&Config{Context: context.Background(), ProxyAddr: addr, Insecure: insecure}) return err }, }, @@ -104,7 +105,7 @@ func TestPlainHttpFallback(t *testing.T) { func TestGetTunnelAddr(t *testing.T) { t.Setenv(defaults.TunnelPublicAddrEnvar, "tunnel.example.com:4024") - tunnelAddr, err := GetTunnelAddr(context.Background(), "", true, nil) + tunnelAddr, err := GetTunnelAddr(&Config{Context: context.Background(), ProxyAddr: "", Insecure: false}) require.NoError(t, err) require.Equal(t, "tunnel.example.com:4024", tunnelAddr) } diff --git a/lib/client/api.go b/lib/client/api.go index 59d811d466c82..fb030b72d1ebf 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -348,6 +348,9 @@ type Config struct { // Invited is a list of people invited to a session. Invited []string + + // ExtraProxyHeaders is a collection of http headers to be included in requests to the WebProxy. + ExtraProxyHeaders map[string]string } // CachePolicy defines cache policy for local clients @@ -2544,12 +2547,13 @@ func (tc *TeleportClient) Ping(ctx context.Context) (*webclient.PingResponse, er if tc.lastPing != nil { return tc.lastPing, nil } - pr, err := webclient.Ping( - ctx, - tc.WebProxyAddr, - tc.InsecureSkipVerify, - loopbackPool(tc.WebProxyAddr), - tc.AuthConnector) + pr, err := webclient.Ping(&webclient.Config{ + Context: ctx, + ProxyAddr: tc.WebProxyAddr, + Insecure: tc.InsecureSkipVerify, + Pool: loopbackPool(tc.WebProxyAddr), + ConnectorName: tc.AuthConnector, + ExtraHeaders: tc.ExtraProxyHeaders}) if err != nil { return nil, trace.Wrap(err) } @@ -2581,10 +2585,13 @@ func (tc *TeleportClient) Ping(ctx context.Context) (*webclient.PingResponse, er // confirmation from the user. func (tc *TeleportClient) ShowMOTD(ctx context.Context) error { motd, err := webclient.GetMOTD( - ctx, - tc.WebProxyAddr, - tc.InsecureSkipVerify, - loopbackPool(tc.WebProxyAddr)) + &webclient.Config{ + Context: ctx, + ProxyAddr: tc.WebProxyAddr, + Insecure: tc.InsecureSkipVerify, + Pool: loopbackPool(tc.WebProxyAddr), + ExtraHeaders: tc.ExtraProxyHeaders}) + if err != nil { return trace.Wrap(err) } diff --git a/lib/client/keystore.go b/lib/client/keystore.go index 3b4b31b5776dc..5b47d232b1b78 100644 --- a/lib/client/keystore.go +++ b/lib/client/keystore.go @@ -50,6 +50,10 @@ const ( // keyFilePerms is the default permissions applied to key files (.cert, .key, pub) // under ~/.tsh keyFilePerms os.FileMode = 0600 + + // tshConfigFileName is the name of the directory containing the + // tsh config file. + tshConfigFileName = "config" ) // LocalKeyStore interface allows for different storage backends for tsh to @@ -223,9 +227,27 @@ func (fs *FSLocalKeyStore) DeleteUserCerts(idx KeyIndex, opts ...CertOption) err // DeleteKeys removes all session keys. func (fs *FSLocalKeyStore) DeleteKeys() error { - if err := os.RemoveAll(fs.KeyDir); err != nil { + + files, err := os.ReadDir(fs.KeyDir) + if err != nil { return trace.ConvertSystemError(err) } + for _, file := range files { + if file.IsDir() && file.Name() == tshConfigFileName { + continue + } + if file.IsDir() { + err := os.RemoveAll(filepath.Join(fs.KeyDir, file.Name())) + if err != nil { + return trace.ConvertSystemError(err) + } + continue + } + err := os.Remove(filepath.Join(fs.KeyDir, file.Name())) + if err != nil { + return trace.ConvertSystemError(err) + } + } return nil } diff --git a/lib/client/keystore_test.go b/lib/client/keystore_test.go index c6e26402464b5..0a9fe29197b80 100644 --- a/lib/client/keystore_test.go +++ b/lib/client/keystore_test.go @@ -402,6 +402,19 @@ func TestAddKey_withoutSSHCert(t *testing.T) { require.Len(t, keyCopy.DBTLSCerts, 1) } +func TestConfigDirNotDeleted(t *testing.T) { + s, cleanup := newTest(t) + t.Cleanup(cleanup) + idx := KeyIndex{"host.a", "bob", "root"} + s.store.AddKey(s.makeSignedKey(t, idx, false)) + configPath := filepath.Join(s.storeDir, "config") + require.NoError(t, os.Mkdir(configPath, 0700)) + require.NoError(t, s.store.DeleteKeys()) + require.DirExists(t, configPath) + + require.NoDirExists(t, filepath.Join(s.storeDir, "keys")) +} + type keyStoreTest struct { storeDir string store *FSLocalKeyStore diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 22bebc91c6d54..66f440bc5cd49 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -262,7 +262,9 @@ func (a *Agent) getHostCheckers() ([]ssh.PublicKey, error) { // If this is Web Service port check if proxy support ALPN SNI Listener. func (a *Agent) getReverseTunnelDetails() *reverseTunnelDetails { pd := reverseTunnelDetails{TLSRoutingEnabled: false} - resp, err := webclient.Find(a.ctx, a.Addr.Addr, lib.IsInsecureDevMode(), nil) + resp, err := webclient.Find( + &webclient.Config{Context: a.ctx, ProxyAddr: a.Addr.Addr, Insecure: lib.IsInsecureDevMode()}) + if err != nil { // If TLS Routing is disabled the address is the proxy reverse tunnel // address the ping call will always fail. diff --git a/lib/reversetunnel/resolver.go b/lib/reversetunnel/resolver.go index a467f5214577b..528063a0acbc3 100644 --- a/lib/reversetunnel/resolver.go +++ b/lib/reversetunnel/resolver.go @@ -58,7 +58,9 @@ func WebClientResolver(ctx context.Context, addrs []utils.NetAddr, insecureTLS b for _, addr := range addrs { // In insecure mode, any certificate is accepted. In secure mode the hosts // CAs are used to validate the certificate on the proxy. - tunnelAddr, err := webclient.GetTunnelAddr(ctx, addr.String(), insecureTLS, nil) + tunnelAddr, err := webclient.GetTunnelAddr( + &webclient.Config{Context: ctx, ProxyAddr: addr.String(), Insecure: insecureTLS}) + if err != nil { errs = append(errs, err) continue diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index 73b593690ba49..ceb930296f067 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -91,7 +91,8 @@ func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Co } // Check if t.ProxyAddr is ProxyWebPort and remote Proxy supports TLS ALPNSNIListener. - resp, err := webclient.Find(ctx, addr.Addr, t.InsecureSkipTLSVerify, nil) + resp, err := webclient.Find( + &webclient.Config{Context: ctx, ProxyAddr: addr.Addr, Insecure: t.InsecureSkipTLSVerify}) if err != nil { // If TLS Routing is disabled the address is the proxy reverse tunnel // address thus the ping call will always fail. diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 291cc746c8109..3bfc4da866be0 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -740,7 +740,7 @@ func defaultAuthenticationSettings(ctx context.Context, authClient auth.ClientI) func (h *Handler) ping(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { var err error - + fmt.Println("HONK HONK HONK", r.Header) defaultSettings, err := defaultAuthenticationSettings(r.Context(), h.cfg.ProxyClient) if err != nil { return nil, trace.Wrap(err) diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index d2a0097a4ceb9..eb4123c424e02 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -26,6 +26,7 @@ import ( "os/signal" "path" "path/filepath" + "regexp" "runtime" "sort" "strings" @@ -39,6 +40,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" @@ -283,6 +285,9 @@ type CLIConf struct { // JoinMode is the participant mode someone is joining a session as. JoinMode string + + // ExtraProxyHeaders is configuration read from the .tsh/config/config.yaml file. + ExtraProxyHeaders []ExtraProxyHeaders } // Stdout returns the stdout writer. @@ -642,6 +647,15 @@ func Run(args []string, opts ...cliOption) error { setEnvFlags(&cf, os.Getenv) + confOptions, err := loadConfig(cf.HomePath) + if err != nil && !trace.IsNotFound(err) { + return trace.Wrap(err, "failed to load tsh config from %s", + filepath.Join(profile.FullProfilePath(cf.HomePath), tshConfigPath)) + } + if confOptions != nil { + cf.ExtraProxyHeaders = confOptions.ExtraHeaders + } + switch command { case ver.FullCommand(): utils.PrintVersion() @@ -1959,6 +1973,22 @@ func makeClient(cf *CLIConf, useProfileLogin bool) (*client.TeleportClient, erro return nil, trace.Wrap(err) } + if c.ExtraProxyHeaders == nil { + c.ExtraProxyHeaders = map[string]string{} + } + for _, proxyHeaders := range cf.ExtraProxyHeaders { + proxyGlob := utils.GlobToRegexp(proxyHeaders.Proxy) + proxyRegexp, err := regexp.Compile(proxyGlob) + if err != nil { + return nil, trace.WrapWithMessage(err, "invalid proxy glob %q in tsh configuration file", proxyGlob) + } + if proxyRegexp.MatchString(c.WebProxyAddr) { + for k, v := range proxyHeaders.Headers { + c.ExtraProxyHeaders[k] = v + } + } + } + if len(fPorts) > 0 { c.LocalForwardPorts = fPorts } diff --git a/tool/tsh/tsh_test.go b/tool/tsh/tsh_test.go index 375554dea5750..b48d09f1b35be 100644 --- a/tool/tsh/tsh_test.go +++ b/tool/tsh/tsh_test.go @@ -384,6 +384,11 @@ func TestMakeClient(t *testing.T) { conf.NodePort = 46528 conf.LocalForwardPorts = []string{"80:remote:180"} conf.DynamicForwardedPorts = []string{":8080"} + conf.ExtraProxyHeaders = []ExtraProxyHeaders{ + {Proxy: "proxy:3080", Headers: map[string]string{"A": "B"}}, + {Proxy: "*roxy:3080", Headers: map[string]string{"C": "D"}}, + {Proxy: "*hello:3080", Headers: map[string]string{"E": "F"}}, // shouldn't get included + } tc, err = makeClient(&conf, true) require.NoError(t, err) require.Equal(t, time.Minute*time.Duration(conf.MinsToLive), tc.Config.KeyTTL) @@ -403,6 +408,10 @@ func TestMakeClient(t *testing.T) { }, }, tc.Config.DynamicForwardedPorts) + require.Equal(t, + map[string]string{"A": "B", "C": "D"}, + tc.ExtraProxyHeaders) + _, proxy := makeTestServers(t) proxyWebAddr, err := proxy.ProxyWebAddr() diff --git a/tool/tsh/tshconfig.go b/tool/tsh/tshconfig.go new file mode 100644 index 0000000000000..f0fe4d533c0da --- /dev/null +++ b/tool/tsh/tshconfig.go @@ -0,0 +1,59 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "os" + "path/filepath" + + "github.com/gravitational/teleport/api/profile" + "github.com/gravitational/trace" + "gopkg.in/yaml.v2" +) + +// .tsh config must go in a subdir as all .yaml files in .tsh get +// parsed automatically by the profile loader and results in yaml +// unmarshal errors. +const tshConfigPath = "config/config.yaml" + +// TshConfig represents configuration laoded from the tsh config file. +type TshConfig struct { + // ExtraHeaders are additional http headers to be included in + // webclient requests. + ExtraHeaders []ExtraProxyHeaders `yaml:"add_headers"` +} + +// ExtraProxyHeaders represents the headers to include with the +// webclient. +type ExtraProxyHeaders struct { + // Proxy is the domain of the proxy for these set of Headers, can contain globs. + Proxy string `yaml:"proxy"` + // Headers are the http header key values. + Headers map[string]string `yaml:"headers,omitempty"` +} + +func loadConfig(homePath string) (*TshConfig, error) { + confPath := filepath.Join(profile.FullProfilePath(homePath), tshConfigPath) + configFile, err := os.Open(confPath) + if err != nil { + return nil, trace.ConvertSystemError(err) + } + defer configFile.Close() + cfg := TshConfig{} + err = yaml.NewDecoder(configFile).Decode(&cfg) + return &cfg, trace.Wrap(err) +}