Skip to content

Commit

Permalink
Parse credentials and try figure out best action to take
Browse files Browse the repository at this point in the history
  • Loading branch information
louisruch committed Jul 20, 2022
1 parent 57d4316 commit fef22f0
Show file tree
Hide file tree
Showing 5 changed files with 541 additions and 331 deletions.
55 changes: 32 additions & 23 deletions internal/cmd/commands/connect/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,15 +483,16 @@ func (c *Command) Run(args []string) (retCode int) {

c.listenerAddr = c.listener.Addr().(*net.TCPAddr)

var creds []*targets.SessionCredential
if c.sessionAuthz != nil && len(c.sessionAuthz.Credentials) > 0 {
creds = c.sessionAuthz.Credentials
}
switch c.Func {
case "connect":
if c.Func == "connect" {
// "connect" indicates there is no subcommand to the connect function.
// The only way a user will be able to connect to the session is by
// connecting directly to the port and address we report to them here.

var creds []*targets.SessionCredential
if c.sessionAuthz != nil && len(c.sessionAuthz.Credentials) > 0 {
creds = c.sessionAuthz.Credentials
}

sessInfo := SessionInfo{
Protocol: c.sessionAuthzData.GetType(),
Address: c.listenerAddr.IP.String(),
Expand Down Expand Up @@ -649,18 +650,18 @@ func (c *Command) Run(args []string) (retCode int) {
return
}

func (c *Command) printCredentials() error {
if c.sessionAuthz == nil || len(c.sessionAuthz.Credentials) == 0 {
func (c *Command) printCredentials(creds []*targets.SessionCredential) error {
if len(creds) == 0 {
return nil
}
switch base.Format(c.UI) {
case "table":
c.UI.Output(generateCredentialTableOutput(c.sessionAuthz.Credentials))
c.UI.Output(generateCredentialTableOutput(creds))
case "json":
out, err := json.Marshal(&struct {
Credentials []*targets.SessionCredential `json:"credentials"`
}{
Credentials: c.sessionAuthz.Credentials,
Credentials: creds,
})
if err != nil {
return fmt.Errorf("error marshaling credential information: %w", err)
Expand Down Expand Up @@ -812,7 +813,18 @@ func (c *Command) handleExec(passthroughArgs []string) {
var args []string
var envs []string
var argsErr error
printCreds := true

var creds credentials
if c.sessionAuthz != nil {
var err error
creds, err = parseCredentials(c.sessionAuthz.Credentials)
if err != nil {
c.PrintCliError(fmt.Errorf("Error interpreting secret: %w", err))
c.execCmdReturnValue.Store(int32(3))
return
}
}

switch c.Func {
case "http":
httpArgs, err := c.httpFlags.buildArgs(c, port, ip, addr)
Expand All @@ -824,28 +836,27 @@ func (c *Command) handleExec(passthroughArgs []string) {
args = append(args, httpArgs...)

case "postgres":
pgArgs, pgEnvs, pgErr := c.postgresFlags.buildArgs(c, port, ip, addr)
pgArgs, pgEnvs, pgCreds, pgErr := c.postgresFlags.buildArgs(c, port, ip, addr, creds)
if pgErr != nil {
argsErr = pgErr
break
}
args = append(args, pgArgs...)
envs = append(envs, pgEnvs...)
printCreds = false
creds = pgCreds

case "rdp":
args = append(args, c.rdpFlags.buildArgs(c, port, ip, addr)...)

case "ssh":
sshArgs, sshEnvs, consumedCreds, sshErr := c.sshFlags.buildArgs(c, port, ip, addr)
sshArgs, sshEnvs, sshCreds, sshErr := c.sshFlags.buildArgs(c, port, ip, addr, creds)
if sshErr != nil {
argsErr = sshErr
break
}
args = append(args, sshArgs...)
envs = append(envs, sshEnvs...)
if consumedCreds {
printCreds = false
}
creds = sshCreds

case "kube":
kubeArgs, err := c.kubeFlags.buildArgs(c, port, ip, addr)
Expand All @@ -863,12 +874,10 @@ func (c *Command) handleExec(passthroughArgs []string) {
return
}

if printCreds {
if err := c.printCredentials(); err != nil {
c.PrintCliError(fmt.Errorf("Failed to print credentials: %w", err))
c.execCmdReturnValue.Store(int32(2))
return
}
if err := c.printCredentials(creds.unconsumedSessionCredentials()); err != nil {
c.PrintCliError(fmt.Errorf("Failed to print credentials: %w", err))
c.execCmdReturnValue.Store(int32(2))
return
}

args = append(passthroughArgs, args...)
Expand Down
105 changes: 70 additions & 35 deletions internal/cmd/commands/connect/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,79 +8,114 @@ import (
"github.com/mitchellh/mapstructure"
)

type usernamePasswordCredential struct {
type usernamePassword struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`

raw *targets.SessionCredential
consumed bool
}

type sshPrivateKeyCredential struct {
type sshPrivateKey struct {
Username string `mapstructure:"username"`
PrivateKey string `mapstructure:"private_key"`

raw *targets.SessionCredential
consumed bool
}

type credentials struct {
usernamePassword []usernamePassword
sshPrivateKey []sshPrivateKey
unspecified []*targets.SessionCredential
}

func (c credentials) unconsumedSessionCredentials() []*targets.SessionCredential {
out := make([]*targets.SessionCredential, 0, len(c.sshPrivateKey)+len(c.usernamePassword)+len(c.unspecified))

// Unspecified credentials cannot be consumed
out = append(out, c.unspecified...)

for _, c := range c.sshPrivateKey {
if !c.consumed {
out = append(out, c.raw)
}
}
for _, c := range c.usernamePassword {
if !c.consumed {
out = append(out, c.raw)
}
}
return out
}

func parseCredentials(creds []*targets.SessionCredential) ([]any, error) {
func parseCredentials(creds []*targets.SessionCredential) (credentials, error) {
if creds == nil {
return nil, nil
return credentials{}, nil
}
var out []any
var out credentials
for _, cred := range creds {
if cred.CredentialSource == nil {
return nil, errors.New("missing credential source")
return credentials{}, errors.New("missing credential source")
}

var upCred usernamePasswordCredential
var spkCred sshPrivateKeyCredential
var upCred usernamePassword
var spkCred sshPrivateKey
switch credential.Type(cred.CredentialSource.CredentialType) {
case credential.UsernamePasswordType:
// Decode attributes from credential struct
if err := mapstructure.Decode(cred.Credential, &upCred); err != nil {
return nil, err
return credentials{}, err
}

if upCred.Username != "" && upCred.Password != "" {
out = append(out, upCred)
upCred.raw = cred
out.usernamePassword = append(out.usernamePassword, upCred)
continue
}

case credential.SshPrivateKeyType:
// Decode attributes from credential struct
if err := mapstructure.Decode(cred.Credential, &spkCred); err != nil {
return nil, err
return credentials{}, err
}

if spkCred.Username != "" && spkCred.PrivateKey != "" {
out = append(out, spkCred)
spkCred.raw = cred
out.sshPrivateKey = append(out.sshPrivateKey, spkCred)
continue
}
}

// Credential type is unspecified, make a best effort attempt to parse
// the Decoded field if it exists
if cred.Secret == nil || cred.Secret.Decoded == nil {
// No secret data continue to next credential
continue
}
// a username_password credential from the Decoded field if it exists
if cred.Secret != nil && cred.Secret.Decoded != nil {
switch cred.CredentialSource.Type {
case "vault", "static":
// Attempt unmarshaling into username password creds
if err := mapstructure.Decode(cred.Secret.Decoded, &upCred); err != nil {
return credentials{}, err
}
if upCred.Username != "" && upCred.Password != "" {
upCred.raw = cred
out.usernamePassword = append(out.usernamePassword, upCred)
continue
}

switch cred.CredentialSource.Type {
case "vault", "static":
// Attempt unmarshaling into username password creds
if err := mapstructure.Decode(cred.Secret.Decoded, &upCred); err != nil {
return nil, err
}
if upCred.Username != "" && upCred.Password != "" {
out = append(out, upCred)
continue
}

// Attempt unmarshaling into ssh private key creds
if err := mapstructure.Decode(cred.Secret.Decoded, &spkCred); err != nil {
return nil, err
}
if spkCred.Username != "" && spkCred.PrivateKey != "" {
out = append(out, spkCred)
continue
// Attempt unmarshaling into ssh private key creds
if err := mapstructure.Decode(cred.Secret.Decoded, &spkCred); err != nil {
return credentials{}, err
}
if spkCred.Username != "" && spkCred.PrivateKey != "" {
spkCred.raw = cred
out.sshPrivateKey = append(out.sshPrivateKey, spkCred)
continue
}
}
}

// We could not parse the credential
out.unspecified = append(out.unspecified, cred)
}

return out, nil
Expand Down
Loading

0 comments on commit fef22f0

Please # to comment.