Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Avoid race condition #137

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 65 additions & 25 deletions Network/HTTP2/H2/Manager.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ module Network.HTTP2.H2.Manager (
) where

import Data.Foldable
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Map (Map)
import qualified Data.Map.Strict as Map
import qualified System.TimeManager as T
import UnliftIO.Concurrent
import UnliftIO.Exception
Expand All @@ -28,11 +28,25 @@ import Imports

----------------------------------------------------------------

data Command = Stop (Maybe SomeException) | Add ThreadId | Delete ThreadId
data Command =
Stop (MVar ()) (Maybe SomeException)
| Add ThreadId
| RegisterTimeout ThreadId T.Handle
| Delete ThreadId

-- | Manager to manage the thread and the timer.
data Manager = Manager (TQueue Command) (TVar Int) T.Manager

data TimeoutHandle =
ThreadWithTimeout T.Handle
| ThreadWithoutTimeout

cancelTimeout :: TimeoutHandle -> IO ()
cancelTimeout (ThreadWithTimeout h) = T.cancel h
cancelTimeout ThreadWithoutTimeout = return ()

type ManagedThreads = Map ThreadId TimeoutHandle

-- | Starting a thread manager.
-- Its action is initially set to 'return ()' and should be set
-- by 'setAction'. This allows that the action can include
Expand All @@ -43,26 +57,37 @@ start timmgr = do
cnt <- newTVarIO 0
void $ forkIO $ do
labelMe "H2 thread manager"
go q Set.empty
go q Map.empty
return $ Manager q cnt timmgr
where
go q tset0 = do
-- This runs in a separate thread whose ThreadId is not known by anyone
-- else, so it cannot be killed by asynchronous exceptions.
go :: TQueue Command -> ManagedThreads -> IO ()
go q threadMap0 = do
x <- atomically $ readTQueue q
case x of
Stop err -> kill tset0 err
Add newtid ->
let tset = add newtid tset0
in go q tset
Delete oldtid ->
let tset = del oldtid tset0
in go q tset
Stop signalTimeoutsDisabled err -> do
kill signalTimeoutsDisabled threadMap0 err
Add newtid -> do
let threadMap = add newtid threadMap0
go q threadMap
RegisterTimeout tid h -> do
let threadMap = registerTimeout tid h threadMap0
go q threadMap
Delete oldtid -> do
threadMap <- del oldtid threadMap0
go q threadMap

-- | Stopping the manager.
stopAfter :: Manager -> IO a -> (Either SomeException a -> IO b) -> IO b
stopAfter (Manager q _ _) action cleanup = do
mask $ \unmask -> do
ma <- try $ unmask action
atomically $ writeTQueue q $ Stop (either Just (const Nothing) ma)
signalTimeoutsDisabled <- newEmptyMVar
atomically $ writeTQueue q $ Stop signalTimeoutsDisabled (either Just (const Nothing) ma)
-- This call to takeMVar /will/ eventually succeed, because the Manager
-- thread cannot be killed (see comment on 'go' in 'start').
takeMVar signalTimeoutsDisabled
cleanup ma

----------------------------------------------------------------
Expand Down Expand Up @@ -112,24 +137,39 @@ deleteMyId (Manager q _ _) = do

----------------------------------------------------------------

add :: ThreadId -> Set ThreadId -> Set ThreadId
add tid set = set'
where
set' = Set.insert tid set
add :: ThreadId -> ManagedThreads -> ManagedThreads
add tid = Map.insert tid ThreadWithoutTimeout

del :: ThreadId -> Set ThreadId -> Set ThreadId
del tid set = set'
where
set' = Set.delete tid set
registerTimeout :: ThreadId -> T.Handle -> ManagedThreads -> ManagedThreads
registerTimeout tid = Map.insert tid . ThreadWithTimeout

del :: ThreadId -> ManagedThreads -> IO ManagedThreads
del tid threadMap = do
forM_ (Map.lookup tid threadMap) cancelTimeout
return $ Map.delete tid threadMap

kill :: Set ThreadId -> Maybe SomeException -> IO ()
kill set err = traverse_ (\tid -> E.throwTo tid $ KilledByHttp2ThreadManager err) set
-- | Kill all threads
--
-- We first remove all threads from the timeout manager, then signal that that
-- is complete, and finally kill all threads. This avoids a race between the
-- timeout manager and our manager: we want to ensure that the exception that
-- gets delivered is 'KilledByHttp2ThreadManager', not 'TimeoutThread'.
kill :: MVar () -> ManagedThreads -> Maybe SomeException -> IO ()
kill signalTimeoutsDisabled threadMap err = do
forM_ (Map.elems threadMap) cancelTimeout
putMVar signalTimeoutsDisabled ()
forM_ (Map.keys threadMap) $ \tid ->
E.throwTo tid $ KilledByHttp2ThreadManager err

-- | Killing the IO action of the second argument on timeout.
timeoutKillThread :: Manager -> (T.Handle -> IO a) -> IO a
timeoutKillThread (Manager _ _ tmgr) action = E.bracket register T.cancel action
timeoutKillThread (Manager q _ tmgr) action = E.bracket register T.cancel action
where
register = T.registerKillThread tmgr (return ())
register = do
h <- T.registerKillThread tmgr (return ())
tid <- myThreadId
atomically $ writeTQueue q (RegisterTimeout tid h)
return h

-- | Registering closer for a resource and
-- returning a timer refresher.
Expand Down
Loading