diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 99e32cfa..ca604198 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -125,34 +125,39 @@ func InitK3sCluster(cluster *types.Cluster) error { logrus.Infof("[%s] successfully created k3s master-%d\n", cluster.Provider, 1) masterErrChan := make(chan error) + masterWaitGroupDone := make(chan bool) masterWaitGroup := &sync.WaitGroup{} masterWaitGroup.Add(len(cluster.MasterNodes) - 1) - defer close(masterErrChan) for i, master := range cluster.MasterNodes { // skip first master nodes if i == 0 { continue } - logrus.Infof("[%s] creating k3s master-%d...\n", cluster.Provider, i+1) - initAdditionalMaster(masterWaitGroup, masterErrChan, k3sScript, k3sMirror, dockerMirror, publicIP, masterExtraArgs, cluster, master, aliCCM) - logrus.Infof("[%s] successfully created k3s master-%d\n", cluster.Provider, i+1) + go func(i int, master types.Node) { + logrus.Infof("[%s] creating k3s master-%d...\n", cluster.Provider, i+1) + initAdditionalMaster(masterWaitGroup, masterErrChan, k3sScript, k3sMirror, dockerMirror, publicIP, masterExtraArgs, cluster, master, aliCCM) + logrus.Infof("[%s] successfully created k3s master-%d\n", cluster.Provider, i+1) + }(i, master) } - masterWaitGroup.Wait() + go func() { + masterWaitGroup.Wait() + close(masterWaitGroupDone) + }() select { - case err, ok := <-masterErrChan: - if ok { - return err - } - default: + case <-masterWaitGroupDone: + break + case err := <-masterErrChan: + close(masterErrChan) + return err } workerErrChan := make(chan error) + workerWaitGroupDone := make(chan bool) workerWaitGroup := &sync.WaitGroup{} workerWaitGroup.Add(len(cluster.WorkerNodes)) - defer close(workerErrChan) for i, worker := range cluster.WorkerNodes { go func(i int, worker types.Node) { @@ -162,14 +167,17 @@ func InitK3sCluster(cluster *types.Cluster) error { }(i, worker) } - workerWaitGroup.Wait() + go func() { + workerWaitGroup.Wait() + close(workerWaitGroupDone) + }() select { - case err, ok := <-workerErrChan: - if ok { - return err - } - default: + case <-workerWaitGroupDone: + break + case err := <-workerErrChan: + close(workerErrChan) + return err } // get k3s cluster config. @@ -283,6 +291,7 @@ func JoinK3sNode(merged, added *types.Cluster) error { } errChan := make(chan error) + waitGroupDone := make(chan bool) waitGroup := &sync.WaitGroup{} waitGroup.Add(len(added.MasterNodes) + len(added.WorkerNodes)) defer close(errChan) @@ -315,14 +324,17 @@ func JoinK3sNode(merged, added *types.Cluster) error { } } - waitGroup.Wait() + go func() { + waitGroup.Wait() + close(waitGroupDone) + }() select { - case err, ok := <-errChan: - if ok { - return err - } - default: + case <-waitGroupDone: + break + case err := <-errChan: + close(errChan) + return err } // sync master & worker numbers. @@ -527,7 +539,7 @@ func initMaster(k3sScript, k3sMirror, dockerMirror, ip, extraArgs string, cluste func initAdditionalMaster(wg *sync.WaitGroup, errChan chan error, k3sScript, k3sMirror, dockerMirror, ip, extraArgs string, cluster *types.Cluster, master types.Node, aliCCM *alibaba.CloudControllerManager) { - defer wg.Done() + if strings.Contains(extraArgs, "--docker") { if _, err := execute(&hosts.Host{Node: master}, fmt.Sprintf(dockerCommand, dockerMirror), false); err != nil { @@ -549,11 +561,13 @@ func initAdditionalMaster(wg *sync.WaitGroup, errChan chan error, k3sScript, k3s strings.TrimSpace(extraArgs), cluster.K3sVersion), false); err != nil { errChan <- err } + + wg.Done() } func initWorker(wg *sync.WaitGroup, errChan chan error, k3sScript, k3sMirror, dockerMirror, extraArgs string, cluster *types.Cluster, worker types.Node, aliCCM *alibaba.CloudControllerManager) { - defer wg.Done() + if strings.Contains(extraArgs, "--docker") { if _, err := execute(&hosts.Host{Node: worker}, fmt.Sprintf(dockerCommand, dockerMirror), false); err != nil { @@ -571,11 +585,12 @@ func initWorker(wg *sync.WaitGroup, errChan chan error, k3sScript, k3sMirror, do strings.TrimSpace(extraArgs), cluster.K3sVersion), false); err != nil { errChan <- err } + + wg.Done() } func joinMaster(wg *sync.WaitGroup, errChan chan error, noFlannel bool, k3sScript, k3sMirror, dockerMirror, extraArgs string, merged *types.Cluster, full types.Node, aliCCM *alibaba.CloudControllerManager) { - defer wg.Done() if !strings.Contains(extraArgs, "server --server") { extraArgs += " server --server " + fmt.Sprintf("https://%s:6443", merged.IP) @@ -615,11 +630,12 @@ func joinMaster(wg *sync.WaitGroup, errChan chan error, noFlannel bool, k3sScrip strings.TrimSpace(extraArgs), merged.K3sVersion), false); err != nil { errChan <- err } + + wg.Done() } func joinWorker(wg *sync.WaitGroup, errChan chan error, k3sScript, k3sMirror, dockerMirror, extraArgs string, merged *types.Cluster, full types.Node, aliCCM *alibaba.CloudControllerManager) { - defer wg.Done() if strings.Contains(extraArgs, "--docker") { if _, err := execute(&hosts.Host{Node: full}, fmt.Sprintf(dockerCommand, dockerMirror), false); err != nil { @@ -637,6 +653,8 @@ func joinWorker(wg *sync.WaitGroup, errChan chan error, k3sScript, k3sMirror, do strings.TrimSpace(extraArgs), merged.K3sVersion), false); err != nil { errChan <- err } + + wg.Done() } func execute(host *hosts.Host, cmd string, print bool) (string, error) { diff --git a/pkg/hosts/dialer.go b/pkg/hosts/dialer.go index af3d69cc..1db36ac7 100644 --- a/pkg/hosts/dialer.go +++ b/pkg/hosts/dialer.go @@ -3,6 +3,7 @@ package hosts import ( "errors" "fmt" + "time" "github.com/cnrancher/autok3s/pkg/common" "github.com/cnrancher/autok3s/pkg/types" @@ -102,10 +103,11 @@ func newDialer(h *Host, kind string) (*Dialer, error) { } func (d *Dialer) getSSHTunnelConnection() (*ssh.Client, error) { - cfg, err := utils.GetSSHConfig(d.username, d.sshKey, d.passphrase, d.sshCert, d.password, d.useSSHAgentAuth) + timeout := time.Duration((common.Backoff.Steps - 1) * int(common.Backoff.Duration)) + cfg, err := utils.GetSSHConfig(d.username, d.sshKey, d.passphrase, d.sshCert, d.password, timeout, d.useSSHAgentAuth) if err != nil { return nil, err } - // Establish connection with SSH server + // establish connection with SSH server. return ssh.Dial(tcpNetProtocol, d.sshAddress, cfg) } diff --git a/pkg/utils/ssh.go b/pkg/utils/ssh.go index 55b4bb4a..7ba66330 100644 --- a/pkg/utils/ssh.go +++ b/pkg/utils/ssh.go @@ -6,6 +6,7 @@ import ( "net" "os" "path/filepath" + "time" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" @@ -36,9 +37,10 @@ func SSHCertificatePath(sshCertPath string) (string, error) { return string(buff), nil } -func GetSSHConfig(username, sshPrivateKeyString, passphrase, sshCert string, password string, useAgentAuth bool) (*ssh.ClientConfig, error) { +func GetSSHConfig(username, sshPrivateKeyString, passphrase, sshCert string, password string, timeout time.Duration, useAgentAuth bool) (*ssh.ClientConfig, error) { config := &ssh.ClientConfig{ User: username, + Timeout: timeout, HostKeyCallback: ssh.InsecureIgnoreHostKey(), }