Skip to content

Commit

Permalink
feat(sftp-server):在首次启用之前不生成主机密钥 (AlistGo#7734))
Browse files Browse the repository at this point in the history
  • Loading branch information
long2005a1 committed Jan 25, 2025
1 parent f51a620 commit 6dcba5e
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 7 deletions.
1 change: 0 additions & 1 deletion cmd/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
func Init() {
bootstrap.InitConfig()
bootstrap.Log()
bootstrap.InitHostKey()
bootstrap.InitDB()
data.InitData()
bootstrap.InitIndex()
Expand Down
5 changes: 1 addition & 4 deletions internal/conf/var.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package conf

import (
"golang.org/x/crypto/ssh"
"net/url"
"regexp"
)
Expand Down Expand Up @@ -32,6 +31,4 @@ var (
RawIndexHtml string
ManageHtml string
IndexHtml string
)

var SSHSigners []ssh.Signer
)
6 changes: 4 additions & 2 deletions server/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/ftp"
"github.com/alist-org/alist/v3/server/sftp"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"net/http"
Expand All @@ -21,6 +22,7 @@ type SftpDriver struct {
}

func NewSftpDriver() (*SftpDriver, error) {
sftp.InitHostKey()
header := &http.Header{}
header.Add("User-Agent", setting.GetStr(conf.FTPProxyUserAgent))
return &SftpDriver{
Expand All @@ -40,7 +42,7 @@ func (d *SftpDriver) GetConfig() *sftpd.Config {
AuthLogCallback: d.AuthLogCallback,
BannerCallback: d.GetBanner,
}
for _, k := range conf.SSHSigners {
for _, k := range sftp.SSHSigners {
serverConfig.AddHostKey(k)
}
d.config = &sftpd.Config{
Expand All @@ -62,7 +64,7 @@ func (d *SftpDriver) GetFileSystem(sc *ssh.ServerConn) (sftpd.FileSystem, error)
ctx = context.WithValue(ctx, "meta_pass", "")
ctx = context.WithValue(ctx, "client_ip", sc.RemoteAddr().String())
ctx = context.WithValue(ctx, "proxy_header", d.proxyHeader)
return &ftp.SftpDriverAdapter{FtpDriver: ftp.NewAferoAdapter(ctx)}, nil
return &sftp.DriverAdapter{FtpDriver: ftp.NewAferoAdapter(ctx)}, nil
}

func (d *SftpDriver) Close() {
Expand Down
11 changes: 11 additions & 0 deletions server/sftp/const.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package sftp

// From leffss/sftpd
const (
SSH_FXF_READ = 0x00000001
SSH_FXF_WRITE = 0x00000002
SSH_FXF_APPEND = 0x00000004
SSH_FXF_CREAT = 0x00000008
SSH_FXF_TRUNC = 0x00000010
SSH_FXF_EXCL = 0x00000020
)
105 changes: 105 additions & 0 deletions server/sftp/hostkey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package sftp

import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"github.com/alist-org/alist/v3/cmd/flags"
"github.com/alist-org/alist/v3/pkg/utils"
"golang.org/x/crypto/ssh"
"os"
"path/filepath"
)

var SSHSigners []ssh.Signer

func InitHostKey() {
if SSHSigners != nil {
return
}
sshPath := filepath.Join(flags.DataDir, "ssh")
if !utils.Exists(sshPath) {
err := utils.CreateNestedDirectory(sshPath)
if err != nil {
utils.Log.Fatalf("failed to create ssh directory: %+v", err)
return
}
}
SSHSigners = make([]ssh.Signer, 0, 4)
if rsaKey, ok := LoadOrGenerateRSAHostKey(sshPath); ok {
SSHSigners = append(SSHSigners, rsaKey)
}
// TODO Add keys for other encryption algorithms
}

func LoadOrGenerateRSAHostKey(parentDir string) (ssh.Signer, bool) {
privateKeyPath := filepath.Join(parentDir, "ssh_host_rsa_key")
publicKeyPath := filepath.Join(parentDir, "ssh_host_rsa_key.pub")
privateKeyBytes, err := os.ReadFile(privateKeyPath)
if err == nil {
var privateKey *rsa.PrivateKey
privateKey, err = rsaDecodePrivateKey(privateKeyBytes)
if err == nil {
var ret ssh.Signer
ret, err = ssh.NewSignerFromKey(privateKey)
if err == nil {
return ret, true
}
}
}
_ = os.Remove(privateKeyPath)
_ = os.Remove(publicKeyPath)
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
utils.Log.Fatalf("failed to generate RSA private key: %+v", err)
return nil, false
}
publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
if err != nil {
utils.Log.Fatalf("failed to generate RSA public key: %+v", err)
return nil, false
}
ret, err := ssh.NewSignerFromKey(privateKey)
if err != nil {
utils.Log.Fatalf("failed to generate RSA signer: %+v", err)
return nil, false
}
privateBytes := rsaEncodePrivateKey(privateKey)
publicBytes := ssh.MarshalAuthorizedKey(publicKey)
err = os.WriteFile(privateKeyPath, privateBytes, 0600)
if err != nil {
utils.Log.Fatalf("failed to write RSA private key to file: %+v", err)
return nil, false
}
err = os.WriteFile(publicKeyPath, publicBytes, 0644)
if err != nil {
_ = os.Remove(privateKeyPath)
utils.Log.Fatalf("failed to write RSA public key to file: %+v", err)
return nil, false
}
return ret, true
}

func rsaEncodePrivateKey(privateKey *rsa.PrivateKey) []byte {
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
privateBlock := &pem.Block{
Type: "RSA PRIVATE KEY",
Headers: nil,
Bytes: privateKeyBytes,
}
return pem.EncodeToMemory(privateBlock)
}

func rsaDecodePrivateKey(bytes []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(bytes)
if block == nil {
return nil, fmt.Errorf("failed to parse PEM block containing the key")
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
return privateKey, nil
}
123 changes: 123 additions & 0 deletions server/sftp/sftp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package sftp

import (
"github.com/KirCute/sftpd-alist"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/ftp"
"os"
)

type DriverAdapter struct {
FtpDriver *ftp.AferoAdapter
}

func (s *DriverAdapter) OpenFile(_ string, _ uint32, _ *sftpd.Attr) (sftpd.File, error) {
// See also GetHandle
return nil, errs.NotImplement
}

func (s *DriverAdapter) OpenDir(_ string) (sftpd.Dir, error) {
// See also GetHandle
return nil, errs.NotImplement
}

func (s *DriverAdapter) Remove(name string) error {
return s.FtpDriver.Remove(name)
}

func (s *DriverAdapter) Rename(old, new string, _ uint32) error {
return s.FtpDriver.Rename(old, new)
}

func (s *DriverAdapter) Mkdir(name string, attr *sftpd.Attr) error {
return s.FtpDriver.Mkdir(name, attr.Mode)
}

func (s *DriverAdapter) Rmdir(name string) error {
return s.Remove(name)
}

func (s *DriverAdapter) Stat(name string, _ bool) (*sftpd.Attr, error) {
stat, err := s.FtpDriver.Stat(name)
if err != nil {
return nil, err
}
return fileInfoToSftpAttr(stat), nil
}

func (s *DriverAdapter) SetStat(_ string, _ *sftpd.Attr) error {
return errs.NotSupport
}

func (s *DriverAdapter) ReadLink(_ string) (string, error) {
return "", errs.NotSupport
}

func (s *DriverAdapter) CreateLink(_, _ string, _ uint32) error {
return errs.NotSupport
}

func (s *DriverAdapter) RealPath(path string) (string, error) {
return utils.FixAndCleanPath(path), nil
}

func (s *DriverAdapter) GetHandle(name string, flags uint32, _ *sftpd.Attr, offset uint64) (sftpd.FileTransfer, error) {
return s.FtpDriver.GetHandle(name, sftpFlagToOpenMode(flags), int64(offset))
}

func (s *DriverAdapter) ReadDir(name string) ([]sftpd.NamedAttr, error) {
dir, err := s.FtpDriver.ReadDir(name)
if err != nil {
return nil, err
}
ret := make([]sftpd.NamedAttr, len(dir))
for i, d := range dir {
ret[i] = *fileInfoToSftpNamedAttr(d)
}
return ret, nil
}

// From leffss/sftpd
func sftpFlagToOpenMode(flags uint32) int {
mode := 0
if (flags & SSH_FXF_READ) != 0 {
mode |= os.O_RDONLY
}
if (flags & SSH_FXF_WRITE) != 0 {
mode |= os.O_WRONLY
}
if (flags & SSH_FXF_APPEND) != 0 {
mode |= os.O_APPEND
}
if (flags & SSH_FXF_CREAT) != 0 {
mode |= os.O_CREATE
}
if (flags & SSH_FXF_TRUNC) != 0 {
mode |= os.O_TRUNC
}
if (flags & SSH_FXF_EXCL) != 0 {
mode |= os.O_EXCL
}
return mode
}

func fileInfoToSftpAttr(stat os.FileInfo) *sftpd.Attr {
ret := &sftpd.Attr{}
ret.Flags |= sftpd.ATTR_SIZE
ret.Size = uint64(stat.Size())
ret.Flags |= sftpd.ATTR_MODE
ret.Mode = stat.Mode()
ret.Flags |= sftpd.ATTR_TIME
ret.ATime = stat.Sys().(model.Obj).CreateTime()
ret.MTime = stat.ModTime()
return ret
}

func fileInfoToSftpNamedAttr(stat os.FileInfo) *sftpd.NamedAttr {
return &sftpd.NamedAttr{
Name: stat.Name(),
Attr: *fileInfoToSftpAttr(stat),
}
}

0 comments on commit 6dcba5e

Please # to comment.