diff --git a/ssh/host_key.go b/ssh/host_key.go index ab816088..007b1362 100644 --- a/ssh/host_key.go +++ b/ssh/host_key.go @@ -33,8 +33,8 @@ import ( // clientHostKeyAlgos defines what HostKey algorithms to be // used by the ssh client when using `ssh.Dial`. The default is // empty, which defaults to Golang's preferred HostKey algorithms. -func ScanHostKey(host string, timeout time.Duration, clientHostKeyAlgos []string) ([]byte, error) { - col := &HostKeyCollector{} +func ScanHostKey(host string, timeout time.Duration, clientHostKeyAlgos []string, hashKeys bool) ([]byte, error) { + col := &HostKeyCollector{hashKeys: hashKeys} config := &ssh.ClientConfig{ HostKeyCallback: col.StoreKey(), Timeout: timeout, @@ -54,6 +54,7 @@ func ScanHostKey(host string, timeout time.Duration, clientHostKeyAlgos []string // HostKeyCallBack to collect public keys from an SSH server. type HostKeyCollector struct { knownKeys []byte + hashKeys bool } // StoreKey stores the public key in bytes as returned by the host. @@ -62,9 +63,13 @@ type HostKeyCollector struct { // the algorithm you want to collect. func (c *HostKeyCollector) StoreKey() ssh.HostKeyCallback { return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + h := knownhosts.Normalize(hostname) + if c.hashKeys { + h = knownhosts.HashHostname(h) + } c.knownKeys = append( c.knownKeys, - fmt.Sprintf("%s %s %s\n", knownhosts.Normalize(hostname), key.Type(), base64.StdEncoding.EncodeToString(key.Marshal()))..., + fmt.Sprintf("%s %s %s\n", h, key.Type(), base64.StdEncoding.EncodeToString(key.Marshal()))..., ) return nil } diff --git a/ssh/host_key_test.go b/ssh/host_key_test.go index dfbef364..675e0a7b 100644 --- a/ssh/host_key_test.go +++ b/ssh/host_key_test.go @@ -85,7 +85,7 @@ func TestScanHost(t *testing.T) { go startSSH(listener, sshConfig) - kh, err := ScanHostKey(serverAddr, 5*time.Second, []string{tt.sshKeyTypeName}) + kh, err := ScanHostKey(serverAddr, 5*time.Second, []string{tt.sshKeyTypeName}, false) if tt.wantErr == "" { g.Expect(err).NotTo(HaveOccurred()) // Confirm the returned key is of expected type.