Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add provider selection logic, file provider and tests #17

Merged
merged 10 commits into from
Nov 29, 2023
7 changes: 7 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@ require (
github.com/spf13/cast v1.5.1
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

require (
github.com/samber/lo v1.38.1 // indirect
github.com/stretchr/testify v1.8.4
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect
)
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
Expand All @@ -6,6 +8,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
Expand All @@ -16,5 +20,11 @@ github.com/samber/slog-syslog v1.0.0 h1:4tf8sNv9+qTQ6Fj8+N6U1ZEtUbqbAIzd+q26/Neg
github.com/samber/slog-syslog v1.0.0/go.mod h1:jjupk+yHPVSuXuGhKleoClYc/HEaC+Ro5X4YYeBrt6g=
github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA=
github.com/spf13/cast v1.5.1/go.mod h1:b9PdjNptOpzXr7Rq1q9gJML/2cdGQAo69NKzQ10KN48=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
96 changes: 90 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"os/exec"
"os/signal"
"slices"
"strings"
"syscall"
"time"

Expand All @@ -32,8 +33,77 @@ import (
"github.com/spf13/cast"

"github.com/bank-vaults/secret-init/provider"
"github.com/bank-vaults/secret-init/provider/file"
)

func NewProvider(providerName string) (provider.Provider, error) {
csatib02 marked this conversation as resolved.
Show resolved Hide resolved
switch providerName {
case file.ProviderName:
provider, err := file.NewProvider(os.DirFS("/secrets"))
csatib02 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, err
}

return provider, nil

default:
return nil, errors.New("invalid provider specified")
}
}

func CreateMapOfEnvs() map[string]string {
environ := make(map[string]string, len(os.Environ()))
for _, env := range os.Environ() {
split := strings.SplitN(env, "=", 2)
name := split[0]
value := split[1]
environ[name] = value
}

return environ
}

func ExtractPathsFromEnvs(envs map[string]string) []string {
var secretPaths []string

for _, path := range envs {
if strings.HasPrefix(path, "file:") {
path = strings.TrimPrefix(path, "file://")
secretPaths = append(secretPaths, path)
}
}

return secretPaths
}

func CreateEnvsFromLoadedSecrets(envs map[string]string, secrets []string) ([]string, error) {
// Reverse the map so we can match
// the environment variable key to the secret
// by using the secret path
reversedEnvs := make(map[string]string)
for envKey, path := range envs {
if strings.HasPrefix(path, "file:") {
path = strings.TrimPrefix(path, "file://")
reversedEnvs[path] = envKey
}
}

var secretsEnv []string
for _, secret := range secrets {
split := strings.SplitN(secret, "|", 2)
secretPath := split[0]

secretValue := split[1]
secretKey, ok := reversedEnvs[secretPath]
if !ok {
return nil, fmt.Errorf("failed to find environment variable key for secret path: %s", secretPath)
}
secretsEnv = append(secretsEnv, fmt.Sprintf("%s=%s", secretKey, secretValue))
}

return secretsEnv, nil
}
csatib02 marked this conversation as resolved.
Show resolved Hide resolved

func main() {
var logger *slog.Logger
{
Expand Down Expand Up @@ -94,8 +164,12 @@ func main() {
slog.SetDefault(logger)
}

// TODO: enable providers
var provider provider.Provider
provider, err := NewProvider(os.Getenv("PROVIDER"))
if err != nil {
logger.Error(fmt.Errorf("failed to create provider: %w", err).Error())

os.Exit(1)
}

if len(os.Args) == 1 {
logger.Error("no command is given, vault-env can't determine the entrypoint (command), please specify it explicitly or let the webhook query it (see documentation)")
Expand All @@ -115,10 +189,20 @@ func main() {
os.Exit(1)
}

environ := CreateMapOfEnvs()
paths := ExtractPathsFromEnvs(environ)
csatib02 marked this conversation as resolved.
Show resolved Hide resolved

ctx := context.Background()
envs, err := provider.LoadSecrets(ctx, os.Environ())
secrets, err := provider.LoadSecrets(ctx, paths)
if err != nil {
logger.Error(fmt.Errorf("failed to load secrets from provider: %w", err).Error())

os.Exit(1)
}

secretsEnv, err := CreateEnvsFromLoadedSecrets(environ, secrets)
if err != nil {
logger.Error("could not retrieve secrets from the provider.", err)
logger.Error(fmt.Errorf("failed to create environment variables from loaded secrets: %w", err).Error())

os.Exit(1)
}
Expand All @@ -135,7 +219,7 @@ func main() {
if daemonMode {
logger.Info("in daemon mode...")
cmd := exec.Command(binary, entrypointCmd[1:]...)
cmd.Env = append(os.Environ(), envs...)
cmd.Env = append(os.Environ(), secretsEnv...)
cmd.Stdin = os.Stdin
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
Expand Down Expand Up @@ -184,7 +268,7 @@ func main() {

os.Exit(cmd.ProcessState.ExitCode())
}
err = syscall.Exec(binary, entrypointCmd, envs)
err = syscall.Exec(binary, entrypointCmd, secretsEnv)
if err != nil {
logger.Error(fmt.Errorf("failed to exec process: %w", err).Error(), slog.String("entrypoint", fmt.Sprint(entrypointCmd)))

Expand Down
96 changes: 96 additions & 0 deletions provider/file/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright © 2023 Bank-Vaults Maintainers
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package file

import (
"context"
"fmt"
"io/fs"

"github.com/bank-vaults/secret-init/provider"
)

const ProviderName = "file"

type Provider struct {
fs fs.FS
}

func NewProvider(fs fs.FS) (provider.Provider, error) {
if fs == nil {
return nil, fmt.Errorf("file system is nil")
}

isEmpty, err := isFileSystemEmpty(fs)
if err != nil {
return nil, fmt.Errorf("failed to check if file system is empty: %w", err)
}
if isEmpty {
return nil, fmt.Errorf("file system is empty")
}

csatib02 marked this conversation as resolved.
Show resolved Hide resolved
return &Provider{fs: fs}, nil
}

func (provider *Provider) LoadSecrets(_ context.Context, paths []string) ([]string, error) {
var secrets []string

for i, path := range paths {
secret, err := provider.getSecretFromFile(path)
if err != nil {
return nil, fmt.Errorf("failed to get secret from file: %w", err)
}
// Add the secret path with a "|" separator character
// to the secrets slice along with the secret
// so later we can match it to the environment key
secrets = append(secrets, paths[i]+"|"+secret)
csatib02 marked this conversation as resolved.
Show resolved Hide resolved
}

return secrets, nil
}

func isFileSystemEmpty(fsys fs.FS) (bool, error) {
csatib02 marked this conversation as resolved.
Show resolved Hide resolved
dir, err := fs.ReadDir(fsys, ".")
fmt.Println(dir, err)
if err != nil {
return false, err
}

for _, entry := range dir {
if entry.IsDir() || entry.Type().IsRegular() {
return false, nil
}
}

return true, nil
}

func (provider *Provider) getSecretFromFile(path string) (string, error) {
content, err := provider.readFile(path)
if err != nil {
return "", err
}

return string(content), nil
}

func (provider *Provider) readFile(path string) ([]byte, error) {
content, err := fs.ReadFile(provider.fs, path)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}

return content, nil
}
csatib02 marked this conversation as resolved.
Show resolved Hide resolved
116 changes: 116 additions & 0 deletions provider/file/file_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright © 2023 Bank-Vaults Maintainers
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package file

import (
"context"
"io/fs"
"testing"
"testing/fstest"

"github.com/stretchr/testify/assert"
)

func TestNewProvider(t *testing.T) {
tests := []struct {
name string
fs fs.FS
wantErr bool
wantType bool
}{
{
name: "Valid file system",
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
},
wantErr: false,
wantType: true,
},
{
name: "Nil file system",
fs: nil,
wantErr: true,
wantType: false,
},
{
name: "Empty file system",
fs: fstest.MapFS{},
wantErr: true,
wantType: false,
},
}

for _, tt := range tests {
ttp := tt
t.Run(ttp.name, func(t *testing.T) {

prov, err := NewProvider(ttp.fs)
if (err != nil) != ttp.wantErr {
t.Fatalf("NewProvider() error = %v, wantErr %v", err, ttp.wantErr)
return
}
// Use type assertion to check if the provider is of the correct type
_, ok := prov.(*Provider)
if ok != ttp.wantType {
t.Fatalf("NewProvider() = %v, wantType %v", ok, ttp.wantType)
}
})
}
}

func TestLoadSecrets(t *testing.T) {
tests := []struct {
name string
fs fs.FS
paths []string
wantErr bool
wantData []string
}{
{
name: "Load secrets successfully",
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
},
paths: []string{"test/secrets/sqlpass.txt", "test/secrets/awsaccess.txt"},
wantErr: false,
wantData: []string{"test/secrets/sqlpass.txt|3xtr3ms3cr3t", "test/secrets/awsaccess.txt|s3cr3t"},
},
{
name: "Fail to load secrets due to invalid path",
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
},
paths: []string{"test/secrets/mistake/sqlpass.txt", "test/secrets/mistake/awsaccess.txt"},
wantErr: true,
wantData: nil,
},
}

for _, tt := range tests {
ttp := tt
t.Run(ttp.name, func(t *testing.T) {
provider, err := NewProvider(ttp.fs)
if assert.NoError(t, err, "Unexpected error") {
secrets, err := provider.LoadSecrets(context.Background(), ttp.paths)
assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status")

assert.Equal(t, ttp.wantData, secrets, "Unexpected secrets loaded")
}
})
}
}