diff --git a/cmd/common.go b/cmd/common.go index beb558f510f..47a25f3f266 100644 --- a/cmd/common.go +++ b/cmd/common.go @@ -15,7 +15,6 @@ import ( func Init() { bootstrap.InitConfig() bootstrap.Log() - bootstrap.InitHostKey() bootstrap.InitDB() data.InitData() bootstrap.InitIndex() diff --git a/internal/conf/var.go b/internal/conf/var.go index bf17972cb22..f7239d9c0eb 100644 --- a/internal/conf/var.go +++ b/internal/conf/var.go @@ -1,7 +1,6 @@ package conf import ( - "golang.org/x/crypto/ssh" "net/url" "regexp" ) @@ -32,6 +31,4 @@ var ( RawIndexHtml string ManageHtml string IndexHtml string -) - -var SSHSigners []ssh.Signer \ No newline at end of file +) \ No newline at end of file diff --git a/server/sftp.go b/server/sftp.go index 0deb9e78cdd..8efc0221cdf 100644 --- a/server/sftp.go +++ b/server/sftp.go @@ -9,6 +9,7 @@ import ( "github.com/alist-org/alist/v3/internal/setting" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/ftp" + "github.com/alist-org/alist/v3/server/sftp" "github.com/pkg/errors" "golang.org/x/crypto/ssh" "net/http" @@ -21,6 +22,7 @@ type SftpDriver struct { } func NewSftpDriver() (*SftpDriver, error) { + sftp.InitHostKey() header := &http.Header{} header.Add("User-Agent", setting.GetStr(conf.FTPProxyUserAgent)) return &SftpDriver{ @@ -40,7 +42,7 @@ func (d *SftpDriver) GetConfig() *sftpd.Config { AuthLogCallback: d.AuthLogCallback, BannerCallback: d.GetBanner, } - for _, k := range conf.SSHSigners { + for _, k := range sftp.SSHSigners { serverConfig.AddHostKey(k) } d.config = &sftpd.Config{ @@ -62,7 +64,7 @@ func (d *SftpDriver) GetFileSystem(sc *ssh.ServerConn) (sftpd.FileSystem, error) ctx = context.WithValue(ctx, "meta_pass", "") ctx = context.WithValue(ctx, "client_ip", sc.RemoteAddr().String()) ctx = context.WithValue(ctx, "proxy_header", d.proxyHeader) - return &ftp.SftpDriverAdapter{FtpDriver: ftp.NewAferoAdapter(ctx)}, nil + return &sftp.DriverAdapter{FtpDriver: ftp.NewAferoAdapter(ctx)}, nil } func (d *SftpDriver) Close() { diff --git a/server/sftp/const.go b/server/sftp/const.go new file mode 100644 index 00000000000..1098a13682a --- /dev/null +++ b/server/sftp/const.go @@ -0,0 +1,11 @@ +package sftp + +// From leffss/sftpd +const ( + SSH_FXF_READ = 0x00000001 + SSH_FXF_WRITE = 0x00000002 + SSH_FXF_APPEND = 0x00000004 + SSH_FXF_CREAT = 0x00000008 + SSH_FXF_TRUNC = 0x00000010 + SSH_FXF_EXCL = 0x00000020 +) \ No newline at end of file diff --git a/server/sftp/hostkey.go b/server/sftp/hostkey.go new file mode 100644 index 00000000000..2d24f86b32c --- /dev/null +++ b/server/sftp/hostkey.go @@ -0,0 +1,105 @@ +package sftp + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "github.com/alist-org/alist/v3/cmd/flags" + "github.com/alist-org/alist/v3/pkg/utils" + "golang.org/x/crypto/ssh" + "os" + "path/filepath" +) + +var SSHSigners []ssh.Signer + +func InitHostKey() { + if SSHSigners != nil { + return + } + sshPath := filepath.Join(flags.DataDir, "ssh") + if !utils.Exists(sshPath) { + err := utils.CreateNestedDirectory(sshPath) + if err != nil { + utils.Log.Fatalf("failed to create ssh directory: %+v", err) + return + } + } + SSHSigners = make([]ssh.Signer, 0, 4) + if rsaKey, ok := LoadOrGenerateRSAHostKey(sshPath); ok { + SSHSigners = append(SSHSigners, rsaKey) + } + // TODO Add keys for other encryption algorithms +} + +func LoadOrGenerateRSAHostKey(parentDir string) (ssh.Signer, bool) { + privateKeyPath := filepath.Join(parentDir, "ssh_host_rsa_key") + publicKeyPath := filepath.Join(parentDir, "ssh_host_rsa_key.pub") + privateKeyBytes, err := os.ReadFile(privateKeyPath) + if err == nil { + var privateKey *rsa.PrivateKey + privateKey, err = rsaDecodePrivateKey(privateKeyBytes) + if err == nil { + var ret ssh.Signer + ret, err = ssh.NewSignerFromKey(privateKey) + if err == nil { + return ret, true + } + } + } + _ = os.Remove(privateKeyPath) + _ = os.Remove(publicKeyPath) + privateKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + utils.Log.Fatalf("failed to generate RSA private key: %+v", err) + return nil, false + } + publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + if err != nil { + utils.Log.Fatalf("failed to generate RSA public key: %+v", err) + return nil, false + } + ret, err := ssh.NewSignerFromKey(privateKey) + if err != nil { + utils.Log.Fatalf("failed to generate RSA signer: %+v", err) + return nil, false + } + privateBytes := rsaEncodePrivateKey(privateKey) + publicBytes := ssh.MarshalAuthorizedKey(publicKey) + err = os.WriteFile(privateKeyPath, privateBytes, 0600) + if err != nil { + utils.Log.Fatalf("failed to write RSA private key to file: %+v", err) + return nil, false + } + err = os.WriteFile(publicKeyPath, publicBytes, 0644) + if err != nil { + _ = os.Remove(privateKeyPath) + utils.Log.Fatalf("failed to write RSA public key to file: %+v", err) + return nil, false + } + return ret, true +} + +func rsaEncodePrivateKey(privateKey *rsa.PrivateKey) []byte { + privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey) + privateBlock := &pem.Block{ + Type: "RSA PRIVATE KEY", + Headers: nil, + Bytes: privateKeyBytes, + } + return pem.EncodeToMemory(privateBlock) +} + +func rsaDecodePrivateKey(bytes []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(bytes) + if block == nil { + return nil, fmt.Errorf("failed to parse PEM block containing the key") + } + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + return privateKey, nil +} \ No newline at end of file diff --git a/server/sftp/sftp.go b/server/sftp/sftp.go new file mode 100644 index 00000000000..467f1a730a7 --- /dev/null +++ b/server/sftp/sftp.go @@ -0,0 +1,123 @@ +package sftp + +import ( + "github.com/KirCute/sftpd-alist" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/ftp" + "os" +) + +type DriverAdapter struct { + FtpDriver *ftp.AferoAdapter +} + +func (s *DriverAdapter) OpenFile(_ string, _ uint32, _ *sftpd.Attr) (sftpd.File, error) { + // See also GetHandle + return nil, errs.NotImplement +} + +func (s *DriverAdapter) OpenDir(_ string) (sftpd.Dir, error) { + // See also GetHandle + return nil, errs.NotImplement +} + +func (s *DriverAdapter) Remove(name string) error { + return s.FtpDriver.Remove(name) +} + +func (s *DriverAdapter) Rename(old, new string, _ uint32) error { + return s.FtpDriver.Rename(old, new) +} + +func (s *DriverAdapter) Mkdir(name string, attr *sftpd.Attr) error { + return s.FtpDriver.Mkdir(name, attr.Mode) +} + +func (s *DriverAdapter) Rmdir(name string) error { + return s.Remove(name) +} + +func (s *DriverAdapter) Stat(name string, _ bool) (*sftpd.Attr, error) { + stat, err := s.FtpDriver.Stat(name) + if err != nil { + return nil, err + } + return fileInfoToSftpAttr(stat), nil +} + +func (s *DriverAdapter) SetStat(_ string, _ *sftpd.Attr) error { + return errs.NotSupport +} + +func (s *DriverAdapter) ReadLink(_ string) (string, error) { + return "", errs.NotSupport +} + +func (s *DriverAdapter) CreateLink(_, _ string, _ uint32) error { + return errs.NotSupport +} + +func (s *DriverAdapter) RealPath(path string) (string, error) { + return utils.FixAndCleanPath(path), nil +} + +func (s *DriverAdapter) GetHandle(name string, flags uint32, _ *sftpd.Attr, offset uint64) (sftpd.FileTransfer, error) { + return s.FtpDriver.GetHandle(name, sftpFlagToOpenMode(flags), int64(offset)) +} + +func (s *DriverAdapter) ReadDir(name string) ([]sftpd.NamedAttr, error) { + dir, err := s.FtpDriver.ReadDir(name) + if err != nil { + return nil, err + } + ret := make([]sftpd.NamedAttr, len(dir)) + for i, d := range dir { + ret[i] = *fileInfoToSftpNamedAttr(d) + } + return ret, nil +} + +// From leffss/sftpd +func sftpFlagToOpenMode(flags uint32) int { + mode := 0 + if (flags & SSH_FXF_READ) != 0 { + mode |= os.O_RDONLY + } + if (flags & SSH_FXF_WRITE) != 0 { + mode |= os.O_WRONLY + } + if (flags & SSH_FXF_APPEND) != 0 { + mode |= os.O_APPEND + } + if (flags & SSH_FXF_CREAT) != 0 { + mode |= os.O_CREATE + } + if (flags & SSH_FXF_TRUNC) != 0 { + mode |= os.O_TRUNC + } + if (flags & SSH_FXF_EXCL) != 0 { + mode |= os.O_EXCL + } + return mode +} + +func fileInfoToSftpAttr(stat os.FileInfo) *sftpd.Attr { + ret := &sftpd.Attr{} + ret.Flags |= sftpd.ATTR_SIZE + ret.Size = uint64(stat.Size()) + ret.Flags |= sftpd.ATTR_MODE + ret.Mode = stat.Mode() + ret.Flags |= sftpd.ATTR_TIME + ret.ATime = stat.Sys().(model.Obj).CreateTime() + ret.MTime = stat.ModTime() + return ret +} + +func fileInfoToSftpNamedAttr(stat os.FileInfo) *sftpd.NamedAttr { + return &sftpd.NamedAttr{ + Name: stat.Name(), + Attr: *fileInfoToSftpAttr(stat), + } +} \ No newline at end of file