Skip to content

Commit

Permalink
feat: use slog handler (#17)
Browse files Browse the repository at this point in the history
* feat:  use slog handler
Using the handler directly allows to correctly set the Record.PC
  • Loading branch information
injeniero authored Mar 16, 2024
1 parent 91d131d commit 5c0b52f
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 25 deletions.
50 changes: 31 additions & 19 deletions logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log/slog"
"runtime"
"time"

"gorm.io/gorm"
Expand Down Expand Up @@ -50,16 +51,16 @@ func New(options ...Option) *logger {
option(&l)
}

if l.slogger == nil {
// If no slogger is defined, use the default Logger
l.slogger = slog.Default()
if l.sloggerHandler == nil {
// If no sloggerHandler is defined, use the default Handler
l.sloggerHandler = slog.Default().Handler()
}

return &l
}

type logger struct {
slogger *slog.Logger
sloggerHandler slog.Handler
ignoreTrace bool
ignoreRecordNotFoundError bool
traceAll bool
Expand All @@ -85,28 +86,39 @@ func (l logger) LogMode(level gormlogger.LogLevel) gormlogger.Interface {
}

// Info logs info
func (l logger) Info(ctx context.Context, msg string, args ...any) {
l.log(l.slogger.InfoContext, ctx, msg, args...)
func (l logger) Info(ctx context.Context, format string, args ...any) {
l.log(ctx, slog.LevelInfo, format, args...)
}

// Warn logs warn messages
func (l logger) Warn(ctx context.Context, msg string, args ...any) {
l.log(l.slogger.WarnContext, ctx, msg, args...)
func (l logger) Warn(ctx context.Context, format string, args ...any) {
l.log(ctx, slog.LevelWarn, format, args...)
}

// Error logs error messages
func (l logger) Error(ctx context.Context, msg string, args ...any) {
l.log(l.slogger.ErrorContext, ctx, msg, args...)
func (l logger) Error(ctx context.Context, format string, args ...any) {
l.log(ctx, slog.LevelError, format, args...)
}

// log adds context attributes and logs a message with the given slog function
func (l logger) log(f func(ctx context.Context, msg string, args ...any), ctx context.Context, msg string, args ...any) {
// log adds context attributes and logs a message with the given slog level
func (l logger) log(ctx context.Context, level slog.Level, format string, args ...any) {
if ctx == nil {
ctx = context.Background()
}
if !l.sloggerHandler.Enabled(ctx, level) {
return
}

// Append context attributes
args = l.appendContextAttributes(ctx, args)
// Properly handle the PC for the caller
var pc uintptr
var pcs [1]uintptr
// skip [runtime.Callers, this function, this function's caller]
runtime.Callers(3, pcs[:])
pc = pcs[0]
r := slog.NewRecord(time.Now(), level, fmt.Sprintf(format, args...), pc)
r.Add(l.appendContextAttributes(ctx, nil)...)

// Call slog
f(ctx, msg, args...)
_ = l.sloggerHandler.Handle(ctx, r)
}

// Trace logs sql message
Expand All @@ -129,7 +141,7 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (sql strin
slog.String(l.sourceField, utils.FileWithLineNum()),
})

l.slogger.Log(ctx, l.logLevel[ErrorLogType], err.Error(), attributes...)
l.log(ctx, l.logLevel[ErrorLogType], err.Error(), attributes...)

case l.slowThreshold != 0 && elapsed > l.slowThreshold:
sql, rows := fc()
Expand All @@ -142,7 +154,7 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (sql strin
slog.Int64(RowsField, rows),
slog.String(l.sourceField, utils.FileWithLineNum()),
})
l.slogger.Log(ctx, l.logLevel[SlowQueryLogType], fmt.Sprintf("slow sql query [%s >= %v]", elapsed, l.slowThreshold), attributes...)
l.log(ctx, l.logLevel[SlowQueryLogType], fmt.Sprintf("slow sql query [%s >= %v]", elapsed, l.slowThreshold), attributes...)

case l.traceAll || l.gormLevel == gormlogger.Info:
sql, rows := fc()
Expand All @@ -155,7 +167,7 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (sql strin
slog.String(l.sourceField, utils.FileWithLineNum()),
})

l.slogger.Log(ctx, l.logLevel[DefaultLogType], fmt.Sprintf("SQL query executed [%s]", elapsed), attributes...)
l.log(ctx, l.logLevel[DefaultLogType], fmt.Sprintf("SQL query executed [%s]", elapsed), attributes...)
}
}

Expand Down
38 changes: 34 additions & 4 deletions logger_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package slogGorm

import (
"bytes"
"context"
"fmt"
"log/slog"
"runtime"
"testing"
"time"

Expand All @@ -17,18 +19,40 @@ func TestNew(t *testing.T) {
t.Run("Without options", func(t *testing.T) {
l := New()

require.NotNil(t, l.slogger)
assert.Equal(t, slog.Default(), l.slogger)
require.NotNil(t, l.sloggerHandler)
assert.Equal(t, slog.Default().Handler(), l.sloggerHandler)
})

t.Run("WithLogger(nil)", func(t *testing.T) {
l := New(
WithLogger(nil),
)

require.NotNil(t, l.slogger)
assert.Equal(t, slog.Default(), l.slogger)
require.NotNil(t, l.sloggerHandler)
assert.Equal(t, slog.Default().Handler(), l.sloggerHandler)
})

t.Run("WithHandler(nil)", func(t *testing.T) {
l := New(
WithHandler(nil),
)

require.NotNil(t, l.sloggerHandler)
assert.Equal(t, slog.Default().Handler(), l.sloggerHandler)
})
}

func Test_logger_Enabled(t *testing.T) {
buffer := bytes.NewBuffer(nil)
leveler := &slog.LevelVar{}
l := New(WithHandler(slog.NewTextHandler(buffer, &slog.HandlerOptions{Level: leveler})))
leveler.Set(slog.LevelWarn)

l.Info(context.Background(), "an info message")
assert.Equal(t, 0, buffer.Len())

l.Warn(context.Background(), "a warn message")
assert.Greater(t, buffer.Len(), 0)
}

func Test_logger_LogMode(t *testing.T) {
Expand All @@ -49,6 +73,7 @@ func Test_logger(t *testing.T) {
wantMsg string
wantAttributes map[string]slog.Attr
wantLevel slog.Level
wantSource string
}{
{
name: "Info",
Expand Down Expand Up @@ -92,6 +117,11 @@ func Test_logger(t *testing.T) {
require.NotNil(t, receiver.Record)
assert.Equal(t, tt.wantMsg, receiver.Record.Message)
assert.Equal(t, tt.wantLevel, receiver.Record.Level)
pc, _, _, ok := runtime.Caller(0)
assert.True(t, ok)
actualFrame, _ := runtime.CallersFrames([]uintptr{pc}).Next()
frame, _ := runtime.CallersFrames([]uintptr{receiver.Record.PC}).Next()
assert.Equal(t, actualFrame.Function, frame.Function)

if tt.wantAttributes != nil {
for _, v := range tt.wantAttributes {
Expand Down
12 changes: 11 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,19 @@ import (
type Option func(l *logger)

// WithLogger defines a custom logger to use
// Deprecated: Use WithHandler instead
func WithLogger(log *slog.Logger) Option {
return func(l *logger) {
l.slogger = log
if log != nil {
l.sloggerHandler = log.Handler()
}
}
}

// WithHandler defines a custom logger to use
func WithHandler(handler slog.Handler) Option {
return func(l *logger) {
l.sloggerHandler = handler
}
}

Expand Down
11 changes: 10 additions & 1 deletion options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,16 @@ func TestWithLogger(t *testing.T) {

WithLogger(log)(actual)

assert.Equal(t, log, actual.slogger)
assert.Equal(t, log.Handler(), actual.sloggerHandler)
}

func TestWithHandler(t *testing.T) {
actual := &logger{}
handler := slog.Default().Handler()

WithHandler(handler)(actual)

assert.Equal(t, handler, actual.sloggerHandler)
}

func TestSetLogLevel(t *testing.T) {
Expand Down

0 comments on commit 5c0b52f

Please # to comment.