From 7a87d8d6543123d948679a0dffe394e689a61841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=83=E5=A4=8F=E6=B5=85=E7=AC=91?= Date: Sun, 23 Jan 2022 02:28:35 +0800 Subject: [PATCH] Support TLS & command parameter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加命令行参数支持 添加 TLS 监听支持 --- README.md | 41 +++++++++++++++++++- main.go | 110 ++++++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 129 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index a03e5dd..332689c 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # acmeDeliver +![GitHub](https://img.shields.io/github/license/julydate/acmeDeliver?style=flat-square) +![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/julydate/acmeDeliver?style=flat-square) + acme.sh 证书分发服务 将 acme.sh 获取的证书通过 http 服务分发到多台服务器 @@ -7,5 +10,39 @@ ## Usage ```bash -./acmeDeliver -``` \ No newline at end of file +$ ./acmeDeliver -h +acmeDeliver version: 0.1.0 +Usage: acmeDeliver [-h] [-p port] [-d dirname] [-k password] [-t time] [-b address] [-tls] [-tlsport port] [-cert filename] [-key filename] + +Options: + -h + 显示帮助信息 + -p string + 服务端口,默认 9090 (default "9090") + -d string + 证书文件所在目录,默认当前目录 (default "./") + -k string + 密码,默认 passwd (default "passwd") + -t int + 时间戳误差,默认 60 秒 (default 60) + -b string + 绑定监听地址,默认绑定所有接口 + -tls + 是否监听 TLS,默认关闭 + -tlsport string + TLS 服务端口,默认 9443 (default "9443") + -cert string + TLS 服务证书文件,默认 cert.pem + -key string + TLS 服务私钥文件,默认 key.pem (default "key.pem") +``` + +## Example + +```bash +./acmeDeliver -p 8080 -d "/tmp/acme" -k "passcode" -t 600 -b 0.0.0.0 -tls -tlsport 8443 -cert server.pem -key server.key +``` + +## Document + +待更新 diff --git a/main.go b/main.go index 9995e6f..4a4bd9d 100644 --- a/main.go +++ b/main.go @@ -2,11 +2,13 @@ package main import ( "crypto/md5" + "flag" "fmt" "io" "io/ioutil" "log" "net/http" + "os" "strconv" "time" @@ -14,34 +16,102 @@ import ( "github.com/zekroTJA/timedmap" ) -// PORT 服务端口 -const PORT string = "9090" +// VERSION 版本号 +const VERSION = "0.1.0" -// DIR 证书文件所在目录 -const DIR string = "./" +// h 帮助信息 +var h bool -// KEY 密码 -const KEY string = "passwd" +// bind 绑定监听地址 +var bind string -// EXPTIME 时间戳误差,单位秒 -const EXPTIME int64 = 86400 +// port 服务端口 +var port string -// 初始化全局变量 +// tls TLS 监听开关 +var tls bool + +// tlsport TLS 服务端口 +var tlsport string + +// certFile TLS 服务证书文件 +var certFile string + +// keyFile TLS 服务私钥文件 +var keyFile string + +// baseDir 证书文件所在目录 +var baseDir string + +// key 密码 +var key string + +// timeRange 时间戳误差,单位秒 +var timeRange int64 + +// 初始化从客户端获取的全局变量 var domain, file, t, checksum, sign string // Creates a new timed map which scans for expired keys every 1 second var tm = timedmap.New(1 * time.Second) +func init() { + // 初始化从命令行获取参数 + flag.BoolVar(&h, "h", false, "显示帮助信息") + flag.StringVar(&bind, "b", "", "绑定监听地址,默认绑定所有接口") + flag.StringVar(&port, "p", "9090", "服务端口,默认 9090") + flag.BoolVar(&tls, "tls", false, "是否监听 TLS,默认关闭") + flag.StringVar(&tlsport, "tlsport", "9443", "TLS 服务端口,默认 9443") + flag.StringVar(&certFile, "cert", "cert.pem", "TLS 服务证书文件,默认 cert.pem") + flag.StringVar(&keyFile, "key", "key.pem", "TLS 服务私钥文件,默认 key.pem") + flag.StringVar(&baseDir, "d", "./", "证书文件所在目录,默认当前目录") + flag.StringVar(&key, "k", "passwd", "密码,默认 passwd") + flag.Int64Var(&timeRange, "t", 60, "时间戳误差,默认 60 秒") + // 修改默认 Usage + flag.Usage = usage +} + func main() { + // 从 arguments 中解析注册的 flag + // 必须在所有 flag 都注册好而未访问其值时执行 + // 未注册却使用 flag -help 时,会返回 ErrHelp + flag.Parse() + + // 显示帮助信息 + if h { + flag.Usage() + return + } + // 设置访问的路由 http.HandleFunc("/", check) - // 设置监听的端口 - err := http.ListenAndServe(":"+PORT, nil) - if err != nil { + + // 启动 TLS 端口监听 + if tls { + // 使用 goroutine 进行监听防阻塞 + go func() { + if err := http.ListenAndServeTLS(bind+":"+tlsport, certFile, keyFile, nil); err != nil { + log.Fatal("ListenAndServeTLS:", err) + } + }() + } + + // 启动端口监听 + if err := http.ListenAndServe(bind+":"+port, nil); err != nil { log.Fatal("ListenAndServe:", err) } } +// 自定义帮助信息 +func usage() { + fmt.Fprintf(os.Stderr, `acmeDeliver version: `+VERSION+` +Usage: acmeDeliver [-h] [-p port] [-d dirname] [-k password] [-t time] [-b address] [-tls] [-tlsport port] [-cert filename] [-key filename] + +Options: +`) + flag.PrintDefaults() +} + func check(response http.ResponseWriter, req *http.Request) { // 解析 url 传递的参数,对于 POST 则解析响应包的主体(request body) err := req.ParseForm() @@ -97,14 +167,14 @@ func check(response http.ResponseWriter, req *http.Request) { } expireTime := time.Now().Unix() - reqTime // 时间戳太超前可以判定为异常访问 - if expireTime < -EXPTIME { + if expireTime < -timeRange { fmt.Println("Access from IP:", ip) fmt.Println("Incoming illegal timestamp:", expireTime) fmt.Fprintf(response, "Timestamp not allowed.") return } // 校验时间戳是否过期 - if expireTime > EXPTIME { + if expireTime > timeRange { fmt.Println("Access from IP:", ip) fmt.Println("Incoming expired access:", expireTime) fmt.Fprintf(response, "Timestamp expired.") @@ -115,7 +185,7 @@ func check(response http.ResponseWriter, req *http.Request) { token := md5.New() io.WriteString(token, domain) io.WriteString(token, file) - io.WriteString(token, KEY) + io.WriteString(token, key) io.WriteString(token, t) io.WriteString(token, checksum) checkToken := fmt.Sprintf("%x", token.Sum(nil)) @@ -124,7 +194,7 @@ func check(response http.ResponseWriter, req *http.Request) { if sign == checkToken { // 检测验证码是否重复 if checkTime, ok := tm.GetValue(checksum).(int64); ok { - if checkTime > 0 && time.Now().Unix()-checkTime > EXPTIME { + if checkTime > 0 && time.Now().Unix()-checkTime > timeRange { tm.Remove(checkTime) } else { // 检测到重放请求 @@ -134,12 +204,12 @@ func check(response http.ResponseWriter, req *http.Request) { return } } else { - tm.Set(checksum, reqTime, time.Duration(EXPTIME)*time.Second) + tm.Set(checksum, reqTime, time.Duration(timeRange)*time.Second) } // 校验域名是否在指定文件夹内 var checkDomain, checkFile bool = false, false - files, _ := ioutil.ReadDir(DIR) + files, _ := ioutil.ReadDir(baseDir) for _, f := range files { if domain == f.Name() { checkDomain = true @@ -147,7 +217,7 @@ func check(response http.ResponseWriter, req *http.Request) { } if checkDomain { // 对应域名的文件夹存在,校验内部文件是否存在 - files, _ := ioutil.ReadDir(DIR + domain) + files, _ := ioutil.ReadDir(baseDir + domain) for _, f := range files { if file == f.Name() { checkFile = true @@ -168,7 +238,7 @@ func check(response http.ResponseWriter, req *http.Request) { return } // 全部校验通过,放行文件 - filepath := DIR + domain + "/" + file + filepath := baseDir + domain + "/" + file fmt.Println("Access from IP:", ip) fmt.Println("Access file:", filepath) http.ServeFile(response, req, filepath)