diff --git a/.databases.json.example b/.databases.json.example index dfb09fb..40d773c 100644 --- a/.databases.json.example +++ b/.databases.json.example @@ -14,6 +14,15 @@ "dbName": "prod_database", "dbServer": "prod.db.server.com", "user": "user", - "pass": "pass" + "pass": "pass", + "sqlType": "mysql" + }, + "postgresdb": { + "appServer": "pg.server.com", + "dbName": "pg_database", + "dbServer": "pg.db.server.com", + "user": "postgres", + "pass": "postgres", + "sqlType": "postgres" } } diff --git a/Dockerfile b/Dockerfile index e3e33d1..c03600f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ FROM golang:1.11 -RUN apt-get update && apt-get install -y --no-install-recommends mysql-client && rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y --no-install-recommends mysql-client postgresql-client && rm -rf /var/lib/apt/lists/* ENTRYPOINT [ "go", "test", "-v", "." ] diff --git a/README.md b/README.md index 54ad27d..b45421f 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,10 @@ MySQL pipe ## What does it do? -- `sql` allows you to pipe STDIN (hopefully containing SQL) to one or more pre-configured MySQL databases +- `sql` allows you to pipe STDIN (hopefully containing SQL) to one or more pre-configured MySQL or PostgreSQL databases - output comes out in `\t`-separated format, allowing further piping (e.g. works really well with [chart](https://github.com/MarianoGappa/chart)) - when more than one database is queried, the requests are made in parallel -- `sql` can either run `mysql` locally, run `mysql` locally but connecting to a remote host (by configuring a `dbServer`), or `ssh` to a remote host and from there run `mysql` to either a local or remote host (by configuring an `appServer` and a `dbServer`) +- `sql` can either run `mysql/psql` locally, run `mysql/psql` locally but connecting to a remote host (by configuring a `dbServer`), or `ssh` to a remote host and from there run `mysql/psql` to either a local or remote host (by configuring an `appServer` and a `dbServer`) ## Installation @@ -37,6 +37,8 @@ $ compinit Create a `.databases.json` dotfile in your home folder or in any [XDG-compliant](https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html) directory. [This](.databases.json.example) is an example file. +`sql` decides to execute with MySQL or PostgreSQL depending on the `sqlType` property set for a database, *defaulting to to MySQL if not set.* + ## Example usages ``` @@ -50,8 +52,9 @@ sql all "SELECT * FROM users WHERE name = 'John'" ## Notes - when more than one database is queried, the resulting rows are prefixed with the database identifier -- the `all` special keyword means "sql to all configured databases" +- the `all` special keyword means "sql to all configured databases". - `sql` assumes that you have correctly configured SSH keys on all servers you `ssh` to +- `sql` will error if all targeted databases do not have the same sql type. ## Beware! @@ -61,7 +64,7 @@ sql all "SELECT * FROM users WHERE name = 'John'" ## Dependencies -- mysql +- mysql-client and/or postgresql-client - ssh (only if you configure an "appServer") ## Contribute diff --git a/config.go b/config.go index a546af8..2b05f02 100644 --- a/config.go +++ b/config.go @@ -15,6 +15,7 @@ type database struct { DbName string User string Pass string + SQLType string } func mustReadDatabasesConfigFile() map[string]database { diff --git a/main.go b/main.go index d79373d..2eba9e8 100644 --- a/main.go +++ b/main.go @@ -1,14 +1,11 @@ package main import ( - "bufio" "context" "flag" "fmt" "log" "os" - "os/exec" - "strings" "sync" ) @@ -84,11 +81,13 @@ func _main(databases map[string]database, databasesArgs []string, query string, var wg sync.WaitGroup wg.Add(len(targetDatabases)) + sqlRunner := mustNewSQLRunner(quitContext, println, query, len(targetDatabases) > 1) + returnCode := 0 for _, k := range targetDatabases { go func(db database, k string) { defer wg.Done() - if r := runSQL(quitContext, db, query, k, len(targetDatabases) > 1, println); !r { + if r := sqlRunner.runSQL(db, k); !r { returnCode = 1 } }(databases[k], k) @@ -97,78 +96,3 @@ func _main(databases map[string]database, databasesArgs []string, query string, wg.Wait() return returnCode } - -func runSQL(quitContext context.Context, db database, query string, key string, prependKey bool, println func(string)) bool { - userOption := "" - if db.User != "" { - userOption = fmt.Sprintf("-u %v ", db.User) - } - - passOption := "" - if db.Pass != "" { - passOption = fmt.Sprintf("-p%v ", db.Pass) - } - - hostOption := "" - if db.DbServer != "" { - hostOption = fmt.Sprintf("-h %v ", db.DbServer) - } - - prepend := "" - if prependKey { - prepend = key + "\t" - } - - mysql := "mysql" - options := fmt.Sprintf(" -Nsr %v%v%v%v -e ", userOption, passOption, hostOption, db.DbName) - - var cmd *exec.Cmd - if db.AppServer != "" { - escapedQuery := fmt.Sprintf(`'%v'`, strings.Replace(query, `'`, `'"'"'`, -1)) - cmd = exec.CommandContext(quitContext, "ssh", db.AppServer, mysql+options+escapedQuery) - } else { - args := append(trimEmpty(strings.Split(options, " ")), query) - cmd = exec.CommandContext(quitContext, mysql, args...) - } - - stdout, err := cmd.StdoutPipe() - if err != nil { - log.Printf("Cannot create pipe for STDOUT of running command on %v; not running. err=%v\n", key, err) - return false - } - - stderr, err := cmd.StderrPipe() - if err != nil { - log.Printf("Cannot create pipe for STDERR of running command on %v; not running. err=%v\n", key, err) - return false - } - - if err := cmd.Start(); err != nil { - log.Printf("Cannot start command on %v; not running. err=%v\n", key, err) - return false - } - - scanner := bufio.NewScanner(stdout) - for scanner.Scan() { - println(prepend + scanner.Text()) - } - - stderrLines := []string{} - scanner = bufio.NewScanner(stderr) - for scanner.Scan() { - stderrLines = append(stderrLines, scanner.Text()) - } - - cmd.Wait() - - result := true - if len(stderrLines) > 0 { - result = false - log.Println(key + " had errors:") - for _, v := range stderrLines { - log.Println(key + " [ERROR] " + v) - } - } - - return result -} diff --git a/main_test.go b/main_test.go index 0a25732..26bdfbd 100644 --- a/main_test.go +++ b/main_test.go @@ -11,56 +11,126 @@ import ( "time" ) -func TestSQL(t *testing.T) { - var err error - for i := 1; i <= 30; i++ { // Try up to 30 times, because MySQL takes a while to become online - var c = exec.Command("mysql", "-h", "test-mysql", "-u", "root", "-e", "SELECT * FROM db1.table1") - if err = c.Run(); err == nil { - break +type tests []struct { + name string + targetDBs []string + query string + expected []string +} +type testConfig map[string]database + +var baseTests = tests{ + { + name: "reads from one database", + targetDBs: []string{"db1"}, + query: "SELECT id FROM table1", + expected: []string{ + "", + "1", + "2", + "3", + }, + }, + { + name: "reads from two databases", + targetDBs: []string{"db1", "db2"}, + query: "SELECT id FROM table1", + expected: []string{ + "", + "db1\t1", + "db1\t2", + "db1\t3", + "db2\t1", + "db2\t2", + "db2\t3", + }, + }, + { + name: "reads from all databases with the all keyword", + targetDBs: []string{"all"}, + query: "SELECT id FROM table1", + expected: []string{ + "", + "db1\t1", + "db1\t2", + "db1\t3", + "db2\t1", + "db2\t2", + "db2\t3", + "db3\t1", + "db3\t2", + "db3\t3", + }, + }, + { + name: "reads two fields from all databases", + targetDBs: []string{"all"}, + query: "SELECT id, name FROM table1", + expected: []string{ + "", + "db1\t1\tJohn", + "db1\t2\tGeorge", + "db1\t3\tRichard", + "db2\t1\tRob", + "db2\t2\tKen", + "db2\t3\tRobert", + "db3\t1\tAthos", + "db3\t2\tPorthos", + "db3\t3\tAramis", + }, + }, +} + +func Test_MySQL(t *testing.T) { + awaitDB(mySQL, t) + + var ( + testConfig = testConfig{ + "db1": database{DbServer: "test-mysql", DbName: "db1", User: "root", Pass: "", SQLType: "mysql"}, + "db2": database{DbServer: "test-mysql", DbName: "db2", User: "root", Pass: "", SQLType: ""}, + "db3": database{DbServer: "test-mysql", DbName: "db3", User: "root", Pass: "", SQLType: ""}, } - log.Printf("Retrying (%v/30) in 1 sec because MySQL is not yet ready", i) - time.Sleep(1 * time.Second) - } - for err != nil { - t.Errorf("bailing because couldn't connect to MySQL after 30 tries: %v", err) - t.FailNow() - } + ) + runTests(baseTests, testConfig, t) +} + +func Test_PostgreSQL(t *testing.T) { + awaitDB(postgreSQL, t) var ( - testConfig = map[string]database{ - "db1": database{DbServer: "test-mysql", DbName: "db1", User: "root", Pass: ""}, - "db2": database{DbServer: "test-mysql", DbName: "db2", User: "root", Pass: ""}, - "db3": database{DbServer: "test-mysql", DbName: "db3", User: "root", Pass: ""}, + testConfig = testConfig{ + "db1": database{DbServer: "test-postgres", DbName: "db1", User: "root", Pass: "", SQLType: "postgres"}, + "db2": database{DbServer: "test-postgres", DbName: "db2", User: "root", Pass: "", SQLType: "postgres"}, + "db3": database{DbServer: "test-postgres", DbName: "db3", User: "root", Pass: "", SQLType: "postgres"}, } - ts = []struct { - name string - targetDBs []string - query string - expected []string - }{ - { - name: "reads from one database", - targetDBs: []string{"db1"}, - query: "SELECT id FROM table1", - expected: []string{ - "", - "1", - "2", - "3", - }, - }, + ) + runTests(baseTests, testConfig, t) +} + +func Test_Mix_Mysql_PostgreSQL(t *testing.T) { + awaitDB(mySQL, t) + awaitDB(postgreSQL, t) + + var ( + testConfig = testConfig{ + "db1": database{DbServer: "test-postgres", DbName: "db1", User: "root", Pass: "", SQLType: "postgres"}, + "db2": database{DbServer: "test-postgres", DbName: "db2", User: "root", Pass: "", SQLType: "postgres"}, + "db3": database{DbServer: "test-mysql", DbName: "db3", User: "root", Pass: "", SQLType: ""}, + "db4": database{DbServer: "test-mysql", DbName: "db1", User: "root", Pass: "", SQLType: "mysql"}, + } + ts = tests{ { name: "reads from two databases", - targetDBs: []string{"db1", "db2"}, + targetDBs: []string{"db1", "db3"}, query: "SELECT id FROM table1", expected: []string{ "", - "db1 1", - "db1 2", - "db1 3", - "db2 1", - "db2 2", - "db2 3", + "db1\t1", + "db1\t2", + "db1\t3", + "db3\t1", + "db3\t2", + "db3\t3", }, }, { @@ -69,15 +139,18 @@ func TestSQL(t *testing.T) { query: "SELECT id FROM table1", expected: []string{ "", - "db1 1", - "db1 2", - "db1 3", - "db2 1", - "db2 2", - "db2 3", - "db3 1", - "db3 2", - "db3 3", + "db1\t1", + "db1\t2", + "db1\t3", + "db2\t1", + "db2\t2", + "db2\t3", + "db3\t1", + "db3\t2", + "db3\t3", + "db4\t1", + "db4\t2", + "db4\t3", }, }, { @@ -86,19 +159,26 @@ func TestSQL(t *testing.T) { query: "SELECT id, name FROM table1", expected: []string{ "", - "db1 1 John", - "db1 2 George", - "db1 3 Richard", - "db2 1 Rob", - "db2 2 Ken", - "db2 3 Robert", - "db3 1 Athos", - "db3 2 Porthos", - "db3 3 Aramis", + "db1\t1\tJohn", + "db1\t2\tGeorge", + "db1\t3\tRichard", + "db2\t1\tRob", + "db2\t2\tKen", + "db2\t3\tRobert", + "db3\t1\tAthos", + "db3\t2\tPorthos", + "db3\t3\tAramis", + "db4\t1\tJohn", + "db4\t2\tGeorge", + "db4\t3\tRichard", }, }, } ) + runTests(ts, testConfig, t) +} + +func runTests(ts tests, testConfig testConfig, t *testing.T) { for _, tc := range ts { t.Run(tc.name, func(t *testing.T) { var buf = bytes.Buffer{} @@ -106,8 +186,32 @@ func TestSQL(t *testing.T) { var actual = strings.Split(buf.String(), "\n") sort.Strings(actual) if !reflect.DeepEqual(tc.expected, actual) { - t.Errorf("Expected %v but got %v", tc.expected, actual) + t.Errorf("Expected\n%#v\n but got\n%#v\n", tc.expected, actual) } }) } } + +func awaitDB(typ sqlType, t *testing.T) { + var pgTestCmds = []string{"-h", "test-postgres", "-U", "root", "-d", "db1", "-c", "SELECT * FROM table1"} + var msTestCmds = []string{"-h", "test-mysql", "-u", "root", "-e", "SELECT * FROM db1.table1"} + var err error + var c *exec.Cmd + for i := 1; i <= 30; i++ { + if typ == mySQL { + c = exec.Command("mysql", msTestCmds...) + } else { + c = exec.Command("psql", pgTestCmds...) + } + if err = c.Run(); err == nil { + log.Printf("%s ready after %v/30 tries", typ, i) + break + } + log.Printf("Retrying (%v/30) in 1 sec because %s is not yet ready", i, typ) + time.Sleep(1 * time.Second) + } + for err != nil { + t.Errorf("bailing because couldn't connect to %s after 30 tries: %v", typ, err) + t.FailNow() + } +} diff --git a/test_schemas.sql b/mysql_test_schemas.sql similarity index 100% rename from test_schemas.sql rename to mysql_test_schemas.sql diff --git a/postgres_test_schemas.sql b/postgres_test_schemas.sql new file mode 100644 index 0000000..f6ba1d6 --- /dev/null +++ b/postgres_test_schemas.sql @@ -0,0 +1,47 @@ +DROP DATABASE IF EXISTS db1; +CREATE DATABASE db1; + +DROP DATABASE IF EXISTS db2; +CREATE DATABASE db2; + +DROP DATABASE IF EXISTS db3; +CREATE DATABASE db3; + +\c db1 + +CREATE TABLE table1( + id integer NOT NULL, + name varchar(255) NOT NULL +); + +INSERT INTO table1 (id, name) +VALUES +(1, 'John'), +(2, 'George'), +(3, 'Richard'); + +\c db2 + +CREATE TABLE table1( + id integer NOT NULL, + name varchar(255) NOT NULL +); + +INSERT INTO table1 (id, name) +VALUES +(1, 'Rob'), +(2, 'Ken'), +(3, 'Robert'); + +\c db3 + +CREATE TABLE table1( + id integer NOT NULL, + name varchar(255) NOT NULL +); + +INSERT INTO table1 (id, name) +VALUES +(1, 'Athos'), +(2, 'Porthos'), +(3, 'Aramis'); diff --git a/sql_runner.go b/sql_runner.go new file mode 100644 index 0000000..72b634e --- /dev/null +++ b/sql_runner.go @@ -0,0 +1,181 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "log" + "os" + "os/exec" + "strings" +) + +type sqlType int + +const ( + mySQL sqlType = iota + 1 + postgreSQL +) + +func (t sqlType) String() string { + return [...]string{"", "MySQL", "PostgreSQL"}[t] +} + +type exists struct{} + +type sqlOptions struct { + cmd string + user string + host string + pass string + db string + flags string +} + +var validSQLTypes = map[string]sqlType{ + "": mySQL, + "mysql": mySQL, + "postgres": postgreSQL, +} + +var sqlTypeToOptions = map[sqlType]sqlOptions{ + mySQL: { + "mysql", + "-u%v", + "-h%v", + "-p%v", + "%v", + "-Nsre", + }, + postgreSQL: { + "psql", + "-U %v", + "-h%v", + "PGPASSWORD=%v", + "-d %v", + "-tAc", + }, +} + +type sqlRunner struct { + printer func(string) + query string + quitContext context.Context + multi bool +} + +func mustNewSQLRunner(quitContext context.Context, printer func(string), query string, multi bool) *sqlRunner { + return &sqlRunner{ + printer, + query, + quitContext, + multi, + } +} + +func (sr *sqlRunner) runSQL(db database, key string) bool { + typ, ok := validSQLTypes[db.SQLType] + if !ok { + return maybeErrorResult(key, fmt.Sprintf("Unknown sql type %v for %v", db.SQLType, key)) + } + + sqlOptions := sqlTypeToOptions[typ] + + userOption := "" + if db.User != "" { + userOption = fmt.Sprintf(sqlOptions.user, db.User) + } + + passOption := "" + if db.Pass != "" { + passOption = fmt.Sprintf(sqlOptions.pass, db.Pass) + } + + hostOption := "" + if db.DbServer != "" { + hostOption = fmt.Sprintf(sqlOptions.host, db.DbServer) + } + + dbOption := "" + if db.DbName != "" { + dbOption = fmt.Sprintf(sqlOptions.db, db.DbName) + } + + prepend := "" + if sr.multi { + prepend = key + "\t" + } + + options := "" + if typ == postgreSQL { + options = fmt.Sprintf("%v %v %v %v %v", sqlOptions.cmd, userOption, hostOption, dbOption, sqlOptions.flags) + } else { + options = fmt.Sprintf("%v %v %v %v %v %v", sqlOptions.cmd, dbOption, userOption, passOption, hostOption, sqlOptions.flags) + } + + var cmd *exec.Cmd + if db.AppServer != "" { + escapedQuery := fmt.Sprintf(`'%v'`, strings.Replace(sr.query, `'`, `'"'"'`, -1)) + if typ == postgreSQL { + escapedQuery += fmt.Sprintf("-F%s", "\t") + } + + cmd = exec.CommandContext(sr.quitContext, "ssh", db.AppServer, options+escapedQuery) + + } else { + args := append(trimEmpty(strings.Split(options, " ")), sr.query) + if typ == postgreSQL { + args = append(args, fmt.Sprintf("-F%s", "\t")) + } + cmd = exec.CommandContext(sr.quitContext, args[0], args[1:]...) + } + + if typ == postgreSQL { + cmd.Env = os.Environ() + cmd.Env = append(cmd.Env, passOption) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return maybeErrorResult(key, fmt.Sprintf("Cannot create pipe for STDOUT of running command on %v; not running. err=%v\n", key, err)) + } + + stderr, err := cmd.StderrPipe() + if err != nil { + return maybeErrorResult(key, fmt.Sprintf("Cannot create pipe for STDERR of running command on %v; not running. err=%v\n", key, err)) + } + + if err := cmd.Start(); err != nil { + return maybeErrorResult(key, fmt.Sprintf("Cannot start command on %v; not running. err=%v\n", key, err)) + } + + scanner := bufio.NewScanner(stdout) + for scanner.Scan() { + line := scanner.Text() + if line != "" { + sr.printer(prepend + strings.TrimSpace(line)) + } + } + + stderrLines := []string{} + scanner = bufio.NewScanner(stderr) + for scanner.Scan() { + stderrLines = append(stderrLines, scanner.Text()) + } + + cmd.Wait() + + return maybeErrorResult(key, stderrLines...) +} + +func maybeErrorResult(key string, stderrLines ...string) bool { + result := true + if len(stderrLines) > 0 { + result = false + log.Println(key + " had errors:") + for _, v := range stderrLines { + log.Println(key + " [ERROR] " + v) + } + } + return result +} diff --git a/test-docker-compose.yml b/test-docker-compose.yml index 1677829..f6b871a 100644 --- a/test-docker-compose.yml +++ b/test-docker-compose.yml @@ -7,10 +7,18 @@ services: working_dir: /go/src/sql depends_on: - test-mysql + - test-postgres test-mysql: image: mysql:5 volumes: - - ./test_schemas.sql:/docker-entrypoint-initdb.d/test_schemas.sql + - ./mysql_test_schemas.sql:/docker-entrypoint-initdb.d/mysql_test_schemas.sql environment: MYSQL_ROOT_PASSWORD: MYSQL_ALLOW_EMPTY_PASSWORD: "yes" + test-postgres: + image: postgres:9-alpine + volumes: + - ./postgres_test_schemas.sql:/docker-entrypoint-initdb.d/postgres_test_schemas.sql + environment: + POSTGRES_USER: root + POSTGRES_PASSWORD: