diff --git a/cmd/bank/main.go b/cmd/bank/main.go index e7accda1..b6e0bc64 100644 --- a/cmd/bank/main.go +++ b/cmd/bank/main.go @@ -57,8 +57,8 @@ func run(ctx context.Context) error { } appName := os.Args[1] - db := sqltest.Open("postgres") - defer db.Close() + db, _, close := sqltest.Open("postgres") + defer close() // TODO: return to using SQL // if err := postgres.CreateSchema(ctx, db); err != nil { diff --git a/internal/testing/boltdbtest/db.go b/internal/testing/boltdbtest/db.go index 6de0260e..217a268b 100644 --- a/internal/testing/boltdbtest/db.go +++ b/internal/testing/boltdbtest/db.go @@ -13,7 +13,7 @@ import ( // The returned function must be used to close the database, instead of // DB.Close(). func Open() (*bbolt.DB, func()) { - filename, remove := TempFile() + filename, remove := tempFile() db, err := bbolt.Open(filename, 0600, nil) if err != nil { @@ -26,11 +26,11 @@ func Open() (*bbolt.DB, func()) { } } -// TempFile returns the name of a temporary file to be used for a BoltDB +// tempFile returns the name of a temporary file to be used for a BoltDB // database. // // It returns a function that deletes the temporary file. -func TempFile() (string, func()) { +func tempFile() (string, func()) { f, err := ioutil.TempFile("", "*.boltdb") if err != nil { panic(err) diff --git a/internal/testing/sqltest/db.go b/internal/testing/sqltest/db.go index bc9ccbbc..5750233b 100644 --- a/internal/testing/sqltest/db.go +++ b/internal/testing/sqltest/db.go @@ -3,50 +3,98 @@ package sqltest import ( "database/sql" "fmt" + "io/ioutil" "os" + "sync" _ "github.com/go-sql-driver/mysql" // keep driver import near code that uses it - "github.com/google/uuid" - _ "github.com/lib/pq" // keep driver import near code that uses it - _ "github.com/mattn/go-sqlite3" // keep driver import near code that uses it + _ "github.com/lib/pq" // keep driver import near code that uses it + _ "github.com/mattn/go-sqlite3" // keep driver import near code that uses it ) // DSN returns the DSN for the test database to use with the given SQL driver. -func DSN(driver string) string { - var env, dsn string +// +// The returned function must be used to cleanup any data created for the DSN, +// such as temporary on-disk databases. +func DSN(driver string) (string, func()) { + dsn := dsnFromEnv(driver) + if dsn != "" { + return dsn, func() {} + } switch driver { case "mysql": - env = "DOGMATIQ_TEST_MYSQL_DSN" - dsn = "root:rootpass@tcp(127.0.0.1:3306)/dogmatiq" - case "sqlite3": - env = "DOGMATIQ_TEST_SQLITE_DSN" - dsn = fmt.Sprintf( - "file:sqlite-%s.db?cache=shared&mode=memory", - uuid.New().String(), - ) + return "root:rootpass@tcp(127.0.0.1:3306)/dogmatiq", func() {} case "postgres": - env = "DOGMATIQ_TEST_POSTGRES_DSN" - dsn = "user=postgres password=rootpass sslmode=disable" + return "user=postgres password=rootpass sslmode=disable", func() {} default: - panic("unsupported driver: " + driver) + file, close := tempFile() + return fmt.Sprintf("file:%s?mode=rwc", file), close } +} - if v := os.Getenv(env); v != "" { - return v +// Open returns the test database to use with the given driver. +// +// The returned function must be used to close the database, instead of +// DB.Close(). +func Open( + driver string, +) ( + db *sql.DB, + dsn string, + close func(), +) { + dsn, closeDSN := DSN(driver) + + db, err := sql.Open(driver, dsn) + if err != nil { + panic(err) } - return dsn + return db, dsn, func() { + db.Close() + closeDSN() + } } -// Open returns the test database to use with the given driver. -func Open(driver string) *sql.DB { - dsn := DSN(driver) +// dsnFromEnv returns a DSN for the given driver from an environment variable. +func dsnFromEnv(driver string) string { + switch driver { + case "mysql": + return os.Getenv("DOGMATIQ_TEST_MYSQL_DSN") + case "postgres": + return os.Getenv("DOGMATIQ_TEST_POSTGRES_DSN") + case "sqlite3": + return os.Getenv("DOGMATIQ_TEST_SQLITE_DSN") + default: + panic("unsupported driver: " + driver) + } +} - db, err := sql.Open(driver, dsn) +// tempFile returns the name of a temporary file to be used for an SQLite +// database. +// +// It returns a function that deletes the temporary file. +func tempFile() (string, func()) { + f, err := ioutil.TempFile("", "*.sqlite3") if err != nil { panic(err) } - return db + if err := f.Close(); err != nil { + panic(err) + } + + file := f.Name() + + if err := os.Remove(file); err != nil { + panic(err) + } + + var once sync.Once + return file, func() { + once.Do(func() { + os.Remove(file) + }) + } }