diff --git a/README.md b/README.md index c2c4d1e..4fa741f 100644 --- a/README.md +++ b/README.md @@ -2,3 +2,28 @@ switch cloudflare FW to under_attack mode when server load is high [![Go](https://github.com/amnonbc/underattack/actions/workflows/go2.yml/badge.svg)](https://github.com/amnonbc/underattack/actions/workflows/go2.yml) + +This runs on a web host which sits behind cloudflare. + +It is run using a cron like + +``` +*/5 * * * * ${HOME}/bin/underattack -config ${HOME}/etc/underatttack.conf -default_level high >> ${HOME}/logs/underattack.log 2>&1 +``` + +It checks the server's load, free memory and tries to connect to the db. +If the load is too high, the free memory is too low, or if it can't connect to the db, +then it sets CF into under-attack mode, which massively reduces bot traffic which usually allows server resources to recover. +Once the server has recovered, then it turns off under-attack mode. + +The config file looks like this: + +```json +{ + "domain": "mydomain.com", + "apiKey": "cfApiKeyWithZoneZoneReadAmdZoneSecuritySet", + "DbName": "nameOfDb", + "DbUser": "NameOfDbUser", + "DbPassword": "dbUserPassword" +} +``` \ No newline at end of file diff --git a/checkdb.go b/checkdb.go index b16534a..d6bfe3e 100644 --- a/checkdb.go +++ b/checkdb.go @@ -6,9 +6,9 @@ import ( _ "github.com/go-sql-driver/mysql" ) -func checkDb(conf Config) error { +func (a *app) checkDb() error { db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(127.0.0.1:3306)/%s", - conf.DbUser, conf.DbPassword, conf.DbName)) + a.conf.DbUser, a.conf.DbPassword, a.conf.DbName)) if err != nil { return err } diff --git a/under.go b/under.go index 0a994a1..6394bb1 100644 --- a/under.go +++ b/under.go @@ -24,15 +24,19 @@ type Config struct { DbPassword string } -var config Config +type app struct { + conf Config + api *cloudflare.API + zoneId string +} -func loadConfig(fn string) error { +func (a *app) loadConfig(fn string) error { f, err := os.Open(fn) if err != nil { return err } defer f.Close() - return json.NewDecoder(f).Decode(&config) + return json.NewDecoder(f).Decode(&a.conf) } func loadAvg(text string) ([]float64, error) { @@ -54,25 +58,27 @@ func loadAvg(text string) ([]float64, error) { return res, nil } -// setSecurityLevel sets the security level. value must be one of -// off, essentially_off, low, medium, high, under_attack. -func setSecurityLevel(value string) error { - api, err := cloudflare.NewWithAPIToken(config.ApiKey) - if err != nil { - return err - } - zoneID, err := api.ZoneIDByName(config.Domain) +func (a *app) init() error { + var err error + a.api, err = cloudflare.NewWithAPIToken(a.conf.ApiKey) if err != nil { return err } + a.zoneId, err = a.api.ZoneIDByName(a.conf.Domain) + return err +} - currentLevel, err := currentLevel(api, zoneID) +// setSecurityLevel sets the security level. value must be one of +// off, essentially_off, low, medium, high, under_attack. +// If the cf security is already at the reauested level, then do nothing. +func (a *app) setSecurityLevel(value string) error { + currentLevel, err := a.currentLevel() if currentLevel == value { return nil } log.Println("setting security level to", value) - _, err = api.UpdateZoneSettings(context.TODO(), zoneID, []cloudflare.ZoneSetting{ + _, err = a.api.UpdateZoneSettings(context.TODO(), a.zoneId, []cloudflare.ZoneSetting{ { ID: securityLevel, Value: value, @@ -81,15 +87,15 @@ func setSecurityLevel(value string) error { return err } -func mustSetSecurityLevel(value string) { - err := setSecurityLevel(value) +func (a *app) mustSetSecurityLevel(value string) { + err := a.setSecurityLevel(value) if err != nil { log.Fatalln(err) } } -func currentLevel(api *cloudflare.API, zoneID string) (string, error) { - settings, err := api.ZoneSettings(context.TODO(), zoneID) +func (a *app) currentLevel() (string, error) { + settings, err := a.api.ZoneSettings(context.TODO(), a.zoneId) if err != nil { return "", err } @@ -114,7 +120,13 @@ func main() { if err != nil { log.Fatalln(err) } - err = loadConfig(*cf) + var a app + err = a.loadConfig(*cf) + if err != nil { + log.Fatalln(err) + } + + err = a.init() if err != nil { log.Fatalln(err) } @@ -131,23 +143,23 @@ func main() { log.Println("freeMem", humanize.Bytes(freeMem), "load", la) if freeMem < mb { log.Println("free memory is below", *minBytesStr) - mustSetSecurityLevel("under_attack") + a.mustSetSecurityLevel("under_attack") return } - err = checkDb(config) + err = a.checkDb() if err != nil { log.Println("checkDb returned", err) - mustSetSecurityLevel("under_attack") + a.mustSetSecurityLevel("under_attack") return } if la[0] >= *maxLoad { log.Println("Load average is", la, "setting level to under_attack") - mustSetSecurityLevel("under_attack") + a.mustSetSecurityLevel("under_attack") return } if la[0] < *minLoad && la[1] < *minLoad && la[2] < *minLoad { - mustSetSecurityLevel(*defaultSecurityLevel) + a.mustSetSecurityLevel(*defaultSecurityLevel) return } }