diff --git a/main.go b/main.go index 08de605..9d05e98 100644 --- a/main.go +++ b/main.go @@ -128,30 +128,35 @@ func check(response http.ResponseWriter, req *http.Request) { // 获取传入域名 if len(req.Form.Get("domain")) == 0 { + response.WriteHeader(400) fmt.Fprintf(response, "No domain specified.") return } domain = req.Form.Get("domain") // 获取传入文件名 if len(req.Form.Get("file")) == 0 { + response.WriteHeader(400) fmt.Fprintf(response, "No file specified.") return } file = req.Form.Get("file") // 获取传入签名 if len(req.Form.Get("sign")) == 0 { + response.WriteHeader(400) fmt.Fprintf(response, "No sign specified.") return } sign = req.Form.Get("sign") // 获取传入验证码 if len(req.Form.Get("checksum")) == 0 { + response.WriteHeader(400) fmt.Fprintf(response, "No checksum specified.") return } checksum = req.Form.Get("checksum") // 获取传入时间戳 if len(req.Form.Get("t")) == 0 { + response.WriteHeader(400) fmt.Fprintf(response, "No timestamp specified.") return } @@ -162,6 +167,7 @@ func check(response http.ResponseWriter, req *http.Request) { if err != nil { fmt.Println("Access from IP:", ip) fmt.Println("Incoming illegal timestamp:", t) + response.WriteHeader(403) fmt.Fprintf(response, "Timestamp not allowed.") return } @@ -170,6 +176,7 @@ func check(response http.ResponseWriter, req *http.Request) { if expireTime < -timeRange { fmt.Println("Access from IP:", ip) fmt.Println("Incoming illegal timestamp:", expireTime) + response.WriteHeader(403) fmt.Fprintf(response, "Timestamp not allowed.") return } @@ -177,6 +184,7 @@ func check(response http.ResponseWriter, req *http.Request) { if expireTime > timeRange { fmt.Println("Access from IP:", ip) fmt.Println("Incoming expired access:", expireTime) + response.WriteHeader(403) fmt.Fprintf(response, "Timestamp expired.") return } @@ -200,6 +208,7 @@ func check(response http.ResponseWriter, req *http.Request) { // 检测到重放请求 fmt.Println("Access from IP:", ip) fmt.Println("Incoming repeat access:", checksum) + response.WriteHeader(403) fmt.Fprintf(response, "Repeat access.") return } @@ -227,6 +236,7 @@ func check(response http.ResponseWriter, req *http.Request) { // 获取的域名不存在 fmt.Println("Access from IP:", ip) fmt.Println("Incoming illegal domain:", domain) + response.WriteHeader(404) fmt.Fprintf(response, "Domain not exist.") return } @@ -234,6 +244,7 @@ func check(response http.ResponseWriter, req *http.Request) { // 获取的文件不存在 fmt.Println("Access from IP:", ip) fmt.Println("Incoming illegal filename:", file) + response.WriteHeader(404) fmt.Fprintf(response, "File not exist.") return } @@ -246,6 +257,7 @@ func check(response http.ResponseWriter, req *http.Request) { // 签名错误 fmt.Println("Access from IP:", ip) fmt.Println("Incoming unauthorized access:", sign) + response.WriteHeader(401) fmt.Fprintf(response, "Unauthorized access.") } }