diff --git a/accounts/sql_migration_test.go b/accounts/sql_migration_test.go index d1e331e42..621c947a5 100644 --- a/accounts/sql_migration_test.go +++ b/accounts/sql_migration_test.go @@ -39,9 +39,6 @@ func TestAccountStoreMigration(t *testing.T) { *db.TransactionExecutor[SQLQueries]) { testDBStore := NewTestDB(t, clock) - t.Cleanup(func() { - require.NoError(t, testDBStore.Close()) - }) store, ok := testDBStore.(*SQLStore) require.True(t, ok) diff --git a/accounts/test_postgres.go b/accounts/test_postgres.go index 609eeb608..16665030d 100644 --- a/accounts/test_postgres.go +++ b/accounts/test_postgres.go @@ -16,7 +16,7 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,5 +24,5 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/accounts/test_sql.go b/accounts/test_sql.go new file mode 100644 index 000000000..3c1ee7f16 --- /dev/null +++ b/accounts/test_sql.go @@ -0,0 +1,22 @@ +//go:build test_db_postgres || test_db_sqlite + +package accounts + +import ( + "testing" + + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightningnetwork/lnd/clock" + "github.com/stretchr/testify/require" +) + +// createStore is a helper function that creates a new SQLStore and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { + store := NewSQLStore(sqlDB, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store +} diff --git a/accounts/test_sqlite.go b/accounts/test_sqlite.go index 0dd042a28..9d899b3e2 100644 --- a/accounts/test_sqlite.go +++ b/accounts/test_sqlite.go @@ -16,7 +16,7 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return NewSQLStore(db.NewTestSqliteDB(t).BaseDB, clock) + return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +24,7 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return NewSQLStore( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, + return createStore( + t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, ) } diff --git a/db/sqlc/kvstores.sql.go b/db/sqlc/kvstores.sql.go index b2e6632f4..c0949d173 100644 --- a/db/sqlc/kvstores.sql.go +++ b/db/sqlc/kvstores.sql.go @@ -257,6 +257,42 @@ func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreReco return err } +const listAllKVStoresRecords = `-- name: ListAllKVStoresRecords :many +SELECT id, perm, rule_id, session_id, feature_id, entry_key, value +FROM kvstores +` + +func (q *Queries) ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) { + rows, err := q.db.QueryContext(ctx, listAllKVStoresRecords) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Kvstore + for rows.Next() { + var i Kvstore + if err := rows.Scan( + &i.ID, + &i.Perm, + &i.RuleID, + &i.SessionID, + &i.FeatureID, + &i.EntryKey, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateFeatureKVStoreRecord = `-- name: UpdateFeatureKVStoreRecord :exec UPDATE kvstores SET value = $1 diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index df89d0898..117a1fbc5 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -57,6 +57,7 @@ type Querier interface { ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) ListAllAccounts(ctx context.Context) ([]Account, error) + ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) ListSessions(ctx context.Context) ([]Session, error) ListSessionsByState(ctx context.Context, state int16) ([]Session, error) ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) diff --git a/db/sqlc/queries/kvstores.sql b/db/sqlc/queries/kvstores.sql index 7963e46a4..1ebfe3b0d 100644 --- a/db/sqlc/queries/kvstores.sql +++ b/db/sqlc/queries/kvstores.sql @@ -28,6 +28,10 @@ VALUES ($1, $2, $3, $4, $5, $6); DELETE FROM kvstores WHERE perm = false; +-- name: ListAllKVStoresRecords :many +SELECT * +FROM kvstores; + -- name: GetGlobalKVStoreRecord :one SELECT value FROM kvstores diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index c27e53e96..69990c1da 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -28,9 +28,6 @@ func TestActionStorage(t *testing.T) { sessDB := session.NewTestDBWithAccounts(t, clock, accountsDB) db := NewTestDBWithSessionsAndAccounts(t, sessDB, accountsDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Assert that attempting to add an action for a session that does not // exist returns an error. @@ -198,9 +195,6 @@ func TestListActions(t *testing.T) { sessDB := session.NewTestDB(t, clock) db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Add 2 sessions that we can reference. sess1, err := sessDB.NewSession( @@ -466,9 +460,6 @@ func TestListGroupActions(t *testing.T) { } db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // There should not be any actions in group 1 yet. al, _, _, err := db.ListActions(ctx, nil, WithActionGroupID(group1)) diff --git a/firewalldb/db.go b/firewalldb/db.go index b8d9ed06f..a8349a538 100644 --- a/firewalldb/db.go +++ b/firewalldb/db.go @@ -14,29 +14,21 @@ var ( ErrNoSuchKeyFound = fmt.Errorf("no such key found") ) -// firewallDBs is an interface that groups the RulesDB and PrivacyMapper -// interfaces. -type firewallDBs interface { - RulesDB - PrivacyMapper - ActionDB -} - // DB manages the firewall rules database. type DB struct { started sync.Once stopped sync.Once - firewallDBs + FirewallDBs cancel fn.Option[context.CancelFunc] } // NewDB creates a new firewall database. For now, it only contains the // underlying rules' and privacy mapper databases. -func NewDB(dbs firewallDBs) *DB { +func NewDB(dbs FirewallDBs) *DB { return &DB{ - firewallDBs: dbs, + FirewallDBs: dbs, } } diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 5ee729e91..c2955bdc6 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -134,3 +134,11 @@ type ActionDB interface { // and feature name. GetActionsReadDB(groupID session.ID, featureName string) ActionsReadDB } + +// FirewallDBs is an interface that groups the RulesDB, PrivacyMapper and +// ActionDB interfaces. +type FirewallDBs interface { + RulesDB + PrivacyMapper + ActionDB +} diff --git a/firewalldb/kvstores_kvdb.go b/firewalldb/kvstores_kvdb.go index 51721d475..d1e8e35a6 100644 --- a/firewalldb/kvstores_kvdb.go +++ b/firewalldb/kvstores_kvdb.go @@ -16,13 +16,13 @@ the temporary store changes instead of just keeping an in-memory store is that we can then guarantee atomicity if changes are made to both the permanent and temporary stores. -rules -> perm -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} +"rules" -> "perm" -> rule-name -> "global" -> {k:v} + "session-kv-store" -> group ID -> {k:v} + -> "feature-kv-stores" -> feature-name -> {k:v} - -> temp -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} + -> "temp" -> rule-name -> "global" -> {k:v} + "session-kv-store" -> group ID -> {k:v} + -> "feature-kv-stores" -> feature-name -> {k:v} */ var ( diff --git a/firewalldb/sql_migration.go b/firewalldb/sql_migration.go new file mode 100644 index 000000000..092b61c8e --- /dev/null +++ b/firewalldb/sql_migration.go @@ -0,0 +1,486 @@ +package firewalldb + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb" + "go.etcd.io/bbolt" +) + +// kvParams is a type alias for the InsertKVStoreRecordParams, to shorten the +// line length in the migration code. +type kvParams = sqlc.InsertKVStoreRecordParams + +// MigrateFirewallDBToSQL runs the migration of the firwalldb stores from the +// bbolt database to a SQL database. The migration is done in a single +// transaction to ensure that all rows in the stores are migrated or none at +// all. +// +// Note that this migration currently only migrates the kvstores, but will be +// extended in the future to also migrate the privacy mapper and action stores. +// +// NOTE: As sessions may contain linked sessions and accounts, the sessions and +// accounts sql migration MUST be run prior to this migration. +func MigrateFirewallDBToSQL(ctx context.Context, kvStore *bbolt.DB, + tx SQLQueries) error { + + log.Infof("Starting migration of the rules DB to SQL") + + err := migrateKVStoresDBToSQL(ctx, kvStore, tx) + if err != nil { + return err + } + + log.Infof("The rules DB has been migrated from KV to SQL.") + + // TODO(viktor): Add migration for the privacy mapper and the action + // stores. + + return nil +} + +// migrateKVStoresDBToSQL runs the migration of all KV stores from the KV +// database to the SQL database. The function also asserts that the +// migrated values match the original values in the KV store. +// See the illustration in the firwalldb/kvstores_kvdb.go file to understand +// the structure of the KV stores, and why we process the buckets in the +// order we do. +// Note that this function and the subsequent functions are intentionally +// designed to loop over all buckets and values that exist in the KV store, +// so that we are sure that we actually find all stores and values that +// exist in the KV store, and can be sure that the kv store actually follows +// the expected structure. +func migrateKVStoresDBToSQL(ctx context.Context, kvStore *bbolt.DB, + sqlTx SQLQueries) error { + + log.Infof("Starting migration of the KV stores to SQL") + + // allParams will hold all the kvParams that are inserted into the + // SQL database during the migration. + var allParams []kvParams + + err := kvStore.View(func(kvTx *bbolt.Tx) error { + for _, perm := range []bool{true, false} { + mainBucket, err := getMainBucket(kvTx, false, perm) + if err != nil { + return err + } + + if mainBucket == nil { + // If the mainBucket doesn't exist, there are no + // records to migrate under that bucket, + // therefore we don't error, and just proceed + // to not migrate any records under that bucket. + continue + } + + err = mainBucket.ForEach(func(k, v []byte) error { + if v != nil { + return errors.New("expected only " + + "buckets under main bucket") + } + + ruleName := k + ruleNameBucket := mainBucket.Bucket(k) + if ruleNameBucket == nil { + return fmt.Errorf("rule bucket %s "+ + "not found", string(k)) + } + + ruleId, err := sqlTx.GetOrInsertRuleID( + ctx, string(ruleName), + ) + if err != nil { + return err + } + + params, err := processRuleBucket( + ctx, sqlTx, perm, ruleId, + ruleNameBucket, + ) + if err != nil { + return err + } + + allParams = append(allParams, params...) + + return nil + }) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + return err + } + + // After the migration is done, we validate that all inserted kvParams + // can match the original values in the KV store. Note that this is done + // after all values have been inserted, in order to ensure that the + // migration doesn't overwrite any values after they were inserted. + for _, param := range allParams { + switch { + case param.FeatureID.Valid && param.SessionID.Valid: + migratedValue, err := sqlTx.GetFeatureKVStoreRecord( + ctx, + sqlc.GetFeatureKVStoreRecordParams{ + Key: param.EntryKey, + Perm: param.Perm, + RuleID: param.RuleID, + SessionID: param.SessionID, + FeatureID: param.FeatureID, + }, + ) + if err != nil { + return fmt.Errorf("retreiving of migrated "+ + "feature specific kv store record "+ + "failed %w", err) + } + + if !bytes.Equal(migratedValue, param.Value) { + return fmt.Errorf("migrated feature specific "+ + "kv record value %x does not match "+ + "original value %x", migratedValue, + param.Value) + } + + case param.SessionID.Valid: + migratedValue, err := sqlTx.GetSessionKVStoreRecord( + ctx, + sqlc.GetSessionKVStoreRecordParams{ + Key: param.EntryKey, + Perm: param.Perm, + RuleID: param.RuleID, + SessionID: param.SessionID, + }, + ) + if err != nil { + return fmt.Errorf("retreiving of migrated "+ + "session wide kv store record "+ + "failed %w", err) + } + + if !bytes.Equal(migratedValue, param.Value) { + return fmt.Errorf("migrated session wide kv "+ + "record value %x does not match "+ + "original value %x", migratedValue, + param.Value) + } + + case !param.FeatureID.Valid && !param.SessionID.Valid: + migratedValue, err := sqlTx.GetGlobalKVStoreRecord( + ctx, + sqlc.GetGlobalKVStoreRecordParams{ + Key: param.EntryKey, + Perm: param.Perm, + RuleID: param.RuleID, + }, + ) + if err != nil { + return fmt.Errorf("retreiving of migrated "+ + "global kv store record failed %w", err) + } + + if !bytes.Equal(migratedValue, param.Value) { + return fmt.Errorf("migrated global kv record "+ + "value %x does not match original "+ + "value %x", migratedValue, param.Value) + } + + default: + return fmt.Errorf("unexpected combination of "+ + "FeatureID and SessionID for: %v", param) + } + } + + log.Infof("Migration of the KV stores to SQL completed. Total number "+ + "of rows migrated: %d", len(allParams)) + + return nil +} + +// processRuleBucket processes a single rule bucket, which contains the +// global and session-kv-store key buckets. +func processRuleBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, ruleBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, ruleBucket.ForEach(func(k, v []byte) error { + switch { + case v != nil: + return errors.New("expected only buckets under " + + "rule-name bucket") + case bytes.Equal(k, globalKVStoreBucketKey): + globalBucket := ruleBucket.Bucket( + globalKVStoreBucketKey, + ) + if globalBucket == nil { + return fmt.Errorf("global bucket %s for rule "+ + "id %d not found", string(k), ruleSqlId) + } + + p, err := processGlobalRuleBucket( + ctx, sqlTx, perm, ruleSqlId, globalBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + case bytes.Equal(k, sessKVStoreBucketKey): + sessionBucket := ruleBucket.Bucket( + sessKVStoreBucketKey, + ) + if sessionBucket == nil { + return fmt.Errorf("session bucket %s for rule "+ + "id %d not found", string(k), ruleSqlId) + } + + p, err := processSessionBucket( + ctx, sqlTx, perm, ruleSqlId, sessionBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + default: + return fmt.Errorf("unexpected bucket %s under "+ + "rule-name bucket", string(k)) + } + }) +} + +// processGlobalRuleBucket processes the global bucket under a rule bucket, +// which contains the global key-value store records for the rule. +// It inserts the records into the SQL database and asserts that +// the migrated values match the original values in the KV store. +func processGlobalRuleBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, globalBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, globalBucket.ForEach(func(k, v []byte) error { + if v == nil { + return errors.New("expected only key-values under " + + "global rule-name bucket") + } + + globalInsertParams := kvParams{ + EntryKey: string(k), + Value: v, + Perm: perm, + RuleID: ruleSqlId, + } + + err := sqlTx.InsertKVStoreRecord(ctx, globalInsertParams) + if err != nil { + return fmt.Errorf("inserting global kv store "+ + "record failed %w", err) + } + + params = append(params, globalInsertParams) + + return nil + }) +} + +// processSessionBucket processes the session-kv-store bucket under a rule +// bucket, which contains the group-id buckets for that rule. +func processSessionBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, mainSessionBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, mainSessionBucket.ForEach(func(groupId, v []byte) error { + if v != nil { + return fmt.Errorf("expected only buckets under "+ + "%s bucket", string(sessKVStoreBucketKey)) + } + + groupBucket := mainSessionBucket.Bucket(groupId) + if groupBucket == nil { + return fmt.Errorf("group bucket for group id %s"+ + "not found", string(groupId)) + } + + p, err := processGroupBucket( + ctx, sqlTx, perm, ruleSqlId, groupId, groupBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + }) +} + +// processGroupBucket processes a single group bucket, which contains the +// session-wide kv records and as well as the feature-kv-stores key bucket for +// that group. For the session-wide kv records, it inserts the records into the +// SQL database and asserts that the migrated values match the original values. +func processGroupBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, groupAlias []byte, + groupBucket *bbolt.Bucket) ([]kvParams, error) { + + groupSqlId, err := sqlTx.GetSessionIDByAlias( + ctx, groupAlias, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("session with group id %x "+ + "not found in sql db", groupAlias) + } else if err != nil { + return nil, err + } + + var params []kvParams + + return params, groupBucket.ForEach(func(k, v []byte) error { + switch { + case v != nil: + // This is a non-feature specific k:v store for the + // session, i.e. the session-wide store. + sessWideParams := kvParams{ + EntryKey: string(k), + Value: v, + Perm: perm, + RuleID: ruleSqlId, + SessionID: sqldb.SQLInt64(groupSqlId), + } + + err := sqlTx.InsertKVStoreRecord(ctx, sessWideParams) + if err != nil { + return fmt.Errorf("inserting session wide kv "+ + "store record failed %w", err) + } + + params = append(params, sessWideParams) + + return nil + case bytes.Equal(k, featureKVStoreBucketKey): + // This is a feature specific k:v store for the + // session, which will be stored under the feature-name + // under this bucket. + + featureStoreBucket := groupBucket.Bucket( + featureKVStoreBucketKey, + ) + if featureStoreBucket == nil { + return fmt.Errorf("feature store bucket %s "+ + "for group id %s not found", + string(featureKVStoreBucketKey), + string(groupAlias)) + } + + p, err := processFeatureStoreBucket( + ctx, sqlTx, perm, ruleSqlId, groupSqlId, + featureStoreBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + default: + return fmt.Errorf("unexpected bucket %s found under "+ + "the %s bucket", string(k), + string(sessKVStoreBucketKey)) + } + }) +} + +// processFeatureStoreBucket processes the feature-kv-store bucket under a +// group bucket, which contains the feature specific buckets for that group. +func processFeatureStoreBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, groupSqlId int64, + featureStoreBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, featureStoreBucket.ForEach(func(k, v []byte) error { + if v != nil { + return fmt.Errorf("expected only buckets under " + + "feature stores bucket") + } + + featureName := k + featureNameBucket := featureStoreBucket.Bucket(featureName) + if featureNameBucket == nil { + return fmt.Errorf("feature bucket %s not found", + string(featureName)) + } + + featureSqlId, err := sqlTx.GetOrInsertFeatureID( + ctx, string(featureName), + ) + if err != nil { + return err + } + + p, err := processFeatureNameBucket( + ctx, sqlTx, perm, ruleSqlId, groupSqlId, featureSqlId, + featureNameBucket, + ) + if err != nil { + return err + } + + params = append(params, p...) + + return nil + }) +} + +// processFeatureNameBucket processes a single feature name bucket, which +// contains the feature specific key-value store records for that group. +// It inserts the records into the SQL database and asserts that +// the migrated values match the original values in the KV store. +func processFeatureNameBucket(ctx context.Context, sqlTx SQLQueries, perm bool, + ruleSqlId int64, groupSqlId int64, featureSqlId int64, + featureNameBucket *bbolt.Bucket) ([]kvParams, error) { + + var params []kvParams + + return params, featureNameBucket.ForEach(func(k, v []byte) error { + if v == nil { + return fmt.Errorf("expected only key-values under "+ + "feature name bucket, but found bucket %s", + string(k)) + } + + featureParams := kvParams{ + EntryKey: string(k), + Value: v, + Perm: perm, + RuleID: ruleSqlId, + SessionID: sqldb.SQLInt64(groupSqlId), + FeatureID: sqldb.SQLInt64(featureSqlId), + } + + err := sqlTx.InsertKVStoreRecord(ctx, featureParams) + if err != nil { + return fmt.Errorf("inserting feature specific kv "+ + "store record failed %w", err) + } + + params = append(params, featureParams) + + return nil + }) +} diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go new file mode 100644 index 000000000..c068671cc --- /dev/null +++ b/firewalldb/sql_migration_test.go @@ -0,0 +1,537 @@ +package firewalldb + +import ( + "context" + "database/sql" + "fmt" + "github.com/lightningnetwork/lnd/fn" + "testing" + "time" + + "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" +) + +// kvStoreRecord represents a single KV entry inserted into the BoltDB. +type kvStoreRecord struct { + Perm bool + RuleName string + EntryKey string + Global bool + GroupID *session.ID + FeatureName fn.Option[string] // Set if the record is feature specific + Value []byte +} + +// TestFirewallDBMigration tests the migration of firewalldb from a bolt +// backed to a SQL database. Note that this test does not attempt to be a +// complete migration test. +// This test only tests the migration of the KV stores currently, but will +// be extended in the future to also test the migration of the privacy mapper +// and the actions store in the future. +func TestFirewallDBMigration(t *testing.T) { + t.Parallel() + + ctx := context.Background() + clock := clock.NewTestClock(time.Now()) + + // When using build tags that creates a kvdb store for NewTestDB, we + // skip this test as it is only applicable for postgres and sqlite tags. + store := NewTestDB(t, clock) + if _, ok := store.(*BoltDB); ok { + t.Skipf("Skipping Firewall DB migration test for kvdb build") + } + + makeSQLDB := func(t *testing.T, sessionsStore session.Store) (*SQLDB, + *db.TransactionExecutor[SQLQueries]) { + + testDBStore := NewTestDBWithSessions(t, sessionsStore, clock) + + store, ok := testDBStore.(*SQLDB) + require.True(t, ok) + + baseDB := store.BaseDB + + genericExecutor := db.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return baseDB.WithTx(tx) + }, + ) + + return store, genericExecutor + } + + // The assertMigrationResults function will currently assert that + // the migrated kv stores records in the SQLDB match the original kv + // stores records in the BoltDB. + assertMigrationResults := func(t *testing.T, sqlStore *SQLDB, + kvRecords []kvStoreRecord) { + + var ( + ruleIDs = make(map[string]int64) + groupIDs = make(map[string]int64) + featureIDs = make(map[string]int64) + err error + ) + + getRuleID := func(ruleName string) int64 { + ruleID, ok := ruleIDs[ruleName] + if !ok { + ruleID, err = sqlStore.GetRuleID( + ctx, ruleName, + ) + require.NoError(t, err) + + ruleIDs[ruleName] = ruleID + } + + return ruleID + } + + getGroupID := func(groupAlias []byte) int64 { + groupID, ok := groupIDs[string(groupAlias)] + if !ok { + groupID, err = sqlStore.GetSessionIDByAlias( + ctx, groupAlias, + ) + require.NoError(t, err) + + groupIDs[string(groupAlias)] = groupID + } + + return groupID + } + + getFeatureID := func(featureName string) int64 { + featureID, ok := featureIDs[featureName] + if !ok { + featureID, err = sqlStore.GetFeatureID( + ctx, featureName, + ) + require.NoError(t, err) + + featureIDs[featureName] = featureID + } + + return featureID + } + + // First we extract all migrated kv records from the SQLDB, + // in order to be able to compare them to the original kv + // records, to ensure that the migration was successful. + sqlKvRecords, err := sqlStore.ListAllKVStoresRecords(ctx) + require.NoError(t, err) + require.Equal(t, len(kvRecords), len(sqlKvRecords)) + + for _, kvRecord := range kvRecords { + ruleID := getRuleID(kvRecord.RuleName) + + if kvRecord.Global { + sqlVal, err := sqlStore.GetGlobalKVStoreRecord( + ctx, + sqlc.GetGlobalKVStoreRecordParams{ + Key: kvRecord.EntryKey, + Perm: kvRecord.Perm, + RuleID: ruleID, + }, + ) + require.NoError(t, err) + require.Equal(t, kvRecord.Value, sqlVal) + } else if kvRecord.FeatureName.IsNone() { + groupID := getGroupID(kvRecord.GroupID[:]) + + sqlVal, err := sqlStore.GetSessionKVStoreRecord( + ctx, + sqlc.GetSessionKVStoreRecordParams{ + Key: kvRecord.EntryKey, + Perm: kvRecord.Perm, + RuleID: ruleID, + SessionID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + require.Equal(t, kvRecord.Value, sqlVal) + } else { + groupID := getGroupID(kvRecord.GroupID[:]) + featureID := getFeatureID( + kvRecord.FeatureName.UnwrapOrFail(t), + ) + + sqlVal, err := sqlStore.GetFeatureKVStoreRecord( + ctx, + sqlc.GetFeatureKVStoreRecordParams{ + Key: kvRecord.EntryKey, + Perm: kvRecord.Perm, + RuleID: ruleID, + SessionID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + FeatureID: sql.NullInt64{ + Int64: featureID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + require.Equal(t, kvRecord.Value, sqlVal) + } + } + } + + // The tests slice contains all the tests that we will run for the + // migration of the firewalldb from a BoltDB to a SQLDB. + // Note that the tests currently only test the migration of the KV + // stores, but will be extended in the future to also test the migration + // of the privacy mapper and the actions store. + tests := []struct { + name string + populateDB func(t *testing.T, ctx context.Context, + boltDB *BoltDB, + sessionStore session.Store) []kvStoreRecord + }{ + { + name: "empty", + populateDB: func(t *testing.T, ctx context.Context, + boltDB *BoltDB, + sessionStore session.Store) []kvStoreRecord { + + // Don't populate the DB. + return make([]kvStoreRecord, 0) + }, + }, + { + name: "global records", + populateDB: globalRecords, + }, + { + name: "session specific records", + populateDB: sessionSpecificRecords, + }, + { + name: "feature specific records", + populateDB: featureSpecificRecords, + }, + { + name: "records at all levels", + populateDB: recordsAtAllLevels, + }, + { + name: "random records", + populateDB: randomKVRecords, + }, + } + + for _, test := range tests { + tc := test + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // First let's create a sessions store to link to in + // the kvstores DB. In order to create the sessions + // store though, we also need to create an accounts + // store, that we link to the sessions store. + // Note that both of these stores will be sql stores due + // to the build tags enabled when running this test, + // which means we can also pass the sessions store to + // the sql version of the kv stores that we'll create + // in test, without also needing to migrate it. + accountStore := accounts.NewTestDB(t, clock) + sessionsStore := session.NewTestDBWithAccounts( + t, clock, accountStore, + ) + + // Create a new firewall store to populate with test + // data. + firewallStore, err := NewBoltDB( + t.TempDir(), DBFilename, sessionsStore, + accountStore, clock, + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, firewallStore.Close()) + }) + + // Populate the kv store. + records := test.populateDB( + t, ctx, firewallStore, sessionsStore, + ) + + // Create the SQL store that we will migrate the data + // to. + sqlStore, txEx := makeSQLDB(t, sessionsStore) + + // Perform the migration. + var opts sqldb.MigrationTxOptions + err = txEx.ExecTx(ctx, &opts, + func(tx SQLQueries) error { + return MigrateFirewallDBToSQL( + ctx, firewallStore.DB, tx, + ) + }, + ) + require.NoError(t, err) + + // Assert migration results. + assertMigrationResults(t, sqlStore, records) + }) + } +} + +// globalRecords populates the kv store with one global record for the temp +// store, and one for the perm store. +func globalRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + return insertTestKVRecords( + t, ctx, boltDB, sessionStore, true, fn.None[string](), + ) +} + +// sessionSpecificRecords populates the kv store with one session specific +// record for the local temp store, and one session specific record for the perm +// local store. +func sessionSpecificRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + return insertTestKVRecords( + t, ctx, boltDB, sessionStore, false, fn.None[string](), + ) +} + +// featureSpecificRecords populates the kv store with one feature specific +// record for the local temp store, and one feature specific record for the perm +// local store. +func featureSpecificRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + return insertTestKVRecords( + t, ctx, boltDB, sessionStore, false, fn.Some("test-feature"), + ) +} + +// recordsAtAllLevels uses adds a record at all possible levels of the kvstores, +// by utilizing all the other helper functions that populates the kvstores at +// different levels. +func recordsAtAllLevels(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + gRecords := globalRecords(t, ctx, boltDB, sessionStore) + sRecords := sessionSpecificRecords(t, ctx, boltDB, sessionStore) + fRecords := featureSpecificRecords(t, ctx, boltDB, sessionStore) + + return append(gRecords, append(sRecords, fRecords...)...) +} + +// insertTestKVRecords populates the kv store with one record for the local temp +// store, and one record for the local store. The records will be feature +// specific if the featureNameOpt is set, otherwise they will be session +// specific. Both of the records will be inserted with the same +// session.GroupID, which is created in this function, as well as the same +// ruleName, entryKey and entryVal. +func insertTestKVRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store, global bool, + featureNameOpt fn.Option[string]) []kvStoreRecord { + + var ( + ruleName = "test-rule" + entryKey = "test1" + entryVal = []byte{1, 2, 3} + ) + + // Create a session that we can reference. + sess, err := sessionStore.NewSession( + ctx, "test", session.TypeAutopilot, + time.Unix(1000, 0), "something", + ) + require.NoError(t, err) + + tempKvRecord := kvStoreRecord{ + RuleName: ruleName, + GroupID: &sess.GroupID, + FeatureName: featureNameOpt, + EntryKey: entryKey, + Value: entryVal, + Perm: false, + Global: global, + } + + insertKvRecord(t, ctx, boltDB, tempKvRecord) + + permKvRecord := kvStoreRecord{ + RuleName: ruleName, + GroupID: &sess.GroupID, + FeatureName: featureNameOpt, + EntryKey: entryKey, + Value: entryVal, + Perm: true, + Global: global, + } + + insertKvRecord(t, ctx, boltDB, permKvRecord) + + return []kvStoreRecord{tempKvRecord, permKvRecord} +} + +// insertTestKVRecords populates the kv store with passed record, and asserts +// that the record is inserted correctly. +func insertKvRecord(t *testing.T, ctx context.Context, + boltDB *BoltDB, record kvStoreRecord) { + + if record.Global && record.FeatureName.IsSome() { + t.Fatalf("cannot set both global and feature specific at the " + + "same time") + } + + kvStores := boltDB.GetKVStores( + record.RuleName, *record.GroupID, + record.FeatureName.UnwrapOr(""), + ) + + err := kvStores.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + + switch { + case record.Global && !record.Perm: + return tx.GlobalTemp().Set( + ctx, record.EntryKey, record.Value, + ) + case record.Global && record.Perm: + return tx.Global().Set( + ctx, record.EntryKey, record.Value, + ) + case !record.Global && !record.Perm: + return tx.LocalTemp().Set( + ctx, record.EntryKey, record.Value, + ) + case !record.Global && record.Perm: + return tx.Local().Set( + ctx, record.EntryKey, record.Value, + ) + default: + return fmt.Errorf("unexpected global/perm "+ + "combination: global=%v, perm=%v", + record.Global, record.Perm) + } + }) + require.NoError(t, err) +} + +// randomKVRecords populates the kv store with random kv records that span +// across all possible combinations of different levels of records in the kv +// store. All values and different bucket names are randomly generated. +func randomKVRecords(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []kvStoreRecord { + + var ( + // We set the number of records to insert to 1000, as that + // should be enough to cover as many different + // combinations of records as possible, while still being + // fast enough to run in a reasonable time. + numberOfRecords = 1000 + insertedRecords = make([]kvStoreRecord, 0) + ruleName = "initial-rule" + groupId *session.ID + featureName = "initial-feature" + ) + + // Create a random session that we can reference for the initial group + // ID. + sess, err := sessionStore.NewSession( + ctx, "initial-session", session.Type(uint8(rand.Intn(5))), + time.Unix(1000, 0), randomString(rand.Intn(10)+1), + ) + require.NoError(t, err) + + groupId = &sess.GroupID + + // Generate random records. Note that many records will use the same + // rule name, group ID and feature name, to simulate the real world + // usage of the kv stores as much as possible. + for i := 0; i < numberOfRecords; i++ { + // On average, we will generate a new rule which will be used + // for the kv store record 10% of the time. + if rand.Intn(10) == 0 { + ruleName = fmt.Sprintf( + "rule-%s-%d", randomString(rand.Intn(30)+1), i, + ) + } + + // On average, we use the global store 25% of the time. + global := rand.Intn(4) == 0 + + // We'll use the perm store 50% of the time. + perm := rand.Intn(2) == 0 + + // For the non-global records, we will generate a new group ID + // 25% of the time. + if !global && rand.Intn(4) == 0 { + newSess, err := sessionStore.NewSession( + ctx, fmt.Sprintf("session-%d", i), + session.Type(uint8(rand.Intn(5))), + time.Unix(1000, 0), + randomString(rand.Intn(10)+1), + ) + require.NoError(t, err) + + groupId = &newSess.GroupID + } + + featureNameOpt := fn.None[string]() + + // For 50% of the non-global records, we insert a feature + // specific record. The other 50% will be session specific + // records. + if !global && rand.Intn(2) == 0 { + // 25% of the time, we will generate a new feature name. + if rand.Intn(4) == 0 { + featureName = fmt.Sprintf( + "feature-%s-%d", + randomString(rand.Intn(30)+1), i, + ) + } + + featureNameOpt = fn.Some(featureName) + } + + kvEntry := kvStoreRecord{ + RuleName: ruleName, + GroupID: groupId, + FeatureName: featureNameOpt, + EntryKey: fmt.Sprintf("key-%d", i), + Perm: perm, + Global: global, + // We'll generate a random value for all records, + Value: []byte(randomString(rand.Intn(100) + 1)), + } + + // Insert the record into the kv store. + insertKvRecord(t, ctx, boltDB, kvEntry) + + // Add the record to the list of inserted records. + insertedRecords = append(insertedRecords, kvEntry) + } + + return insertedRecords +} + +// randomString generates a random string of the passed length n. +func randomString(n int) string { + letterBytes := "abcdefghijklmnopqrstuvwxyz" + + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} diff --git a/firewalldb/test_kvdb.go b/firewalldb/test_kvdb.go index 6f7a49aa3..c3cd4533a 100644 --- a/firewalldb/test_kvdb.go +++ b/firewalldb/test_kvdb.go @@ -6,34 +6,37 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/require" ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *BoltDB { +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return NewTestDBFromPath(t, t.TempDir(), clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *BoltDB { +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + return newDBFromPathWithSessions(t, dbPath, nil, nil, clock) } // NewTestDBWithSessions creates a new test BoltDB Store with access to an // existing sessions DB. -func NewTestDBWithSessions(t *testing.T, sessStore SessionDB, - clock clock.Clock) *BoltDB { +func NewTestDBWithSessions(t *testing.T, sessStore session.Store, + clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions(t, t.TempDir(), sessStore, nil, clock) } // NewTestDBWithSessionsAndAccounts creates a new test BoltDB Store with access // to an existing sessions DB and accounts DB. -func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *BoltDB { +func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore session.Store, + acctStore AccountsDB, clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions( t, t.TempDir(), sessStore, acctStore, clock, @@ -41,7 +44,8 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, } func newDBFromPathWithSessions(t *testing.T, dbPath string, - sessStore SessionDB, acctStore AccountsDB, clock clock.Clock) *BoltDB { + sessStore session.Store, acctStore AccountsDB, + clock clock.Clock) FirewallDBs { store, err := NewBoltDB(dbPath, DBFilename, sessStore, acctStore, clock) require.NoError(t, err) diff --git a/firewalldb/test_postgres.go b/firewalldb/test_postgres.go index f5777e4cb..732b19b4a 100644 --- a/firewalldb/test_postgres.go +++ b/firewalldb/test_postgres.go @@ -10,12 +10,12 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) +func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) FirewallDBs { + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index 03dcfbebf..a412441f8 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -7,6 +7,7 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/stretchr/testify/require" @@ -15,18 +16,17 @@ import ( // NewTestDBWithSessions creates a new test SQLDB Store with access to an // existing sessions DB. func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, - clock clock.Clock) *SQLDB { - + clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } // NewTestDBWithSessionsAndAccounts creates a new test SQLDB Store with access // to an existing sessions DB and accounts DB. func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *SQLDB { + acctStore AccountsDB, clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) @@ -36,7 +36,7 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, require.Equal(t, accounts.BaseDB, sessions.BaseDB) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } func assertEqualActions(t *testing.T, expected, got *Action) { @@ -52,3 +52,14 @@ func assertEqualActions(t *testing.T, expected, got *Action) { expected.AttemptedAt = expectedAttemptedAt got.AttemptedAt = actualAttemptedAt } + +// createStore is a helper function that creates a new SQLDB and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { + store := NewSQLDB(sqlDB, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store +} diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 5496cb205..49b956d7d 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -10,14 +10,14 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestSqliteDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { + return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *SQLDB { - return NewSQLDB( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, +func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) FirewallDBs { + return createStore( + t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, ) } diff --git a/session/test_kvdb.go b/session/test_kvdb.go index 241448410..cc939159d 100644 --- a/session/test_kvdb.go +++ b/session/test_kvdb.go @@ -11,14 +11,14 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *BoltStore { +func NewTestDB(t *testing.T, clock clock.Clock) Store { return NewTestDBFromPath(t, t.TempDir(), clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. func NewTestDBFromPath(t *testing.T, dbPath string, - clock clock.Clock) *BoltStore { + clock clock.Clock) Store { acctStore := accounts.NewTestDB(t, clock) @@ -28,13 +28,13 @@ func NewTestDBFromPath(t *testing.T, dbPath string, // NewTestDBWithAccounts creates a new test session Store with access to an // existing accounts DB. func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, - acctStore accounts.Store) *BoltStore { + acctStore accounts.Store) Store { return newDBFromPathWithAccounts(t, clock, t.TempDir(), acctStore) } func newDBFromPathWithAccounts(t *testing.T, clock clock.Clock, dbPath string, - acctStore accounts.Store) *BoltStore { + acctStore accounts.Store) Store { store, err := NewDB(dbPath, DBFilename, clock, acctStore) require.NoError(t, err) diff --git a/session/test_postgres.go b/session/test_postgres.go index db392fe7f..cb5aa061d 100644 --- a/session/test_postgres.go +++ b/session/test_postgres.go @@ -15,14 +15,14 @@ import ( var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLStore { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) Store { + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a // connection to an existing postgres database for testing. func NewTestDBFromPath(t *testing.T, dbPath string, - clock clock.Clock) *SQLStore { + clock clock.Clock) Store { - return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/session/test_sql.go b/session/test_sql.go index ab4b32a6c..a83186069 100644 --- a/session/test_sql.go +++ b/session/test_sql.go @@ -6,15 +6,27 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" "github.com/stretchr/testify/require" ) func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, - acctStore accounts.Store) *SQLStore { + acctStore accounts.Store) Store { accounts, ok := acctStore.(*accounts.SQLStore) require.True(t, ok) - return NewSQLStore(accounts.BaseDB, clock) + return createStore(t, accounts.BaseDB, clock) +} + +// createStore is a helper function that creates a new SQLStore and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { + store := NewSQLStore(sqlDB, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store } diff --git a/session/test_sqlite.go b/session/test_sqlite.go index 87519f4f1..0ceb0e046 100644 --- a/session/test_sqlite.go +++ b/session/test_sqlite.go @@ -15,16 +15,16 @@ import ( var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLStore { - return NewSQLStore(db.NewTestSqliteDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) Store { + return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a // connection to an existing sqlite database for testing. func NewTestDBFromPath(t *testing.T, dbPath string, - clock clock.Clock) *SQLStore { + clock clock.Clock) Store { - return NewSQLStore( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, + return createStore( + t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, ) }