{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
module Network.Mux.Timeout
( TimeoutFn
, withTimeoutSerial
, withTimeoutSerialNative
, withTimeoutSerialAlternative
) where
import Control.Concurrent.Class.MonadSTM
import Control.Exception (asyncExceptionFromException,
asyncExceptionToException)
import Control.Monad
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI (MonadTimer, registerDelay)
import Control.Monad.Class.MonadTimer.SI qualified as MonadTimer
type TimeoutFn m = forall a. DiffTime -> m a -> m (Maybe a)
withTimeoutSerial, withTimeoutSerialNative
:: forall m b. (MonadAsync m, MonadFork m,
MonadMonotonicTime m, MonadTimer m,
MonadMask m, MonadThrow (STM m))
=> (TimeoutFn m -> m b) -> m b
#if defined(mingw32_HOST_OS)
withTimeoutSerial = withTimeoutSerialAlternative
#else
withTimeoutSerial :: forall (m :: * -> *) b.
(MonadAsync m, MonadFork m, MonadMonotonicTime m, MonadTimer m,
MonadMask m, MonadThrow (STM m)) =>
(TimeoutFn m -> m b) -> m b
withTimeoutSerial = (TimeoutFn m -> m b) -> m b
forall (m :: * -> *) b.
(MonadAsync m, MonadFork m, MonadMonotonicTime m, MonadTimer m,
MonadMask m, MonadThrow (STM m)) =>
(TimeoutFn m -> m b) -> m b
withTimeoutSerialNative
#endif
withTimeoutSerialNative :: forall (m :: * -> *) b.
(MonadAsync m, MonadFork m, MonadMonotonicTime m, MonadTimer m,
MonadMask m, MonadThrow (STM m)) =>
(TimeoutFn m -> m b) -> m b
withTimeoutSerialNative TimeoutFn m -> m b
body = TimeoutFn m -> m b
body DiffTime -> m a -> m (Maybe a)
TimeoutFn m
forall (m :: * -> *) a.
MonadTimer m =>
DiffTime -> m a -> m (Maybe a)
MonadTimer.timeout
withTimeoutSerialAlternative
:: forall m b. (MonadAsync m, MonadFork m,
MonadMonotonicTime m, MonadTimer m,
MonadMask m, MonadThrow (STM m))
=> (TimeoutFn m -> m b) -> m b
withTimeoutSerialAlternative :: forall (m :: * -> *) b.
(MonadAsync m, MonadFork m, MonadMonotonicTime m, MonadTimer m,
MonadMask m, MonadThrow (STM m)) =>
(TimeoutFn m -> m b) -> m b
withTimeoutSerialAlternative TimeoutFn m -> m b
body = do
monitorState <- m (MonitorState m)
forall (m :: * -> *). MonadSTM m => m (MonitorState m)
newMonitorState
withAsync (monitoringThread monitorState) $ \Async m ()
_ ->
TimeoutFn m -> m b
body (MonitorState m -> DiffTime -> m a -> m (Maybe a)
forall (m :: * -> *) a.
(MonadFork m, MonadMonotonicTime m, MonadTimer m, MonadMask m,
MonadThrow (STM m)) =>
MonitorState m -> DiffTime -> m a -> m (Maybe a)
timeout MonitorState m
monitorState)
data MonitorState m =
MonitorState {
forall (m :: * -> *). MonitorState m -> TVar m (NextTimeout m)
nextTimeoutVar :: !(TVar m (NextTimeout m)),
forall (m :: * -> *). MonitorState m -> TVar m Time
curDeadlineVar :: !(TVar m Time),
forall (m :: * -> *). MonitorState m -> TVar m Bool
deadlineResetVar :: !(TVar m Bool)
}
data NextTimeout m = NoNextTimeout
| NextTimeout
!(ThreadId m)
!Time
!(TVar m TimeoutState)
newMonitorState :: MonadSTM m => m (MonitorState m)
newMonitorState :: forall (m :: * -> *). MonadSTM m => m (MonitorState m)
newMonitorState = do
nextTimeoutVar <- NextTimeout m -> m (TVar m (NextTimeout m))
forall a. a -> m (TVar m a)
forall (m :: * -> *) a. MonadSTM m => a -> m (TVar m a)
newTVarIO NextTimeout m
forall (m :: * -> *). NextTimeout m
NoNextTimeout
curDeadlineVar <- newTVarIO (Time 0)
deadlineResetVar <- newTVarIO False
return MonitorState{..}
setNewTimer :: MonadSTM m
=> MonitorState m
-> ThreadId m
-> Time
-> TVar m TimeoutState
-> m ()
setNewTimer :: forall (m :: * -> *).
MonadSTM m =>
MonitorState m -> ThreadId m -> Time -> TVar m TimeoutState -> m ()
setNewTimer MonitorState{TVar m (NextTimeout m)
nextTimeoutVar :: forall (m :: * -> *). MonitorState m -> TVar m (NextTimeout m)
nextTimeoutVar :: TVar m (NextTimeout m)
nextTimeoutVar, TVar m Time
curDeadlineVar :: forall (m :: * -> *). MonitorState m -> TVar m Time
curDeadlineVar :: TVar m Time
curDeadlineVar, TVar m Bool
deadlineResetVar :: forall (m :: * -> *). MonitorState m -> TVar m Bool
deadlineResetVar :: TVar m Bool
deadlineResetVar}
!ThreadId m
tid !Time
deadline !TVar m TimeoutState
stateVar =
STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
TVar m (NextTimeout m) -> NextTimeout m -> STM m ()
forall a. TVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m (NextTimeout m)
nextTimeoutVar (ThreadId m -> Time -> TVar m TimeoutState -> NextTimeout m
forall (m :: * -> *).
ThreadId m -> Time -> TVar m TimeoutState -> NextTimeout m
NextTimeout ThreadId m
tid Time
deadline TVar m TimeoutState
stateVar)
curDeadline <- TVar m Time -> STM m Time
forall a. TVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m Time
curDeadlineVar
when (deadline < curDeadline) $
writeTVar deadlineResetVar True
readNextTimeout :: MonadSTM m
=> MonitorState m
-> m (ThreadId m, Time, TVar m TimeoutState)
readNextTimeout :: forall (m :: * -> *).
MonadSTM m =>
MonitorState m -> m (ThreadId m, Time, TVar m TimeoutState)
readNextTimeout MonitorState{TVar m (NextTimeout m)
nextTimeoutVar :: forall (m :: * -> *). MonitorState m -> TVar m (NextTimeout m)
nextTimeoutVar :: TVar m (NextTimeout m)
nextTimeoutVar, TVar m Time
curDeadlineVar :: forall (m :: * -> *). MonitorState m -> TVar m Time
curDeadlineVar :: TVar m Time
curDeadlineVar, TVar m Bool
deadlineResetVar :: forall (m :: * -> *). MonitorState m -> TVar m Bool
deadlineResetVar :: TVar m Bool
deadlineResetVar} = do
STM m (ThreadId m, Time, TVar m TimeoutState)
-> m (ThreadId m, Time, TVar m TimeoutState)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (ThreadId m, Time, TVar m TimeoutState)
-> m (ThreadId m, Time, TVar m TimeoutState))
-> STM m (ThreadId m, Time, TVar m TimeoutState)
-> m (ThreadId m, Time, TVar m TimeoutState)
forall a b. (a -> b) -> a -> b
$ do
nextTimeout <- TVar m (NextTimeout m) -> STM m (NextTimeout m)
forall a. TVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m (NextTimeout m)
nextTimeoutVar
case nextTimeout of
NextTimeout m
NoNextTimeout -> STM m (ThreadId m, Time, TVar m TimeoutState)
forall a. STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
NextTimeout ThreadId m
tid Time
deadline TVar m TimeoutState
stateVar -> do
TVar m (NextTimeout m) -> NextTimeout m -> STM m ()
forall a. TVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m (NextTimeout m)
nextTimeoutVar NextTimeout m
forall (m :: * -> *). NextTimeout m
NoNextTimeout
TVar m Time -> Time -> STM m ()
forall a. TVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m Time
curDeadlineVar Time
deadline
TVar m Bool -> Bool -> STM m ()
forall a. TVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m Bool
deadlineResetVar Bool
False
(ThreadId m, Time, TVar m TimeoutState)
-> STM m (ThreadId m, Time, TVar m TimeoutState)
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ThreadId m
tid, Time
deadline, TVar m TimeoutState
stateVar)
data TimeoutState = TimeoutPending
| TimeoutCancelled
| TimeoutFired
| TimeoutTerminated
data TimeoutException = TimeoutException deriving Int -> TimeoutException -> ShowS
[TimeoutException] -> ShowS
TimeoutException -> String
(Int -> TimeoutException -> ShowS)
-> (TimeoutException -> String)
-> ([TimeoutException] -> ShowS)
-> Show TimeoutException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TimeoutException -> ShowS
showsPrec :: Int -> TimeoutException -> ShowS
$cshow :: TimeoutException -> String
show :: TimeoutException -> String
$cshowList :: [TimeoutException] -> ShowS
showList :: [TimeoutException] -> ShowS
Show
instance Exception TimeoutException where
toException :: TimeoutException -> SomeException
toException = TimeoutException -> SomeException
forall e. Exception e => e -> SomeException
asyncExceptionToException
fromException :: SomeException -> Maybe TimeoutException
fromException = SomeException -> Maybe TimeoutException
forall e. Exception e => SomeException -> Maybe e
asyncExceptionFromException
timeout :: forall m a.
(MonadFork m, MonadMonotonicTime m, MonadTimer m,
MonadMask m, MonadThrow (STM m))
=> MonitorState m
-> DiffTime -> m a -> m (Maybe a)
timeout :: forall (m :: * -> *) a.
(MonadFork m, MonadMonotonicTime m, MonadTimer m, MonadMask m,
MonadThrow (STM m)) =>
MonitorState m -> DiffTime -> m a -> m (Maybe a)
timeout MonitorState m
_ DiffTime
delay m a
action | DiffTime
delay DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
< DiffTime
0 = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> m a -> m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m a
action
timeout MonitorState m
_ DiffTime
delay m a
_ | DiffTime
delay DiffTime -> DiffTime -> Bool
forall a. Eq a => a -> a -> Bool
== DiffTime
0 = Maybe a -> m (Maybe a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
timeout MonitorState m
monitorState DiffTime
delay m a
action =
((forall a. m a -> m a) -> m (Maybe a)) -> m (Maybe a)
forall b. ((forall a. m a -> m a) -> m b) -> m b
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m (Maybe a)) -> m (Maybe a))
-> ((forall a. m a -> m a) -> m (Maybe a)) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
tid <- m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
timeoutStateVar <- newTVarIO TimeoutPending
now <- getMonotonicTime
let deadline = DiffTime -> Time -> Time
addTime DiffTime
delay Time
now
setNewTimer monitorState tid deadline timeoutStateVar
result <- restore action
timeoutFired <- atomically $ do
st <- readTVar timeoutStateVar
case st of
TimeoutState
TimeoutFired -> Bool -> STM m Bool
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
TimeoutState
TimeoutPending -> TVar m TimeoutState -> TimeoutState -> STM m ()
forall a. TVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m TimeoutState
timeoutStateVar TimeoutState
TimeoutCancelled
STM m () -> STM m Bool -> STM m Bool
forall a b. STM m a -> STM m b -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> STM m Bool
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
TimeoutState
_ -> TimeoutAssertion -> STM m Bool
forall (m :: * -> *) e a.
(MonadSTM m, MonadThrow (STM m), Exception e) =>
e -> STM m a
throwSTM TimeoutAssertion
TimeoutImpossibleTimeoutState
if not timeoutFired
then return (Just result)
else atomically $ do
st <- readTVar timeoutStateVar
case st of
TimeoutState
TimeoutFired -> STM m (Maybe a)
forall a. STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
TimeoutState
TimeoutTerminated -> TimeoutAssertion -> STM m (Maybe a)
forall (m :: * -> *) e a.
(MonadSTM m, MonadThrow (STM m), Exception e) =>
e -> STM m a
throwSTM TimeoutAssertion
TimeoutImpossibleReachedTerminated
TimeoutState
_ -> TimeoutAssertion -> STM m (Maybe a)
forall (m :: * -> *) e a.
(MonadSTM m, MonadThrow (STM m), Exception e) =>
e -> STM m a
throwSTM TimeoutAssertion
TimeoutImpossibleTimeoutState
m (Maybe a) -> (TimeoutException -> m (Maybe a)) -> m (Maybe a)
forall e a. Exception e => m a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` \TimeoutException
TimeoutException -> Maybe a -> m (Maybe a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
monitoringThread :: (MonadFork m, MonadSTM m,
MonadMonotonicTime m,
MonadTimer m, MonadThrow (STM m))
=> MonitorState m -> m ()
monitoringThread :: forall (m :: * -> *).
(MonadFork m, MonadSTM m, MonadMonotonicTime m, MonadTimer m,
MonadThrow (STM m)) =>
MonitorState m -> m ()
monitoringThread monitorState :: MonitorState m
monitorState@MonitorState{TVar m Bool
deadlineResetVar :: forall (m :: * -> *). MonitorState m -> TVar m Bool
deadlineResetVar :: TVar m Bool
deadlineResetVar} = do
threadId <- m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
labelThread threadId "timeout-monitoring-thread"
forever $ do
(tid, deadline, timeoutStateVar) <- readNextTimeout monitorState
now <- getMonotonicTime
let delay = Time -> Time -> DiffTime
diffTime Time
deadline Time
now
when (delay > 0) $ do
timerExpired <- registerDelay delay
atomically $
(readTVar timerExpired >>= check)
`orElse`
(readTVar deadlineResetVar >>= check)
cancelled <- atomically $ do
st <- readTVar timeoutStateVar
case st of
TimeoutState
TimeoutPending -> TVar m TimeoutState -> TimeoutState -> STM m ()
forall a. TVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m TimeoutState
timeoutStateVar TimeoutState
TimeoutFired
STM m () -> STM m Bool -> STM m Bool
forall a b. STM m a -> STM m b -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> STM m Bool
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
TimeoutState
TimeoutCancelled -> Bool -> STM m Bool
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
TimeoutState
_ -> TimeoutAssertion -> STM m Bool
forall (m :: * -> *) e a.
(MonadSTM m, MonadThrow (STM m), Exception e) =>
e -> STM m a
throwSTM TimeoutAssertion
TimeoutImpossibleMonitorState
unless cancelled $ do
throwTo tid TimeoutException
atomically $ writeTVar timeoutStateVar TimeoutTerminated
data TimeoutAssertion = TimeoutImpossibleReachedTerminated
| TimeoutImpossibleTimeoutState
| TimeoutImpossibleMonitorState
deriving Int -> TimeoutAssertion -> ShowS
[TimeoutAssertion] -> ShowS
TimeoutAssertion -> String
(Int -> TimeoutAssertion -> ShowS)
-> (TimeoutAssertion -> String)
-> ([TimeoutAssertion] -> ShowS)
-> Show TimeoutAssertion
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TimeoutAssertion -> ShowS
showsPrec :: Int -> TimeoutAssertion -> ShowS
$cshow :: TimeoutAssertion -> String
show :: TimeoutAssertion -> String
$cshowList :: [TimeoutAssertion] -> ShowS
showList :: [TimeoutAssertion] -> ShowS
Show
instance Exception TimeoutAssertion