Skip to content
This repository has been archived by the owner on Oct 21, 2023. It is now read-only.

Commit

Permalink
fix: fix context key collision by defining type and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
dtomasi committed Sep 6, 2021
1 parent 6cf949c commit 9b5d0f0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
2 changes: 1 addition & 1 deletion container.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewServiceContainer(opts ...Option) *Container {
i.logger = i.injectableLogger.WithName(loggerName)

// wrap container into context
i.ctx = context.WithValue(i.ctx, ContextKey, i) // nolint:staticcheck
i.ctx = context.WithValue(i.ctx, ContextKeyContainer, i)

return i
}
Expand Down
8 changes: 6 additions & 2 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ import (
"github.com/dtomasi/di/internal/errors"
)

const ContextKey = "di.container"
type ContextKey int

const (
ContextKeyContainer ContextKey = iota
)

// GetContainerFromContext tries to get the container instance from given context as value.
func GetContainerFromContext(ctx context.Context) (*Container, error) {
container, ok := ctx.Value(ContextKey).(*Container)
container, ok := ctx.Value(ContextKeyContainer).(*Container)
if !ok {
return container, errors.New("could not get container instance from context")
}
Expand Down
18 changes: 18 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package di_test

import (
"context"
"github.com/dtomasi/di"
"github.com/stretchr/testify/assert"
"testing"
)

func TestGetContainerFromContext(t *testing.T) {
c := di.NewServiceContainer()
ctx := context.WithValue(context.Background(), di.ContextKeyContainer, c)

ctxContainer, err := di.GetContainerFromContext(ctx)
assert.NoError(t, err)
assert.IsType(t, &di.Container{}, ctxContainer)
assert.Equal(t, c, ctxContainer)
}

0 comments on commit 9b5d0f0

Please # to comment.