diff --git a/ssh/ssh.go b/ssh/ssh.go index 12e3e07..695275a 100644 --- a/ssh/ssh.go +++ b/ssh/ssh.go @@ -40,6 +40,11 @@ type SSHContext struct { SkipHostKeyCheck bool } +type FileTransfer struct { + Source string + Destination string +} + func (ctx *SSHContext) Cmd(host Host, parts ...string) (*exec.Cmd, error) { var err error @@ -51,15 +56,20 @@ func (ctx *SSHContext) Cmd(host Host, parts ...string) (*exec.Cmd, error) { return ctx.SudoCmd(host, parts...) } - cmdArgs := ctx.initialSSHArgs(host) + cmd, cmdArgs := ctx.sshArgs(host, nil) cmdArgs = append(cmdArgs, parts...) - command := exec.Command("ssh", cmdArgs...) + command := exec.Command(cmd, cmdArgs...) return command, nil } -func (ctx *SSHContext) initialSSHArgs(host Host) []string { - args := make([]string, 0) +func (ctx *SSHContext) sshArgs(host Host, transfer *FileTransfer) (cmd string, args []string) { + if transfer != nil { + cmd = "scp" + } else { + cmd = "ssh" + } + if ctx.SkipHostKeyCheck { args = append(args, "-o", "StrictHostkeyChecking=No", @@ -69,13 +79,18 @@ func (ctx *SSHContext) initialSSHArgs(host Host) []string { args = append(args, "-i") args = append(args, ctx.IdentityFile) } + var hostAndDestination = host.GetTargetHost() + if transfer != nil { + args = append(args, transfer.Source) + hostAndDestination += ":" + transfer.Destination + } if ctx.Username != "" { - args = append(args, ctx.Username + "@" + host.GetTargetHost()) + args = append(args, ctx.Username+"@"+hostAndDestination) } else { - args = append(args, host.GetTargetHost()) + args = append(args, hostAndDestination) } - return args + return } func (ctx *SSHContext) SudoCmd(host Host, parts ...string) (*exec.Cmd, error) { @@ -92,7 +107,7 @@ func (ctx *SSHContext) SudoCmd(host Host, parts ...string) (*exec.Cmd, error) { } } - cmdArgs := ctx.initialSSHArgs(host) + cmd, cmdArgs := ctx.sshArgs(host, nil) // normalize sudo if parts[0] == "sudo" { @@ -110,7 +125,7 @@ func (ctx *SSHContext) SudoCmd(host Host, parts ...string) (*exec.Cmd, error) { cmdArgs = append(cmdArgs, "-p", "''", "-k", "--") cmdArgs = append(cmdArgs, parts...) - command := exec.Command("ssh", cmdArgs...) + command := exec.Command(cmd, cmdArgs...) if ctx.sudoPassword != "" { err := writeSudoPassword(command, ctx.sudoPassword) if err != nil { @@ -251,25 +266,17 @@ func (ctx *SSHContext) MakeTempFile(host Host) (path string, err error) { } func (ctx *SSHContext) UploadFile(host Host, source string, destination string) (err error) { - destinationAndHost := host.GetTargetHost() + ":" + destination - - parts := make([]string, 0) - if ctx.IdentityFile != "" { - parts = append(parts, "-i", ctx.IdentityFile) - } - if ctx.Username != "" { - destinationAndHost = ctx.Username + "@" + destinationAndHost - } - - parts = append(parts, source, destinationAndHost) - - cmd := exec.Command("scp", parts...) + c, parts := ctx.sshArgs(host, &FileTransfer{ + Source: source, + Destination: destination, + }) + cmd := exec.Command(c, parts...) data, err := cmd.CombinedOutput() if err != nil { errorMessage := fmt.Sprintf( "Error on remote host %s:\nCouldn't upload file: %s -> %s\n\nOriginal error:\n%s", - host.GetTargetHost(), source, destinationAndHost, string(data), + host.GetTargetHost(), source, destination, string(data), ) return errors.New(errorMessage) } @@ -304,7 +311,6 @@ func (ctx *SSHContext) MakeDirs(host Host, path string, parents bool, mode os.Fi return nil } - func (ctx *SSHContext) MoveFile(host Host, source string, destination string) (err error) { cmd, err := ctx.SudoCmd(host, "mv", source, destination) if err != nil {