Skip to content

Commit

Permalink
feat: ajust PreparedStmtDB unlock location and BuildCondition if logic (
Browse files Browse the repository at this point in the history
  • Loading branch information
daheige authored Oct 8, 2021
1 parent c13f301 commit e3fc49a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
19 changes: 11 additions & 8 deletions prepare_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {

func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()

for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
go stmt.Close()
}
}

db.Mux.Unlock()
}

func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
Expand All @@ -51,9 +51,10 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Mux.RUnlock()

db.Mux.Lock()
defer db.Mux.Unlock()

// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
return stmt, nil
} else if ok {
go stmt.Close()
Expand All @@ -64,7 +65,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
db.PreparedSQL = append(db.PreparedSQL, query)
}
defer db.Mux.Unlock()

return db.Stmts[query], err
}
Expand All @@ -83,9 +83,9 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
result, err = stmt.ExecContext(ctx, args...)
if err != nil {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
db.Mux.Unlock()
}
}
return result, err
Expand All @@ -97,9 +97,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
rows, err = stmt.QueryContext(ctx, args...)
if err != nil {
db.Mux.Lock()
defer db.Mux.Unlock()

go stmt.Close()
delete(db.Stmts, query)
db.Mux.Unlock()
}
}
return rows, err
Expand Down Expand Up @@ -138,9 +139,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()

go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.Mux.Unlock()
}
}
return result, err
Expand All @@ -152,9 +154,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()

go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.Mux.Unlock()
}
}
return rows, err
Expand Down
12 changes: 9 additions & 3 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,19 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if _, err := strconv.Atoi(s); err != nil {
if s == "" && len(args) == 0 {
return nil
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
}

if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
} else if len(args) > 0 && strings.Contains(s, "@") {
}

if len(args) > 0 && strings.Contains(s, "@") {
// looks like a named query
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
} else if len(args) == 1 {
}

if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
}
}
Expand Down

0 comments on commit e3fc49a

Please # to comment.