8
8
{-# LANGUAGE RecordWildCards #-}
9
9
{-# LANGUAGE ScopedTypeVariables #-}
10
10
{-# LANGUAGE TypeFamilies #-}
11
+ {-# LANGUAGE TupleSections #-}
11
12
12
13
module Development.IDE.Graph.Internal.Database (newDatabase , incDatabase , build , getDirtySet , getKeysAndVisitAge ) where
13
14
@@ -32,14 +33,15 @@ import Data.IORef.Extra
32
33
import Data.Maybe
33
34
import Data.Traversable (for )
34
35
import Data.Tuple.Extra
36
+ import Debug.Trace (traceM )
35
37
import Development.IDE.Graph.Classes
36
38
import Development.IDE.Graph.Internal.Rules
37
39
import Development.IDE.Graph.Internal.Types
38
40
import qualified Focus
39
41
import qualified ListT
40
42
import qualified StmContainers.Map as SMap
43
+ import System.Time.Extra (duration , sleep )
41
44
import System.IO.Unsafe
42
- import System.Time.Extra (duration )
43
45
44
46
newDatabase :: Dynamic -> TheRules -> IO Database
45
47
newDatabase databaseExtra databaseRules = do
@@ -120,7 +122,7 @@ builder db@Database{..} stack keys = withRunInIO $ \(RunInIO run) -> do
120
122
pure (id , val)
121
123
122
124
toForceList <- liftIO $ readTVarIO toForce
123
- let waitAll = run $ mapConcurrentlyAIO_ id toForceList
125
+ let waitAll = run $ waitConcurrently_ toForceList
124
126
case toForceList of
125
127
[] -> return $ Left results
126
128
_ -> return $ Right $ do
@@ -170,6 +172,10 @@ compute db@Database{..} stack key mode result = do
170
172
deps | not (null deps)
171
173
&& runChanged /= ChangedNothing
172
174
-> do
175
+ -- IMPORTANT: record the reverse deps **before** marking the key Clean.
176
+ -- If an async exception strikes before the deps have been recorded,
177
+ -- we won't be able to accurately propagate dirtiness for this key
178
+ -- on the next build.
173
179
void $
174
180
updateReverseDeps key db
175
181
(getResultDepsDefault [] previousDeps)
@@ -224,7 +230,8 @@ updateReverseDeps
224
230
-> [Key ] -- ^ Previous direct dependencies of Id
225
231
-> HashSet Key -- ^ Current direct dependencies of Id
226
232
-> IO ()
227
- updateReverseDeps myId db prev new = uninterruptibleMask_ $ do
233
+ -- mask to ensure that all the reverse dependencies are updated
234
+ updateReverseDeps myId db prev new = do
228
235
forM_ prev $ \ d ->
229
236
unless (d `HSet.member` new) $
230
237
doOne (HSet. delete myId) d
@@ -252,20 +259,27 @@ transitiveDirtySet database = flip State.execStateT HSet.empty . traverse_ loop
252
259
next <- lift $ atomically $ getReverseDependencies database x
253
260
traverse_ loop (maybe mempty HSet. toList next)
254
261
255
- -- | IO extended to track created asyncs to clean them up when the thread is killed,
256
- -- generalizing 'withAsync'
262
+ --------------------------------------------------------------------------------
263
+ -- Asynchronous computations with cancellation
264
+
265
+ -- | A simple monad to implement cancellation on top of 'Async',
266
+ -- generalizing 'withAsync' to monadic scopes.
257
267
newtype AIO a = AIO { unAIO :: ReaderT (IORef [Async () ]) IO a }
258
268
deriving newtype (Applicative , Functor , Monad , MonadIO )
259
269
270
+ -- | Run the monadic computation, cancelling all the spawned asyncs if an exception arises
260
271
runAIO :: AIO a -> IO a
261
272
runAIO (AIO act) = do
262
273
asyncs <- newIORef []
263
274
runReaderT act asyncs `onException` cleanupAsync asyncs
264
275
276
+ -- | Like 'async' but with built-in cancellation.
277
+ -- Returns an IO action to wait on the result.
265
278
asyncWithCleanUp :: AIO a -> AIO (IO a )
266
279
asyncWithCleanUp act = do
267
280
st <- AIO ask
268
281
io <- unliftAIO act
282
+ -- mask to make sure we keep track of the spawned async
269
283
liftIO $ uninterruptibleMask $ \ restore -> do
270
284
a <- async $ restore io
271
285
atomicModifyIORef'_ st (void a : )
@@ -284,27 +298,40 @@ withRunInIO k = do
284
298
k $ RunInIO (\ aio -> runReaderT (unAIO aio) st)
285
299
286
300
cleanupAsync :: IORef [Async a ] -> IO ()
287
- cleanupAsync ref = uninterruptibleMask_ $ do
288
- asyncs <- readIORef ref
301
+ -- mask to make sure we interrupt all the asyncs
302
+ cleanupAsync ref = uninterruptibleMask $ \ unmask -> do
303
+ asyncs <- atomicModifyIORef' ref ([] ,)
304
+ -- interrupt all the asyncs without waiting
289
305
mapM_ (\ a -> throwTo (asyncThreadId a) AsyncCancelled ) asyncs
290
- mapM_ waitCatch asyncs
306
+ -- Wait until all the asyncs are done
307
+ -- But if it takes more than 10 seconds, log to stderr
308
+ unless (null asyncs) $ do
309
+ let warnIfTakingTooLong = unmask $ forever $ do
310
+ sleep 10
311
+ traceM " cleanupAsync: waiting for asyncs to finish"
312
+ withAsync warnIfTakingTooLong $ \ _ ->
313
+ mapM_ waitCatch asyncs
314
+
315
+ data Wait
316
+ = Wait { justWait :: ! (IO () )}
317
+ | Spawn { justWait :: ! (IO () )}
291
318
292
- data Wait a
293
- = Wait { justWait :: ! a }
294
- | Spawn { justWait :: ! a }
295
- deriving Functor
319
+ fmapWait :: (IO () -> IO () ) -> Wait -> Wait
320
+ fmapWait f (Wait io) = Wait (f io)
321
+ fmapWait f (Spawn io) = Spawn (f io)
296
322
297
- waitOrSpawn :: Wait ( IO a ) -> IO (Either (IO a ) (Async a ))
323
+ waitOrSpawn :: Wait -> IO (Either (IO () ) (Async () ))
298
324
waitOrSpawn (Wait io) = pure $ Left io
299
325
waitOrSpawn (Spawn io) = Right <$> async io
300
326
301
- mapConcurrentlyAIO_ :: ( a -> IO () ) -> [Wait a ] -> AIO ()
302
- mapConcurrentlyAIO_ _ [] = pure ()
303
- mapConcurrentlyAIO_ f [one] = liftIO $ justWait $ fmap f one
304
- mapConcurrentlyAIO_ f many = do
327
+ waitConcurrently_ :: [Wait ] -> AIO ()
328
+ waitConcurrently_ [] = pure ()
329
+ waitConcurrently_ [one] = liftIO $ justWait one
330
+ waitConcurrently_ many = do
305
331
ref <- AIO ask
306
- waits <- liftIO $ uninterruptibleMask $ \ restore -> do
307
- waits <- liftIO $ traverse (waitOrSpawn . fmap (restore . f)) many
332
+ -- mask to make sure we keep track of all the asyncs
333
+ waits <- liftIO $ uninterruptibleMask $ \ unmask -> do
334
+ waits <- liftIO $ traverse (waitOrSpawn . fmapWait unmask) many
308
335
let asyncs = rights waits
309
336
liftIO $ atomicModifyIORef'_ ref (asyncs ++ )
310
337
return waits
0 commit comments