diff --git a/cmd/nps/nps.go b/cmd/nps/nps.go index baa930bb..c3b4d334 100644 --- a/cmd/nps/nps.go +++ b/cmd/nps/nps.go @@ -1,14 +1,6 @@ package main import ( - "ehang.io/nps/lib/crypt" - "ehang.io/nps/lib/file" - "ehang.io/nps/lib/install" - "ehang.io/nps/lib/version" - "ehang.io/nps/server" - "ehang.io/nps/server/connection" - "ehang.io/nps/server/tool" - "ehang.io/nps/web/routers" "flag" "log" "os" @@ -18,7 +10,16 @@ import ( "strings" "sync" + "ehang.io/nps/lib/file" + "ehang.io/nps/lib/install" + "ehang.io/nps/lib/version" + "ehang.io/nps/server" + "ehang.io/nps/server/connection" + "ehang.io/nps/server/tool" + "ehang.io/nps/web/routers" + "ehang.io/nps/lib/common" + "ehang.io/nps/lib/crypt" "ehang.io/nps/lib/daemon" "github.com/astaxie/beego" "github.com/astaxie/beego/logs" @@ -200,7 +201,8 @@ func run() { } logs.Info("the version of server is %s ,allow client core version to be %s", version.VERSION, version.GetVersion()) connection.InitConnectionService() - crypt.InitTls(filepath.Join(common.GetRunPath(), "conf", "server.pem"), filepath.Join(common.GetRunPath(), "conf", "server.key")) + //crypt.InitTls(filepath.Join(common.GetRunPath(), "conf", "server.pem"), filepath.Join(common.GetRunPath(), "conf", "server.key")) + crypt.InitTls() tool.InitAllowPort() tool.StartSystemInfo() go server.StartNewServer(bridgePort, task, beego.AppConfig.String("bridge_type")) diff --git a/lib/crypt/tls.go b/lib/crypt/tls.go index 35a0a748..799a1c04 100644 --- a/lib/crypt/tls.go +++ b/lib/crypt/tls.go @@ -1,22 +1,37 @@ package crypt import ( + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "log" + "math/big" "net" "os" + "time" "github.com/astaxie/beego/logs" ) -var pemPath, keyPath string +var ( + cert tls.Certificate +) -func InitTls(pem, key string) { - pemPath = pem - keyPath = key +func InitTls() { + c, k, err := generateKeyPair("NPS Org") + if err == nil { + cert, err = tls.X509KeyPair(c, k) + } + if err != nil { + log.Fatalln("Error initializing crypto certs", err) + } } func NewTlsServerConn(conn net.Conn) net.Conn { - cert, err := tls.LoadX509KeyPair(pemPath, keyPath) + var err error if err != nil { logs.Error(err) os.Exit(0) @@ -32,3 +47,41 @@ func NewTlsClientConn(conn net.Conn) net.Conn { } return tls.Client(conn, conf) } + +func generateKeyPair(CommonName string) (rawCert, rawKey []byte, err error) { + // Create private key and self-signed certificate + // Adapted from https://golang.org/src/crypto/tls/generate_cert.go + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return + } + validFor := time.Hour * 24 * 365 * 10 // ten years + notBefore := time.Now() + notAfter := notBefore.Add(validFor) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"My Company Name LTD."}, + CommonName: CommonName, + Country: []string{"US"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return + } + + rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return +} diff --git a/server/proxy/http.go b/server/proxy/http.go index af4ad805..78a26b6e 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -130,7 +130,7 @@ func (s *httpServer) handleHttp(c *conn.Conn, r *http.Request) { defer func() { if connClient != nil { connClient.Close() - }else { + } else { s.writeConnFail(c.Conn) } c.Close() diff --git a/server/test/test.go b/server/test/test.go index a30d03d6..3fce9eed 100644 --- a/server/test/test.go +++ b/server/test/test.go @@ -52,10 +52,10 @@ func TestServerConfig() { if port, err := strconv.Atoi(p); err != nil { log.Fatalln("get https port error", err) } else { - if !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("pemPath"))) { + if beego.AppConfig.String("pemPath") != "" && !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("pemPath"))) { log.Fatalf("ssl certFile %s is not exist", beego.AppConfig.String("pemPath")) } - if !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("ketPath"))) { + if beego.AppConfig.String("keyPath") != "" && !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("keyPath"))) { log.Fatalf("ssl keyFile %s is not exist", beego.AppConfig.String("pemPath")) } isInArr(&postTcpArr, port, "http port", "tcp")