diff --git a/server/session.go b/server/session.go index 2b66f61b..e3a4dfc4 100644 --- a/server/session.go +++ b/server/session.go @@ -17,6 +17,7 @@ package server import ( "context" "fmt" + "io" "log/slog" "net/url" "sync" @@ -30,6 +31,7 @@ import ( // --- Session type session struct { + io.Closer sync.Mutex id SessionId clientIdentity string @@ -75,7 +77,9 @@ func startSession(sessionId SessionId, sessionMetadata *proto.SessionMetadata, s return s } -func (s *session) closeChannels() { +func (s *session) Close() { + s.Lock() + defer s.Unlock() s.cancel() if s.heartbeatCh != nil { close(s.heartbeatCh) @@ -84,11 +88,6 @@ func (s *session) closeChannels() { s.log.Debug("Session channels closed") } -func (s *session) close() error { - s.log.Info("Session closing") - return s.delete() -} - func (s *session) delete() error { // Delete ephemeral data associated with this session sessionKey := SessionKey(s.id) @@ -171,8 +170,7 @@ func (s *session) waitForHeartbeats() { case <-timeoutCh: s.log.Warn("Session expired") - s.Lock() - s.closeChannels() + s.Close() err := s.delete() if err != nil { @@ -181,7 +179,6 @@ func (s *session) waitForHeartbeats() { slog.Any("error", err), ) } - s.Unlock() s.sm.Lock() s.sm.sessions.Remove(s.id) diff --git a/server/session_manager.go b/server/session_manager.go index 8d9bd134..648c9210 100644 --- a/server/session_manager.go +++ b/server/session_manager.go @@ -202,10 +202,10 @@ func (sm *sessionManager) CloseSession(request *proto.CloseSessionRequest) (*pro } sm.sessions.Remove(s.id) sm.Unlock() - s.Lock() - defer s.Unlock() - s.closeChannels() - err = s.close() + + s.log.Info("Session closing") + s.Close() + err = s.delete() if err != nil { return nil, err } @@ -295,9 +295,7 @@ func (sm *sessionManager) Close() error { sm.cancel() for _, s := range sm.sessions.Values() { sm.sessions.Remove(s.id) - s.Lock() - s.closeChannels() - s.Unlock() + s.Close() } sm.activeSessions.Unregister()