From 215fae4e21857ed7e00349c4247eec1aa930575c Mon Sep 17 00:00:00 2001 From: imrishuroy Date: Fri, 8 Dec 2023 22:29:46 +0530 Subject: [PATCH] switch db driver from lib/pq to pgx --- api/account.go | 15 +++---- api/account_test.go | 9 +++-- api/token.go | 5 ++- api/transfer.go | 3 +- api/transfer_test.go | 4 +- api/user.go | 15 +++---- api/user_test.go | 4 +- app.env | 1 - db/sqlc/account.sql.go | 17 ++++---- db/sqlc/account_test.go | 17 ++++---- db/sqlc/db.go | 13 +++--- db/sqlc/entry.sql.go | 9 ++--- db/sqlc/entry_test.go | 69 ++++++++++++++++++++++++++++++++ db/sqlc/error.go | 27 +++++++++++++ db/sqlc/exec_tx.go | 25 ++++++++++++ db/sqlc/main_test.go | 20 +++------ db/sqlc/session.sql.go | 4 +- db/sqlc/store.go | 31 +++----------- db/sqlc/store_test.go | 24 +++++------ db/sqlc/transfer.sql.go | 9 ++--- db/sqlc/tx_verify_email.go | 5 ++- db/sqlc/user.sql.go | 21 +++++----- db/sqlc/user_test.go | 26 ++++++------ db/sqlc/verify_email.sql.go | 4 +- gapi/rpc_create_user.go | 8 +--- gapi/rpc_create_user_test.go | 4 +- gapi/rpc_login_user.go | 4 +- gapi/rpc_update_user.go | 13 +++--- gapi/rpc_update_user_test.go | 8 ++-- go.mod | 10 ++++- go.sum | 12 ++++++ main.go | 8 ++-- sqlc.yaml | 7 ++++ util/confg.go | 1 - worker/task_send_verify_email.go | 4 +- 35 files changed, 274 insertions(+), 182 deletions(-) create mode 100644 db/sqlc/entry_test.go create mode 100644 db/sqlc/error.go create mode 100644 db/sqlc/exec_tx.go diff --git a/api/account.go b/api/account.go index 2c58881..a031eb0 100644 --- a/api/account.go +++ b/api/account.go @@ -1,14 +1,12 @@ package api import ( - "database/sql" "errors" "net/http" "github.com/gin-gonic/gin" db "github.com/imrishuroy/simplebank/db/sqlc" "github.com/imrishuroy/simplebank/token" - "github.com/lib/pq" ) type createAccountRequest struct { @@ -33,13 +31,10 @@ func (server *Server) createAccount(ctx *gin.Context) { account, err := server.store.CreateAccount(ctx, arg) if err != nil { - - if pqErr, ok := err.(*pq.Error); ok { - switch pqErr.Code.Name() { - case "foreign_key_violation", "unique_violation": - ctx.JSON(http.StatusForbidden, errorResponse(err)) - return - } + errCode := db.ErrorCode(err) + if errCode == db.ForeignKeyViolation || errCode == db.UniqueViolation { + ctx.JSON(http.StatusForbidden, errorResponse(err)) + return } ctx.JSON(http.StatusInternalServerError, errorResponse(err)) return @@ -63,7 +58,7 @@ func (server *Server) getAccount(ctx *gin.Context) { account, err := server.store.GetAccount(ctx, req.ID) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, db.ErrRecordNotFound) { ctx.JSON(http.StatusNotFound, errorResponse(err)) return } diff --git a/api/account_test.go b/api/account_test.go index 9ed80e2..ce49f60 100644 --- a/api/account_test.go +++ b/api/account_test.go @@ -5,7 +5,8 @@ import ( "database/sql" "encoding/json" "fmt" - "io/ioutil" + "io" + "net/http" "net/http/httptest" "testing" @@ -89,7 +90,7 @@ func TestGetAccountAPI(t *testing.T) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). Times(1). - Return(db.Account{}, sql.ErrNoRows) + Return(db.Account{}, db.ErrRecordNotFound) }, checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusNotFound, recorder.Code) @@ -430,7 +431,7 @@ func randomAccount(owner string) db.Account { } func requireBodyMatchAccount(t *testing.T, body *bytes.Buffer, account db.Account) { - data, err := ioutil.ReadAll(body) + data, err := io.ReadAll(body) require.NoError(t, err) var gotAccount db.Account @@ -440,7 +441,7 @@ func requireBodyMatchAccount(t *testing.T, body *bytes.Buffer, account db.Accoun } func requireBodyMatchAccounts(t *testing.T, body *bytes.Buffer, accounts []db.Account) { - data, err := ioutil.ReadAll(body) + data, err := io.ReadAll(body) require.NoError(t, err) var gotAccounts []db.Account diff --git a/api/token.go b/api/token.go index 04a7d98..0f14899 100644 --- a/api/token.go +++ b/api/token.go @@ -1,12 +1,13 @@ package api import ( - "database/sql" + "errors" "fmt" "net/http" "time" "github.com/gin-gonic/gin" + db "github.com/imrishuroy/simplebank/db/sqlc" ) type renewAccessTokenRequest struct { @@ -33,7 +34,7 @@ func (server *Server) renewAccessToken(ctx *gin.Context) { session, err := server.store.GetSession(ctx, refreshPayload.ID) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, db.ErrRecordNotFound) { ctx.JSON(http.StatusNotFound, errorResponse(err)) return } diff --git a/api/transfer.go b/api/transfer.go index 0d175ad..fa3afb0 100644 --- a/api/transfer.go +++ b/api/transfer.go @@ -1,7 +1,6 @@ package api import ( - "database/sql" "errors" "fmt" "net/http" @@ -62,7 +61,7 @@ func (server *Server) validAccount(ctx *gin.Context, accountId int64, currency s account, err := server.store.GetAccount(ctx, accountId) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, db.ErrRecordNotFound) { ctx.JSON(http.StatusNotFound, errorResponse(err)) return account, false } diff --git a/api/transfer_test.go b/api/transfer_test.go index 8f7f5bd..03fbd71 100644 --- a/api/transfer_test.go +++ b/api/transfer_test.go @@ -116,7 +116,7 @@ func TestTransferAPI(t *testing.T) { addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { - store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(db.Account{}, sql.ErrNoRows) + store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(db.Account{}, db.ErrRecordNotFound) store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account2.ID)).Times(0) store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0) }, @@ -137,7 +137,7 @@ func TestTransferAPI(t *testing.T) { }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil) - store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account2.ID)).Times(1).Return(db.Account{}, sql.ErrNoRows) + store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account2.ID)).Times(1).Return(db.Account{}, db.ErrRecordNotFound) store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0) }, checkResponse: func(recorder *httptest.ResponseRecorder) { diff --git a/api/user.go b/api/user.go index ba85ce3..a7574e2 100644 --- a/api/user.go +++ b/api/user.go @@ -1,7 +1,7 @@ package api import ( - "database/sql" + "errors" "net/http" "time" @@ -9,7 +9,6 @@ import ( "github.com/google/uuid" db "github.com/imrishuroy/simplebank/db/sqlc" "github.com/imrishuroy/simplebank/util" - "github.com/lib/pq" ) type createUserRequest struct { @@ -58,13 +57,9 @@ func (server *Server) createUser(ctx *gin.Context) { user, err := server.store.CreateUser(ctx, arg) if err != nil { - - if pqErr, ok := err.(*pq.Error); ok { - switch pqErr.Code.Name() { - case "unique_violation": - ctx.JSON(http.StatusForbidden, errorResponse(err)) - return - } + if db.ErrorCode(err) == db.UniqueViolation { + ctx.JSON(http.StatusForbidden, errorResponse(err)) + return } ctx.JSON(http.StatusInternalServerError, errorResponse(err)) return @@ -97,7 +92,7 @@ func (server *Server) loginUser(ctx *gin.Context) { } user, err := server.store.GetUser(ctx, req.Username) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, db.ErrRecordNotFound) { ctx.JSON(http.StatusNotFound, errorResponse(err)) return } diff --git a/api/user_test.go b/api/user_test.go index cf629f6..5f4587f 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -16,7 +16,7 @@ import ( mockdb "github.com/imrishuroy/simplebank/db/mock" db "github.com/imrishuroy/simplebank/db/sqlc" "github.com/imrishuroy/simplebank/util" - "github.com/lib/pq" + "github.com/stretchr/testify/require" ) @@ -111,7 +111,7 @@ func TestCreateUserAPI(t *testing.T) { store.EXPECT(). CreateUser(gomock.Any(), gomock.Any()). Times(1). - Return(db.User{}, &pq.Error{Code: "23505"}) + Return(db.User{}, db.ErrUniqueViolation) }, checkResponse: func(recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusForbidden, recorder.Code) diff --git a/app.env b/app.env index c7e5fe4..173be6b 100644 --- a/app.env +++ b/app.env @@ -1,5 +1,4 @@ ENVIRONMENT=development -DB_DRIVER=postgres DB_SOURCE=postgresql://root:Prince2024@localhost:5432/simple_bank?sslmode=disable MIGRATION_URL=file://db/migration HTTP_SERVER_ADDRESS=localhost:8080 diff --git a/db/sqlc/account.sql.go b/db/sqlc/account.sql.go index 252391a..bf2db5c 100644 --- a/db/sqlc/account.sql.go +++ b/db/sqlc/account.sql.go @@ -22,7 +22,7 @@ type AddAccountBalanceParams struct { } func (q *Queries) AddAccountBalance(ctx context.Context, arg AddAccountBalanceParams) (Account, error) { - row := q.db.QueryRowContext(ctx, addAccountBalance, arg.Amount, arg.ID) + row := q.db.QueryRow(ctx, addAccountBalance, arg.Amount, arg.ID) var i Account err := row.Scan( &i.ID, @@ -51,7 +51,7 @@ type CreateAccountParams struct { } func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) { - row := q.db.QueryRowContext(ctx, createAccount, arg.Owner, arg.Balance, arg.Currency) + row := q.db.QueryRow(ctx, createAccount, arg.Owner, arg.Balance, arg.Currency) var i Account err := row.Scan( &i.ID, @@ -69,7 +69,7 @@ WHERE id = $1 ` func (q *Queries) DeleteAccount(ctx context.Context, id int64) error { - _, err := q.db.ExecContext(ctx, deleteAccount, id) + _, err := q.db.Exec(ctx, deleteAccount, id) return err } @@ -79,7 +79,7 @@ WHERE id = $1 LIMIT 1 ` func (q *Queries) GetAccount(ctx context.Context, id int64) (Account, error) { - row := q.db.QueryRowContext(ctx, getAccount, id) + row := q.db.QueryRow(ctx, getAccount, id) var i Account err := row.Scan( &i.ID, @@ -98,7 +98,7 @@ FOR NO KEY UPDATE ` func (q *Queries) GetAccountForUpdate(ctx context.Context, id int64) (Account, error) { - row := q.db.QueryRowContext(ctx, getAccountForUpdate, id) + row := q.db.QueryRow(ctx, getAccountForUpdate, id) var i Account err := row.Scan( &i.ID, @@ -125,7 +125,7 @@ type ListAccountsParams struct { } func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) { - rows, err := q.db.QueryContext(ctx, listAccounts, arg.Owner, arg.Limit, arg.Offset) + rows, err := q.db.Query(ctx, listAccounts, arg.Owner, arg.Limit, arg.Offset) if err != nil { return nil, err } @@ -144,9 +144,6 @@ func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]A } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } @@ -166,7 +163,7 @@ type UpdateAccountParams struct { } func (q *Queries) UpdateAccount(ctx context.Context, arg UpdateAccountParams) (Account, error) { - row := q.db.QueryRowContext(ctx, updateAccount, arg.ID, arg.Balance) + row := q.db.QueryRow(ctx, updateAccount, arg.ID, arg.Balance) var i Account err := row.Scan( &i.ID, diff --git a/db/sqlc/account_test.go b/db/sqlc/account_test.go index 57e8d14..1011fd3 100644 --- a/db/sqlc/account_test.go +++ b/db/sqlc/account_test.go @@ -2,7 +2,6 @@ package db import ( "context" - "database/sql" "testing" "time" @@ -11,7 +10,6 @@ import ( ) func createRandomAccount(t *testing.T) Account { - user := createRandomUser(t) arg := CreateAccountParams{ @@ -20,7 +18,7 @@ func createRandomAccount(t *testing.T) Account { Currency: util.RandomCurrency(), } - account, err := testQueries.CreateAccount(context.Background(), arg) + account, err := testStore.CreateAccount(context.Background(), arg) require.NoError(t, err) require.NotEmpty(t, account) @@ -40,7 +38,7 @@ func TestCreateAccount(t *testing.T) { func TestGetAccount(t *testing.T) { account1 := createRandomAccount(t) - account2, err := testQueries.GetAccount(context.Background(), account1.ID) + account2, err := testStore.GetAccount(context.Background(), account1.ID) require.NoError(t, err) require.NotEmpty(t, account2) @@ -59,7 +57,7 @@ func TestUpdateAccount(t *testing.T) { Balance: util.RandomMoney(), } - account2, err := testQueries.UpdateAccount(context.Background(), arg) + account2, err := testStore.UpdateAccount(context.Background(), arg) require.NoError(t, err) require.NotEmpty(t, account2) @@ -72,18 +70,17 @@ func TestUpdateAccount(t *testing.T) { func TestDeleteAccount(t *testing.T) { account1 := createRandomAccount(t) - err := testQueries.DeleteAccount(context.Background(), account1.ID) + err := testStore.DeleteAccount(context.Background(), account1.ID) require.NoError(t, err) - account2, err := testQueries.GetAccount(context.Background(), account1.ID) + account2, err := testStore.GetAccount(context.Background(), account1.ID) require.Error(t, err) - require.EqualError(t, err, sql.ErrNoRows.Error()) + require.EqualError(t, err, ErrRecordNotFound.Error()) require.Empty(t, account2) } func TestListAccounts(t *testing.T) { var lastAccount Account - for i := 0; i < 10; i++ { lastAccount = createRandomAccount(t) } @@ -94,7 +91,7 @@ func TestListAccounts(t *testing.T) { Offset: 0, } - accounts, err := testQueries.ListAccounts(context.Background(), arg) + accounts, err := testStore.ListAccounts(context.Background(), arg) require.NoError(t, err) require.NotEmpty(t, accounts) diff --git a/db/sqlc/db.go b/db/sqlc/db.go index bf8f8e3..fe9b2f7 100644 --- a/db/sqlc/db.go +++ b/db/sqlc/db.go @@ -6,14 +6,15 @@ package db import ( "context" - "database/sql" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) type DBTX interface { - ExecContext(context.Context, string, ...interface{}) (sql.Result, error) - PrepareContext(context.Context, string) (*sql.Stmt, error) - QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) - QueryRowContext(context.Context, string, ...interface{}) *sql.Row + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row } func New(db DBTX) *Queries { @@ -24,7 +25,7 @@ type Queries struct { db DBTX } -func (q *Queries) WithTx(tx *sql.Tx) *Queries { +func (q *Queries) WithTx(tx pgx.Tx) *Queries { return &Queries{ db: tx, } diff --git a/db/sqlc/entry.sql.go b/db/sqlc/entry.sql.go index ea1be27..0caea69 100644 --- a/db/sqlc/entry.sql.go +++ b/db/sqlc/entry.sql.go @@ -24,7 +24,7 @@ type CreateEntryParams struct { } func (q *Queries) CreateEntry(ctx context.Context, arg CreateEntryParams) (Entry, error) { - row := q.db.QueryRowContext(ctx, createEntry, arg.AccountID, arg.Amount) + row := q.db.QueryRow(ctx, createEntry, arg.AccountID, arg.Amount) var i Entry err := row.Scan( &i.ID, @@ -41,7 +41,7 @@ WHERE id = $1 LIMIT 1 ` func (q *Queries) GetEntry(ctx context.Context, id int64) (Entry, error) { - row := q.db.QueryRowContext(ctx, getEntry, id) + row := q.db.QueryRow(ctx, getEntry, id) var i Entry err := row.Scan( &i.ID, @@ -67,7 +67,7 @@ type ListEntriesParams struct { } func (q *Queries) ListEntries(ctx context.Context, arg ListEntriesParams) ([]Entry, error) { - rows, err := q.db.QueryContext(ctx, listEntries, arg.AccountID, arg.Limit, arg.Offset) + rows, err := q.db.Query(ctx, listEntries, arg.AccountID, arg.Limit, arg.Offset) if err != nil { return nil, err } @@ -85,9 +85,6 @@ func (q *Queries) ListEntries(ctx context.Context, arg ListEntriesParams) ([]Ent } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } diff --git a/db/sqlc/entry_test.go b/db/sqlc/entry_test.go new file mode 100644 index 0000000..6ced33b --- /dev/null +++ b/db/sqlc/entry_test.go @@ -0,0 +1,69 @@ +package db + +import ( + "context" + "testing" + "time" + + "github.com/imrishuroy/simplebank/util" + "github.com/stretchr/testify/require" +) + +func createRandomEntry(t *testing.T, account Account) Entry { + arg := CreateEntryParams{ + AccountID: account.ID, + Amount: util.RandomMoney(), + } + + entry, err := testStore.CreateEntry(context.Background(), arg) + require.NoError(t, err) + require.NotEmpty(t, entry) + + require.Equal(t, arg.AccountID, entry.AccountID) + require.Equal(t, arg.Amount, entry.Amount) + + require.NotZero(t, entry.ID) + require.NotZero(t, entry.CreatedAt) + + return entry +} + +func TestCreateEntry(t *testing.T) { + account := createRandomAccount(t) + createRandomEntry(t, account) +} + +func TestGetEntry(t *testing.T) { + account := createRandomAccount(t) + entry1 := createRandomEntry(t, account) + entry2, err := testStore.GetEntry(context.Background(), entry1.ID) + require.NoError(t, err) + require.NotEmpty(t, entry2) + + require.Equal(t, entry1.ID, entry2.ID) + require.Equal(t, entry1.AccountID, entry2.AccountID) + require.Equal(t, entry1.Amount, entry2.Amount) + require.WithinDuration(t, entry1.CreatedAt, entry2.CreatedAt, time.Second) +} + +func TestListEntries(t *testing.T) { + account := createRandomAccount(t) + for i := 0; i < 10; i++ { + createRandomEntry(t, account) + } + + arg := ListEntriesParams{ + AccountID: account.ID, + Limit: 5, + Offset: 5, + } + + entries, err := testStore.ListEntries(context.Background(), arg) + require.NoError(t, err) + require.Len(t, entries, 5) + + for _, entry := range entries { + require.NotEmpty(t, entry) + require.Equal(t, arg.AccountID, entry.AccountID) + } +} diff --git a/db/sqlc/error.go b/db/sqlc/error.go new file mode 100644 index 0000000..0291c1b --- /dev/null +++ b/db/sqlc/error.go @@ -0,0 +1,27 @@ +package db + +import ( + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +const ( + ForeignKeyViolation = "23503" + UniqueViolation = "23505" +) + +var ErrRecordNotFound = pgx.ErrNoRows + +var ErrUniqueViolation = &pgconn.PgError{ + Code: UniqueViolation, +} + +func ErrorCode(err error) string { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + return pgErr.Code + } + return "" +} diff --git a/db/sqlc/exec_tx.go b/db/sqlc/exec_tx.go new file mode 100644 index 0000000..8e14d4a --- /dev/null +++ b/db/sqlc/exec_tx.go @@ -0,0 +1,25 @@ +package db + +import ( + "context" + "fmt" +) + +// ExecTx executes a function within a database transaction +func (store *SQLStore) execTx(ctx context.Context, fn func(*Queries) error) error { + tx, err := store.connPool.Begin(ctx) + if err != nil { + return err + } + + q := New(tx) + err = fn(q) + if err != nil { + if rbErr := tx.Rollback(ctx); rbErr != nil { + return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr) + } + return err + } + + return tx.Commit(ctx) +} diff --git a/db/sqlc/main_test.go b/db/sqlc/main_test.go index 1bb3296..deaaea6 100644 --- a/db/sqlc/main_test.go +++ b/db/sqlc/main_test.go @@ -1,36 +1,28 @@ package db import ( - "database/sql" + "context" "log" "os" "testing" - "github.com/imrishuroy/simplebank/util" - _ "github.com/lib/pq" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/techschool/simplebank/util" ) -// const ( -// dbDriver = "postgres" -// dbSource = "postgresql://root:Prince2024@localhost:5432/simple_bank?sslmode=disable" -// ) - -var testQueries *Queries -var testDB *sql.DB +var testStore Store func TestMain(m *testing.M) { config, err := util.LoadConfig("../..") if err != nil { log.Fatal("cannot load config:", err) - } - testDB, err = sql.Open(config.DBDriver, config.DBSource) + connPool, err := pgxpool.New(context.Background(), config.DBSource) if err != nil { log.Fatal("cannot connect to db:", err) } - testQueries = New(testDB) - + testStore = NewStore(connPool) os.Exit(m.Run()) } diff --git a/db/sqlc/session.sql.go b/db/sqlc/session.sql.go index ec8146b..b1ebffe 100644 --- a/db/sqlc/session.sql.go +++ b/db/sqlc/session.sql.go @@ -37,7 +37,7 @@ type CreateSessionParams struct { } func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { - row := q.db.QueryRowContext(ctx, createSession, + row := q.db.QueryRow(ctx, createSession, arg.ID, arg.Username, arg.RefreshToken, @@ -66,7 +66,7 @@ WHERE id = $1 LIMIT 1 ` func (q *Queries) GetSession(ctx context.Context, id uuid.UUID) (Session, error) { - row := q.db.QueryRowContext(ctx, getSession, id) + row := q.db.QueryRow(ctx, getSession, id) var i Session err := row.Scan( &i.ID, diff --git a/db/sqlc/store.go b/db/sqlc/store.go index 09c41ba..7a7e49c 100644 --- a/db/sqlc/store.go +++ b/db/sqlc/store.go @@ -2,8 +2,8 @@ package db import ( "context" - "database/sql" - "fmt" + + "github.com/jackc/pgx/v5/pgxpool" ) // Store defines all functions to execute db queries and transactions @@ -16,33 +16,14 @@ type Store interface { // SQLStore provides all functions to execute SQL queries and transactions type SQLStore struct { - db *sql.DB + connPool *pgxpool.Pool *Queries } // NewStore creates a new store -func NewStore(db *sql.DB) Store { +func NewStore(connPool *pgxpool.Pool) Store { return &SQLStore{ - db: db, - Queries: New(db), - } -} - -// ExecTx executes a function within a database transaction -func (store *SQLStore) execTx(ctx context.Context, fn func(*Queries) error) error { - tx, err := store.db.BeginTx(ctx, nil) - if err != nil { - return err + connPool: connPool, + Queries: New(connPool), } - - q := New(tx) - err = fn(q) - if err != nil { - if rbErr := tx.Rollback(); rbErr != nil { - return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr) - } - return err - } - - return tx.Commit() } diff --git a/db/sqlc/store_test.go b/db/sqlc/store_test.go index fac344d..bc3e1f9 100644 --- a/db/sqlc/store_test.go +++ b/db/sqlc/store_test.go @@ -9,8 +9,6 @@ import ( ) func TestTransferTx(t *testing.T) { - store := NewStore(testDB) - account1 := createRandomAccount(t) account2 := createRandomAccount(t) fmt.Println(">> before:", account1.Balance, account2.Balance) @@ -24,7 +22,7 @@ func TestTransferTx(t *testing.T) { // run n concurrent transfer transaction for i := 0; i < n; i++ { go func() { - result, err := store.TransferTx(context.Background(), TransferTxParams{ + result, err := testStore.TransferTx(context.Background(), TransferTxParams{ FromAccountID: account1.ID, ToAccountID: account2.ID, Amount: amount, @@ -54,7 +52,7 @@ func TestTransferTx(t *testing.T) { require.NotZero(t, transfer.ID) require.NotZero(t, transfer.CreatedAt) - _, err = store.GetTransfer(context.Background(), transfer.ID) + _, err = testStore.GetTransfer(context.Background(), transfer.ID) require.NoError(t, err) // check entries @@ -65,7 +63,7 @@ func TestTransferTx(t *testing.T) { require.NotZero(t, fromEntry.ID) require.NotZero(t, fromEntry.CreatedAt) - _, err = store.GetEntry(context.Background(), fromEntry.ID) + _, err = testStore.GetEntry(context.Background(), fromEntry.ID) require.NoError(t, err) toEntry := result.ToEntry @@ -75,7 +73,7 @@ func TestTransferTx(t *testing.T) { require.NotZero(t, toEntry.ID) require.NotZero(t, toEntry.CreatedAt) - _, err = store.GetEntry(context.Background(), toEntry.ID) + _, err = testStore.GetEntry(context.Background(), toEntry.ID) require.NoError(t, err) // check accounts @@ -103,10 +101,10 @@ func TestTransferTx(t *testing.T) { } // check the final updated balance - updatedAccount1, err := store.GetAccount(context.Background(), account1.ID) + updatedAccount1, err := testStore.GetAccount(context.Background(), account1.ID) require.NoError(t, err) - updatedAccount2, err := store.GetAccount(context.Background(), account2.ID) + updatedAccount2, err := testStore.GetAccount(context.Background(), account2.ID) require.NoError(t, err) fmt.Println(">> after:", updatedAccount1.Balance, updatedAccount2.Balance) @@ -116,8 +114,6 @@ func TestTransferTx(t *testing.T) { } func TestTransferTxDeadlock(t *testing.T) { - store := NewStore(testDB) - account1 := createRandomAccount(t) account2 := createRandomAccount(t) fmt.Println(">> before:", account1.Balance, account2.Balance) @@ -136,7 +132,7 @@ func TestTransferTxDeadlock(t *testing.T) { } go func() { - _, err := store.TransferTx(context.Background(), TransferTxParams{ + _, err := testStore.TransferTx(context.Background(), TransferTxParams{ FromAccountID: fromAccountID, ToAccountID: toAccountID, Amount: amount, @@ -152,13 +148,13 @@ func TestTransferTxDeadlock(t *testing.T) { } // check the final updated balance - updatedAccount1, err := store.GetAccount(context.Background(), account1.ID) + updatedAccount1, err := testStore.GetAccount(context.Background(), account1.ID) require.NoError(t, err) - updatedAccount2, err := store.GetAccount(context.Background(), account2.ID) + updatedAccount2, err := testStore.GetAccount(context.Background(), account2.ID) require.NoError(t, err) fmt.Println(">> after:", updatedAccount1.Balance, updatedAccount2.Balance) require.Equal(t, account1.Balance, updatedAccount1.Balance) require.Equal(t, account2.Balance, updatedAccount2.Balance) -} +} \ No newline at end of file diff --git a/db/sqlc/transfer.sql.go b/db/sqlc/transfer.sql.go index 8f48979..9ebf015 100644 --- a/db/sqlc/transfer.sql.go +++ b/db/sqlc/transfer.sql.go @@ -26,7 +26,7 @@ type CreateTransferParams struct { } func (q *Queries) CreateTransfer(ctx context.Context, arg CreateTransferParams) (Transfer, error) { - row := q.db.QueryRowContext(ctx, createTransfer, arg.FromAccountID, arg.ToAccountID, arg.Amount) + row := q.db.QueryRow(ctx, createTransfer, arg.FromAccountID, arg.ToAccountID, arg.Amount) var i Transfer err := row.Scan( &i.ID, @@ -44,7 +44,7 @@ WHERE id = $1 LIMIT 1 ` func (q *Queries) GetTransfer(ctx context.Context, id int64) (Transfer, error) { - row := q.db.QueryRowContext(ctx, getTransfer, id) + row := q.db.QueryRow(ctx, getTransfer, id) var i Transfer err := row.Scan( &i.ID, @@ -74,7 +74,7 @@ type ListTransfersParams struct { } func (q *Queries) ListTransfers(ctx context.Context, arg ListTransfersParams) ([]Transfer, error) { - rows, err := q.db.QueryContext(ctx, listTransfers, + rows, err := q.db.Query(ctx, listTransfers, arg.FromAccountID, arg.ToAccountID, arg.Limit, @@ -98,9 +98,6 @@ func (q *Queries) ListTransfers(ctx context.Context, arg ListTransfersParams) ([ } items = append(items, i) } - if err := rows.Close(); err != nil { - return nil, err - } if err := rows.Err(); err != nil { return nil, err } diff --git a/db/sqlc/tx_verify_email.go b/db/sqlc/tx_verify_email.go index 091afed..3fff5a1 100644 --- a/db/sqlc/tx_verify_email.go +++ b/db/sqlc/tx_verify_email.go @@ -2,7 +2,8 @@ package db import ( "context" - "database/sql" + + "github.com/jackc/pgx/v5/pgtype" ) type VerifyEmailTxParams struct { @@ -31,7 +32,7 @@ func (store *SQLStore) VerifyEmailTx(ctx context.Context, arg VerifyEmailTxParam result.User, err = q.UpdateUser(ctx, UpdateUserParams{ Username: result.VerifyEmail.Username, - IsEmailVerified: sql.NullBool{ + IsEmailVerified: pgtype.Bool{ Bool: true, Valid: true, }, diff --git a/db/sqlc/user.sql.go b/db/sqlc/user.sql.go index 86da0a6..17f7c4e 100644 --- a/db/sqlc/user.sql.go +++ b/db/sqlc/user.sql.go @@ -7,7 +7,8 @@ package db import ( "context" - "database/sql" + + "github.com/jackc/pgx/v5/pgtype" ) const createUser = `-- name: CreateUser :one @@ -29,7 +30,7 @@ type CreateUserParams struct { } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { - row := q.db.QueryRowContext(ctx, createUser, + row := q.db.QueryRow(ctx, createUser, arg.Username, arg.HashedPassword, arg.FullName, @@ -54,7 +55,7 @@ WHERE username = $1 LIMIT 1 ` func (q *Queries) GetUser(ctx context.Context, username string) (User, error) { - row := q.db.QueryRowContext(ctx, getUser, username) + row := q.db.QueryRow(ctx, getUser, username) var i User err := row.Scan( &i.Username, @@ -82,16 +83,16 @@ RETURNING username, hashed_password, full_name, email, password_changed_at, crea ` type UpdateUserParams struct { - HashedPassword sql.NullString `json:"hashed_password"` - PasswordChangedAt sql.NullTime `json:"password_changed_at"` - FullName sql.NullString `json:"full_name"` - Email sql.NullString `json:"email"` - IsEmailVerified sql.NullBool `json:"is_email_verified"` - Username string `json:"username"` + HashedPassword pgtype.Text `json:"hashed_password"` + PasswordChangedAt pgtype.Timestamptz `json:"password_changed_at"` + FullName pgtype.Text `json:"full_name"` + Email pgtype.Text `json:"email"` + IsEmailVerified pgtype.Bool `json:"is_email_verified"` + Username string `json:"username"` } func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) { - row := q.db.QueryRowContext(ctx, updateUser, + row := q.db.QueryRow(ctx, updateUser, arg.HashedPassword, arg.PasswordChangedAt, arg.FullName, diff --git a/db/sqlc/user_test.go b/db/sqlc/user_test.go index 5e9e6f6..fa09eda 100644 --- a/db/sqlc/user_test.go +++ b/db/sqlc/user_test.go @@ -2,11 +2,11 @@ package db import ( "context" - "database/sql" "testing" "time" "github.com/imrishuroy/simplebank/util" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" ) @@ -21,7 +21,7 @@ func createRandomUser(t *testing.T) User { Email: util.RandomEmail(), } - user, err := testQueries.CreateUser(context.Background(), arg) + user, err := testStore.CreateUser(context.Background(), arg) require.NoError(t, err) require.NotEmpty(t, user) @@ -41,7 +41,7 @@ func TestCreateUser(t *testing.T) { func TestGetUser(t *testing.T) { user1 := createRandomUser(t) - user2, err := testQueries.GetUser(context.Background(), user1.Username) + user2, err := testStore.GetUser(context.Background(), user1.Username) require.NoError(t, err) require.NotEmpty(t, user2) @@ -57,9 +57,9 @@ func TestUpdateUserOnlyFullName(t *testing.T) { oldUser := createRandomUser(t) newFullName := util.RandomOwner() - updatedUser, err := testQueries.UpdateUser(context.Background(), UpdateUserParams{ + updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ Username: oldUser.Username, - FullName: sql.NullString{ + FullName: pgtype.Text{ String: newFullName, Valid: true, }, @@ -76,9 +76,9 @@ func TestUpdateUserOnlyEmail(t *testing.T) { oldUser := createRandomUser(t) newEmail := util.RandomEmail() - updatedUser, err := testQueries.UpdateUser(context.Background(), UpdateUserParams{ + updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ Username: oldUser.Username, - Email: sql.NullString{ + Email: pgtype.Text{ String: newEmail, Valid: true, }, @@ -98,9 +98,9 @@ func TestUpdateUserOnlyPassword(t *testing.T) { newHashedPassword, err := util.HashPassword(newPassword) require.NoError(t, err) - updatedUser, err := testQueries.UpdateUser(context.Background(), UpdateUserParams{ + updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ Username: oldUser.Username, - HashedPassword: sql.NullString{ + HashedPassword: pgtype.Text{ String: newHashedPassword, Valid: true, }, @@ -122,17 +122,17 @@ func TestUpdateUserAllFields(t *testing.T) { newHashedPassword, err := util.HashPassword(newPassword) require.NoError(t, err) - updatedUser, err := testQueries.UpdateUser(context.Background(), UpdateUserParams{ + updatedUser, err := testStore.UpdateUser(context.Background(), UpdateUserParams{ Username: oldUser.Username, - FullName: sql.NullString{ + FullName: pgtype.Text{ String: newFullName, Valid: true, }, - Email: sql.NullString{ + Email: pgtype.Text{ String: newEmail, Valid: true, }, - HashedPassword: sql.NullString{ + HashedPassword: pgtype.Text{ String: newHashedPassword, Valid: true, }, diff --git a/db/sqlc/verify_email.sql.go b/db/sqlc/verify_email.sql.go index c7a980d..97ac619 100644 --- a/db/sqlc/verify_email.sql.go +++ b/db/sqlc/verify_email.sql.go @@ -26,7 +26,7 @@ type CreateVerifyEmailParams struct { } func (q *Queries) CreateVerifyEmail(ctx context.Context, arg CreateVerifyEmailParams) (VerifyEmail, error) { - row := q.db.QueryRowContext(ctx, createVerifyEmail, arg.Username, arg.Email, arg.SecretCode) + row := q.db.QueryRow(ctx, createVerifyEmail, arg.Username, arg.Email, arg.SecretCode) var i VerifyEmail err := row.Scan( &i.ID, @@ -58,7 +58,7 @@ type UpdateVerifyEmailParams struct { } func (q *Queries) UpdateVerifyEmail(ctx context.Context, arg UpdateVerifyEmailParams) (VerifyEmail, error) { - row := q.db.QueryRowContext(ctx, updateVerifyEmail, arg.ID, arg.SecretCode) + row := q.db.QueryRow(ctx, updateVerifyEmail, arg.ID, arg.SecretCode) var i VerifyEmail err := row.Scan( &i.ID, diff --git a/gapi/rpc_create_user.go b/gapi/rpc_create_user.go index f4817cb..d2c1861 100644 --- a/gapi/rpc_create_user.go +++ b/gapi/rpc_create_user.go @@ -10,7 +10,6 @@ import ( "github.com/imrishuroy/simplebank/util" "github.com/imrishuroy/simplebank/val" "github.com/imrishuroy/simplebank/worker" - "github.com/lib/pq" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -50,11 +49,8 @@ func (server *Server) CreateUser(ctx context.Context, req *pb.CreateUserRequest) txResult, err := server.store.CreateUserTx(ctx, arg) if err != nil { - if pqErr, ok := err.(*pq.Error); ok { - switch pqErr.Code.Name() { - case "unique_violation": - return nil, status.Errorf(codes.AlreadyExists, "username already exists: %s", err) - } + if db.ErrorCode(err) == db.UniqueViolation { + return nil, status.Errorf(codes.AlreadyExists, "%s", err) } return nil, status.Errorf(codes.Internal, "failed to create user: %s", err) } diff --git a/gapi/rpc_create_user_test.go b/gapi/rpc_create_user_test.go index fb75bda..ee0f149 100644 --- a/gapi/rpc_create_user_test.go +++ b/gapi/rpc_create_user_test.go @@ -14,7 +14,7 @@ import ( "github.com/imrishuroy/simplebank/util" "github.com/imrishuroy/simplebank/worker" mockwk "github.com/imrishuroy/simplebank/worker/mock" - "github.com/lib/pq" + "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -152,7 +152,7 @@ func TestCreateUserAPI(t *testing.T) { store.EXPECT(). CreateUserTx(gomock.Any(), gomock.Any()). Times(1). - Return(db.CreateUserTxResult{}, &pq.Error{Code: "23505"}) + Return(db.CreateUserTxResult{}, db.ErrUniqueViolation) taskDistributor.EXPECT(). DistributeTaskSendVerifyEmail(gomock.Any(), gomock.Any(), gomock.Any()). diff --git a/gapi/rpc_login_user.go b/gapi/rpc_login_user.go index d2e656e..243d193 100644 --- a/gapi/rpc_login_user.go +++ b/gapi/rpc_login_user.go @@ -2,7 +2,7 @@ package gapi import ( "context" - "database/sql" + "errors" db "github.com/imrishuroy/simplebank/db/sqlc" "github.com/imrishuroy/simplebank/pb" @@ -22,7 +22,7 @@ func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) ( user, err := server.store.GetUser(ctx, req.GetUsername()) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, db.ErrRecordNotFound) { return nil, status.Errorf(codes.NotFound, "user not found") } return nil, status.Errorf(codes.Internal, "failed to find user") diff --git a/gapi/rpc_update_user.go b/gapi/rpc_update_user.go index 6a088b4..1b9ded4 100644 --- a/gapi/rpc_update_user.go +++ b/gapi/rpc_update_user.go @@ -2,13 +2,14 @@ package gapi import ( "context" - "database/sql" + "errors" "time" db "github.com/imrishuroy/simplebank/db/sqlc" "github.com/imrishuroy/simplebank/pb" "github.com/imrishuroy/simplebank/util" "github.com/imrishuroy/simplebank/val" + "github.com/jackc/pgx/v5/pgtype" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -32,11 +33,11 @@ func (server *Server) UpdateUser(ctx context.Context, req *pb.UpdateUserRequest) arg := db.UpdateUserParams{ Username: req.GetUsername(), - FullName: sql.NullString{ + FullName: pgtype.Text{ String: req.GetFullName(), Valid: req.FullName != nil, }, - Email: sql.NullString{ + Email: pgtype.Text{ String: req.GetEmail(), Valid: req.Email != nil, }, @@ -48,12 +49,12 @@ func (server *Server) UpdateUser(ctx context.Context, req *pb.UpdateUserRequest) return nil, status.Errorf(codes.Internal, "failed to hash password: %s", err) } - arg.HashedPassword = sql.NullString{ + arg.HashedPassword = pgtype.Text{ String: hashedPassword, Valid: true, } - arg.PasswordChangedAt = sql.NullTime{ + arg.PasswordChangedAt = pgtype.Timestamptz{ Time: time.Now(), Valid: true, } @@ -61,7 +62,7 @@ func (server *Server) UpdateUser(ctx context.Context, req *pb.UpdateUserRequest) user, err := server.store.UpdateUser(ctx, arg) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, db.ErrRecordNotFound) { return nil, status.Errorf(codes.NotFound, "user not found") } return nil, status.Errorf(codes.Internal, "failed to update user: %s", err) diff --git a/gapi/rpc_update_user_test.go b/gapi/rpc_update_user_test.go index e7a3f71..c9eed13 100644 --- a/gapi/rpc_update_user_test.go +++ b/gapi/rpc_update_user_test.go @@ -2,7 +2,6 @@ package gapi import ( "context" - "database/sql" "testing" "time" @@ -12,6 +11,7 @@ import ( "github.com/imrishuroy/simplebank/pb" "github.com/imrishuroy/simplebank/token" "github.com/imrishuroy/simplebank/util" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -41,11 +41,11 @@ func TestUpdateUserAPI(t *testing.T) { buildStubs: func(store *mockdb.MockStore) { arg := db.UpdateUserParams{ Username: user.Username, - FullName: sql.NullString{ + FullName: pgtype.Text{ String: newName, Valid: true, }, - Email: sql.NullString{ + Email: pgtype.Text{ String: newEmail, Valid: true, }, @@ -87,7 +87,7 @@ func TestUpdateUserAPI(t *testing.T) { store.EXPECT(). UpdateUser(gomock.Any(), gomock.Any()). Times(1). - Return(db.User{}, sql.ErrNoRows) + Return(db.User{}, db.ErrRecordNotFound) }, buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context { return newContextWithBearerToken(t, tokenMaker, user.Username, time.Minute) diff --git a/go.mod b/go.mod index a60ebee..04fa274 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/imrishuroy/simplebank -go 1.21.4 +go 1.21.5 require ( github.com/aead/chacha20poly1305 v0.0.0-20201124145622-1a5aba2a8b29 @@ -12,13 +12,14 @@ require ( github.com/google/uuid v1.4.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 github.com/hibiken/asynq v0.24.1 + github.com/jackc/pgx/v5 v5.5.0 github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible - github.com/lib/pq v1.10.9 github.com/o1egl/paseto v1.0.0 github.com/rakyll/statik v0.1.7 github.com/rs/zerolog v1.31.0 github.com/spf13/viper v1.17.0 github.com/stretchr/testify v1.8.4 + github.com/techschool/simplebank v0.0.0-20231029084543-9544012aa580 golang.org/x/crypto v0.15.0 google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17 google.golang.org/genproto/googleapis/rpc v0.0.0-20231030173426-d783a09b4405 @@ -45,9 +46,13 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/leodido/go-urn v1.2.4 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -73,6 +78,7 @@ require ( golang.org/x/arch v0.6.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/net v0.18.0 // indirect + golang.org/x/sync v0.3.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.5.0 // indirect diff --git a/go.sum b/go.sum index c2129ba..c32c000 100644 --- a/go.sum +++ b/go.sum @@ -212,6 +212,14 @@ github.com/hibiken/asynq v0.24.1 h1:+5iIEAyA9K/lcSPvx3qoPtsKJeKI5u9aOIvUmSsazEw= github.com/hibiken/asynq v0.24.1/go.mod h1:u5qVeSbrnfT+vtG5Mq8ZPzQu/BmCKMHvTGb91uy9Tts= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.0 h1:NxstgwndsTRy7eq9/kqYc/BZh5w2hHJV86wjvO+1xPw= +github.com/jackc/pgx/v5 v5.5.0/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible h1:jdpOPRN1zP63Td1hDQbZW73xKmzDvZHzVdNYxhnTMDA= github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible/go.mod h1:1c7szIrayyPPB/987hsnvNzLushdWf4o/79s3P08L8A= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -316,6 +324,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/techschool/simplebank v0.0.0-20231029084543-9544012aa580 h1:J67NzuOy2Jxi20mZ/klyDBSyZYfRxzLWJbIMQeJTs+8= +github.com/techschool/simplebank v0.0.0-20231029084543-9544012aa580/go.mod h1:y6nH7U3EyXBjs+mLlEWt4ZOMNptF3AdQoPSBsGgjm94= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= @@ -443,6 +453,8 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/main.go b/main.go index f8c7554..eab1825 100644 --- a/main.go +++ b/main.go @@ -2,12 +2,12 @@ package main import ( "context" - "database/sql" "net" "net/http" "os" "github.com/hibiken/asynq" + "github.com/jackc/pgx/v5/pgxpool" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -22,7 +22,7 @@ import ( "github.com/imrishuroy/simplebank/pb" "github.com/imrishuroy/simplebank/util" "github.com/imrishuroy/simplebank/worker" - _ "github.com/lib/pq" + "github.com/rakyll/statik/fs" "google.golang.org/grpc" "google.golang.org/grpc/reflection" @@ -48,14 +48,14 @@ func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) } - conn, err := sql.Open(config.DBDriver, config.DBSource) + connPool, err := pgxpool.New(context.Background(), config.DBSource) if err != nil { log.Fatal().Msg("cannot connect to db") } runDBMigration(config.MigrationURL, config.DBSource) - store := db.NewStore(conn) + store := db.NewStore(connPool) redisOpt := asynq.RedisClientOpt{ Addr: config.RedisAddress, diff --git a/sqlc.yaml b/sqlc.yaml index 97b0201..8aeffda 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -7,9 +7,16 @@ sql: go: package: "db" out: "db/sqlc" + sql_package: "pgx/v5" emit_json_tags: true emit_interface: true emit_empty_slices: true + overrides: + - db_type: "timestamptz" + go_type: "time.Time" + - db_type: "uuid" + go_type: "github.com/google/uuid.UUID" + # version: "1" diff --git a/util/confg.go b/util/confg.go index 58cf944..814f8d4 100644 --- a/util/confg.go +++ b/util/confg.go @@ -10,7 +10,6 @@ import ( // The values are read by viper from a config file or environment variables type Config struct { Environment string `mapstructure:"ENVIRONMENT"` - DBDriver string `mapstructure:"DB_DRIVER"` DBSource string `mapstructure:"DB_SOURCE"` MigrationURL string `mapstructure:"MIGRATION_URL"` RedisAddress string `mapstructure:"REDIS_ADDRESS"` diff --git a/worker/task_send_verify_email.go b/worker/task_send_verify_email.go index 74c5644..ace1f1c 100644 --- a/worker/task_send_verify_email.go +++ b/worker/task_send_verify_email.go @@ -2,8 +2,8 @@ package worker import ( "context" - "database/sql" "encoding/json" + "errors" "fmt" "github.com/hibiken/asynq" @@ -47,7 +47,7 @@ func (processor *RedisTaskProcessor) ProcessTaskSendVerifyEmail(ctx context.Cont user, err := processor.store.GetUser(ctx, payload.Username) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, db.ErrRecordNotFound) { // if user is not created in DB, we are not allowing to retry return fmt.Errorf("user doesn't exist: %w", asynq.SkipRetry) }