Skip to content

Commit

Permalink
Add ability to generate ccache from client
Browse files Browse the repository at this point in the history
  • Loading branch information
tsipinakis committed Oct 7, 2023
1 parent ecb5027 commit 17fbb08
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 21 deletions.
2 changes: 1 addition & 1 deletion v8/client/TGSExchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (cl *Client) GetServiceTicket(spn string) (messages.Ticket, types.Encryptio
realm = cl.Credentials.Realm()
}

tgt, skey, err := cl.sessionTGT(realm)
tgt, _, skey, err := cl.sessionTGT(realm)
if err != nil {
return tkt, skey, err
}
Expand Down
49 changes: 44 additions & 5 deletions v8/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (cl *Client) IsConfigured() (bool, error) {
}
// Client needs to have either a password, keytab or a session already (later when loading from CCache)
if !cl.Credentials.HasPassword() && !cl.Credentials.HasKeytab() {
authTime, _, _, _, err := cl.sessionTimes(cl.Credentials.Domain())
authTime, _, _, _, _, err := cl.sessionTimes(cl.Credentials.Domain())
if err != nil || authTime.IsZero() {
return false, errors.New("client has neither a keytab nor a password set and no session")
}
Expand All @@ -169,7 +169,7 @@ func (cl *Client) Login() error {
return err
}
if !cl.Credentials.HasPassword() && !cl.Credentials.HasKeytab() {
_, endTime, _, _, err := cl.sessionTimes(cl.Credentials.Domain())
_, _, endTime, _, _, err := cl.sessionTimes(cl.Credentials.Domain())
if err != nil {
return krberror.Errorf(err, krberror.KRBMsgError, "no user credentials available and error getting any existing session")
}
Expand All @@ -193,7 +193,7 @@ func (cl *Client) Login() error {

// AffirmLogin will only perform an AS exchange with the KDC if the client does not already have a TGT.
func (cl *Client) AffirmLogin() error {
_, endTime, _, _, err := cl.sessionTimes(cl.Credentials.Domain())
_, _, endTime, _, _, err := cl.sessionTimes(cl.Credentials.Domain())
if err != nil || time.Now().UTC().After(endTime) {
err := cl.Login()
if err != nil {
Expand All @@ -208,14 +208,14 @@ func (cl *Client) realmLogin(realm string) error {
if realm == cl.Credentials.Domain() {
return cl.Login()
}
_, endTime, _, _, err := cl.sessionTimes(cl.Credentials.Domain())
_, _, endTime, _, _, err := cl.sessionTimes(cl.Credentials.Domain())
if err != nil || time.Now().UTC().After(endTime) {
err := cl.Login()
if err != nil {
return fmt.Errorf("could not get valid TGT for client's realm: %v", err)
}
}
tgt, skey, err := cl.sessionTGT(cl.Credentials.Domain())
tgt, _, skey, err := cl.sessionTGT(cl.Credentials.Domain())
if err != nil {
return err
}
Expand Down Expand Up @@ -243,6 +243,45 @@ func (cl *Client) Destroy() {
cl.Log("client destroyed")
}

func (cl *Client) GetCCache() (*credentials.CCache, error) {
tgt, flags, skey, err := cl.sessionTGT(cl.Credentials.Realm())
if err != nil {
return nil, err
}
authTime, startTime, endTime, renewTime, _, err := cl.sessionTimes(cl.Credentials.Realm())
if err != nil {
return nil, fmt.Errorf("failed to get session times while getting cache (%w)", err)
}
tgtMar, err := tgt.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal tft while getting cache (%w)", err)
}
tgtCred := credentials.Credential{
Client: credentials.Principal{
Realm: cl.Credentials.Realm(),
PrincipalName: cl.Credentials.CName(),
},
Server: credentials.Principal{
Realm: tgt.Realm,
PrincipalName: tgt.SName,
},
Key: skey,
AuthTime: authTime,
StartTime: startTime,
EndTime: endTime,
RenewTill: renewTime,
TicketFlags: flags,
Ticket: tgtMar,
}
creds := []credentials.Credential{tgtCred}
ccache := credentials.CCacheFromCredentials(creds)
ccache.DefaultPrincipal = credentials.Principal{
Realm: cl.Credentials.Realm(),
PrincipalName: cl.Credentials.CName(),
}
return ccache, nil
}

// Diagnostics runs a set of checks that the client is properly configured and writes details to the io.Writer provided.
func (cl *Client) Diagnostics(w io.Writer) error {
cl.Print(w)
Expand Down
30 changes: 20 additions & 10 deletions v8/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"sync"
"time"

"github.com/jcmturner/gofork/encoding/asn1"

"github.com/jcmturner/gokrb5/v8/iana/nametype"
"github.com/jcmturner/gokrb5/v8/krberror"
"github.com/jcmturner/gokrb5/v8/messages"
Expand Down Expand Up @@ -64,9 +66,11 @@ func (s *sessions) get(realm string) (*session, bool) {
type session struct {
realm string
authTime time.Time
startTime time.Time
endTime time.Time
renewTill time.Time
tgt messages.Ticket
tgtFlags asn1.BitString
sessionKey types.EncryptionKey
sessionKeyExpiration time.Time
cancel chan bool
Expand All @@ -77,6 +81,7 @@ type session struct {
type jsonSession struct {
Realm string
AuthTime time.Time
StartTime time.Time
EndTime time.Time
RenewTill time.Time
SessionKeyExpiration time.Time
Expand All @@ -93,9 +98,11 @@ func (cl *Client) addSession(tgt messages.Ticket, dep messages.EncKDCRepPart) {
s := &session{
realm: realm,
authTime: dep.AuthTime,
startTime: dep.StartTime,
endTime: dep.EndTime,
renewTill: dep.RenewTill,
tgt: tgt,
tgtFlags: dep.Flags,
sessionKey: dep.Key,
sessionKeyExpiration: dep.KeyExpiration,
}
Expand All @@ -109,9 +116,11 @@ func (s *session) update(tgt messages.Ticket, dep messages.EncKDCRepPart) {
s.mux.Lock()
defer s.mux.Unlock()
s.authTime = dep.AuthTime
s.startTime = dep.StartTime
s.endTime = dep.EndTime
s.renewTill = dep.RenewTill
s.tgt = tgt
s.tgtFlags = dep.Flags
s.sessionKey = dep.Key
s.sessionKeyExpiration = dep.KeyExpiration
}
Expand Down Expand Up @@ -140,17 +149,17 @@ func (s *session) valid() bool {
}

// tgtDetails is a thread safe way to get the session's realm, TGT and session key values
func (s *session) tgtDetails() (string, messages.Ticket, types.EncryptionKey) {
func (s *session) tgtDetails() (string, messages.Ticket, asn1.BitString, types.EncryptionKey) {
s.mux.RLock()
defer s.mux.RUnlock()
return s.realm, s.tgt, s.sessionKey
return s.realm, s.tgt, s.tgtFlags, s.sessionKey
}

// timeDetails is a thread safe way to get the session's validity time values
func (s *session) timeDetails() (string, time.Time, time.Time, time.Time, time.Time) {
func (s *session) timeDetails() (string, time.Time, time.Time, time.Time, time.Time, time.Time) {
s.mux.RLock()
defer s.mux.RUnlock()
return s.realm, s.authTime, s.endTime, s.renewTill, s.sessionKeyExpiration
return s.realm, s.startTime, s.authTime, s.endTime, s.renewTill, s.sessionKeyExpiration
}

// JSON return information about the held sessions in a JSON format.
Expand All @@ -164,10 +173,11 @@ func (s *sessions) JSON() (string, error) {
}
sort.Strings(keys)
for _, k := range keys {
r, at, et, rt, kt := s.Entries[k].timeDetails()
r, st, at, et, rt, kt := s.Entries[k].timeDetails()
j := jsonSession{
Realm: r,
AuthTime: at,
StartTime: st,
EndTime: et,
RenewTill: rt,
SessionKeyExpiration: kt,
Expand Down Expand Up @@ -217,7 +227,7 @@ func (cl *Client) enableAutoSessionRenewal(s *session) {

// renewTGT renews the client's TGT session.
func (cl *Client) renewTGT(s *session) error {
realm, tgt, skey := s.tgtDetails()
realm, tgt, _, skey := s.tgtDetails()
spn := types.PrincipalName{
NameType: nametype.KRB_NT_SRV_INST,
NameString: []string{"krbtgt", realm},
Expand Down Expand Up @@ -266,7 +276,7 @@ func (cl *Client) ensureValidSession(realm string) error {
}

// sessionTGTDetails is a thread safe way to get the TGT and session key values for a realm
func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey types.EncryptionKey, err error) {
func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, flags asn1.BitString, sessionKey types.EncryptionKey, err error) {
err = cl.ensureValidSession(realm)
if err != nil {
return
Expand All @@ -276,18 +286,18 @@ func (cl *Client) sessionTGT(realm string) (tgt messages.Ticket, sessionKey type
err = fmt.Errorf("could not find TGT session for %s", realm)
return
}
_, tgt, sessionKey = s.tgtDetails()
_, tgt, flags, sessionKey = s.tgtDetails()
return
}

// sessionTimes provides the timing information with regards to a session for the realm specified.
func (cl *Client) sessionTimes(realm string) (authTime, endTime, renewTime, sessionExp time.Time, err error) {
func (cl *Client) sessionTimes(realm string) (authTime, startTime, endTime, renewTime, sessionExp time.Time, err error) {
s, ok := cl.sessions.get(realm)
if !ok {
err = fmt.Errorf("could not find TGT session for %s", realm)
return
}
_, authTime, endTime, renewTime, sessionExp = s.timeDetails()
_, startTime, authTime, endTime, renewTime, sessionExp = s.timeDetails()
return
}

Expand Down
14 changes: 9 additions & 5 deletions v8/client/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package client
import (
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"os"
"runtime"
"sync"
Expand Down Expand Up @@ -55,12 +55,12 @@ func TestMultiThreadedClientSession(t *testing.T) {
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
tgt, _, err := cl.sessionTGT("TEST.GOKRB5")
tgt, _, _, err := cl.sessionTGT("TEST.GOKRB5")
if err != nil || tgt.Realm != "TEST.GOKRB5" {
t.Logf("error getting session: %v", err)
}
_, _, _, r, _ := cl.sessionTimes("TEST.GOKRB5")
fmt.Fprintf(io.Discard, "%v", r)
_, _, _, _, r, _ := cl.sessionTimes("TEST.GOKRB5")
fmt.Fprintf(ioutil.Discard, "%v", r)
}()
time.Sleep(time.Second)
}
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestClient_AutoRenew_Goroutine(t *testing.T) {
n := runtime.NumGoroutine()
for i := 0; i < 24; i++ {
time.Sleep(time.Second * 5)
_, endTime, _, _, err := cl.sessionTimes("TEST.GOKRB5")
_, _, endTime, _, _, err := cl.sessionTimes("TEST.GOKRB5")
if err != nil {
t.Errorf("could not get client's session: %v", err)
}
Expand Down Expand Up @@ -123,6 +123,7 @@ func TestSessions_JSON(t *testing.T) {
e := &session{
realm: realm,
authTime: time.Unix(int64(0+i), 0).UTC(),
startTime: time.Unix(int64(1+i), 0).UTC(),
endTime: time.Unix(int64(10+i), 0).UTC(),
renewTill: time.Unix(int64(20+i), 0).UTC(),
sessionKeyExpiration: time.Unix(int64(30+i), 0).UTC(),
Expand All @@ -137,20 +138,23 @@ func TestSessions_JSON(t *testing.T) {
{
"Realm": "test0",
"AuthTime": "1970-01-01T00:00:00Z",
"StartTime": "1970-01-01T00:00:01Z",
"EndTime": "1970-01-01T00:00:10Z",
"RenewTill": "1970-01-01T00:00:20Z",
"SessionKeyExpiration": "1970-01-01T00:00:30Z"
},
{
"Realm": "test1",
"AuthTime": "1970-01-01T00:00:01Z",
"StartTime": "1970-01-01T00:00:02Z",
"EndTime": "1970-01-01T00:00:11Z",
"RenewTill": "1970-01-01T00:00:21Z",
"SessionKeyExpiration": "1970-01-01T00:00:31Z"
},
{
"Realm": "test2",
"AuthTime": "1970-01-01T00:00:02Z",
"StartTime": "1970-01-01T00:00:03Z",
"EndTime": "1970-01-01T00:00:12Z",
"RenewTill": "1970-01-01T00:00:22Z",
"SessionKeyExpiration": "1970-01-01T00:00:32Z"
Expand Down

0 comments on commit 17fbb08

Please # to comment.