diff --git a/serve.go b/serve.go index e430a8e..aa22c89 100644 --- a/serve.go +++ b/serve.go @@ -15,6 +15,7 @@ import ( "crypto/ecdsa" "crypto/rand" "crypto/rsa" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" @@ -51,36 +52,11 @@ func publicKey(priv interface{}) interface{} { } } -func pemBlockForKey(priv interface{}) *pem.Block { - switch k := priv.(type) { - case *rsa.PrivateKey: - return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)} - case *ecdsa.PrivateKey: - b, err := x509.MarshalECPrivateKey(k) - if err != nil { - fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err) - os.Exit(2) - } - return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b} - default: - return nil - } -} - -// This is extracted from src/crypto/tls/generate_cert.go, keeping only the -// bare minimum needed to create a usable cert for localhost development. -func generateSelfSignedCert() { +func createCertificateForKey(key *rsa.PrivateKey) ([]byte, error) { host := "localhost" validFor := 365 * 24 * time.Hour - rsaBits := 2048 - var priv interface{} var err error - priv, err = rsa.GenerateKey(rand.Reader, rsaBits) - - if err != nil { - log.Fatalf("failed to generate private key: %s", err) - } notBefore := time.Now() notAfter := notBefore.Add(validFor) @@ -113,30 +89,84 @@ func generateSelfSignedCert() { } } - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv) + return x509.CreateCertificate(rand.Reader, &template, &template, publicKey(key), key) +} + +// This is extracted from src/crypto/tls/generate_cert.go, keeping only the +// bare minimum needed to create a usable cert for localhost development. +func generateOnDiskCert(path string) (string, string) { + rsaBits := 2048 + + var key *rsa.PrivateKey + var err error + key, err = rsa.GenerateKey(rand.Reader, rsaBits) + if err != nil { + log.Fatalf("failed to generate private key: %s", err) + } + + derBytes, err := createCertificateForKey(key) if err != nil { log.Fatalf("Failed to create certificate: %s", err) } - certOut, err := os.Create("cert.pem") + certPath := fmt.Sprintf("%s/%s", path, "cert.pem") + keyPath := fmt.Sprintf("%s/%s", path, "key.pem") + + // Skip creating if the files already exists + if _, err := os.Stat(certPath); os.IsNotExist(err) { + _ = os.MkdirAll(path, os.ModePerm) + + certOut, err := os.Create(certPath) + if err != nil { + log.Fatalf("failed to open %s for writing: %s", certPath, err) + } + pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + certOut.Close() + log.Printf("written %s\n", certPath) + + keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + log.Print("failed to open %s for writing:", keyPath, err) + return "", "" + } + + pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + keyOut.Close() + log.Printf("written %s\n", keyPath) + } + + return certPath, keyPath +} + +func generateInMemoryCert() ([]byte, []byte, error) { + // host := "localhost" + // validFor := 365 * 24 * time.Hour + rsaBits := 2048 + + var key *rsa.PrivateKey + var err error + key, err = rsa.GenerateKey(rand.Reader, rsaBits) if err != nil { - log.Fatalf("failed to open cert.pem for writing: %s", err) + log.Fatalf("failed to generate private key: %s", err) } - pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - certOut.Close() - log.Print("written cert.pem\n") - keyOut, err := os.OpenFile("key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + derBytes, err := createCertificateForKey(key) if err != nil { - log.Print("failed to open key.pem for writing:", err) - return + log.Fatalf("Failed to create certificate: %s", err) } - pem.Encode(keyOut, pemBlockForKey(priv)) - keyOut.Close() - log.Print("written key.pem\n") + + certArray := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + keyArray := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + + return certArray, keyArray, nil } func main() { + homeDir, err := os.UserHomeDir() + if err != nil { + log.Fatal(err) + } + var gzip bool var port int var path string = "./" @@ -144,8 +174,9 @@ func main() { var http2 bool var headerFlags cli.StringSlice var headersMap map[string]string + var certSave bool + var certDir string = fmt.Sprintf("%s/.serve", homeDir) var stop = make(chan os.Signal, 1) - var err error var server *http.Server app := &cli.App{ @@ -185,6 +216,19 @@ func main() { Usage: "custom header(s) to add to the response (can be repeated multiple times)", Destination: &headerFlags, }, + &cli.BoolFlag{ + Name: "cert-save", + Aliases: []string{"sc"}, + Usage: "whether to save the generated self-signed certificates to disk", + Destination: &certSave, + }, + &cli.StringFlag{ + Name: "cert-dir", + Aliases: []string{"cd"}, + Usage: "location to save certificate at if saving to disk", + Value: certDir, + Destination: &certDir, + }, }, Action: func(c *cli.Context) error { @@ -223,11 +267,36 @@ func main() { } serveHttps := func() { - if _, err := os.Stat("cert.pem"); os.IsNotExist(err) { - generateSelfSignedCert() + if certSave { + certPath, keyPath := generateOnDiskCert(certDir) + + err = server.ListenAndServeTLS(certPath, keyPath) + } else { + cert, key, err := generateInMemoryCert() + if err != nil { + log.Fatal("Error: Couldn't create https certs.") + } + + keyPair, err := tls.X509KeyPair(cert, key) + if err != nil { + log.Fatal(err) + log.Fatal("Error: Couldn't create key pair") + } + + var certificates []tls.Certificate + certificates = append(certificates, keyPair) + + cfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + Certificates: certificates, + } + + server.TLSConfig = cfg + + log.Fatal(server.ListenAndServeTLS("", "")) } - err = server.ListenAndServeTLS("cert.pem", "key.pem") if err != nil { log.Fatalf("ListenAndServeTLS error: %s", err) }