{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Ouroboros.Network.KeepAlive.Registry
( KeepAliveRegistry (..)
, newKeepAliveRegistry
, bracketKeepAliveClient
, readPeerGSVs
) where
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Set (Set)
import Data.Set qualified as Set
import Control.Concurrent.Class.MonadSTM.Strict
import Control.Exception (assert)
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadThrow
import Ouroboros.Network.DeltaQ
data KeepAliveRegistry peer m = KeepAliveRegistry {
forall peer (m :: * -> *).
KeepAliveRegistry peer m -> StrictTVar m (Map peer PeerGSV)
dqRegistry
:: StrictTVar m (Map peer PeerGSV),
forall peer (m :: * -> *).
KeepAliveRegistry peer m
-> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry
:: StrictTVar m (Map peer (ThreadId m, StrictTMVar m ())),
forall peer (m :: * -> *).
KeepAliveRegistry peer m -> StrictTVar m (Set peer)
dyingRegistry
:: StrictTVar m (Set peer)
}
newKeepAliveRegistry :: MonadSTM m
=> m (KeepAliveRegistry peer m)
newKeepAliveRegistry :: forall (m :: * -> *) peer.
MonadSTM m =>
m (KeepAliveRegistry peer m)
newKeepAliveRegistry = StrictTVar m (Map peer PeerGSV)
-> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
-> StrictTVar m (Set peer)
-> KeepAliveRegistry peer m
forall peer (m :: * -> *).
StrictTVar m (Map peer PeerGSV)
-> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
-> StrictTVar m (Set peer)
-> KeepAliveRegistry peer m
KeepAliveRegistry (StrictTVar m (Map peer PeerGSV)
-> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
-> StrictTVar m (Set peer)
-> KeepAliveRegistry peer m)
-> m (StrictTVar m (Map peer PeerGSV))
-> m (StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
-> StrictTVar m (Set peer) -> KeepAliveRegistry peer m)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map peer PeerGSV -> m (StrictTVar m (Map peer PeerGSV))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO Map peer PeerGSV
forall k a. Map k a
Map.empty
m (StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
-> StrictTVar m (Set peer) -> KeepAliveRegistry peer m)
-> m (StrictTVar m (Map peer (ThreadId m, StrictTMVar m ())))
-> m (StrictTVar m (Set peer) -> KeepAliveRegistry peer m)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Map peer (ThreadId m, StrictTMVar m ())
-> m (StrictTVar m (Map peer (ThreadId m, StrictTMVar m ())))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO Map peer (ThreadId m, StrictTMVar m ())
forall k a. Map k a
Map.empty
m (StrictTVar m (Set peer) -> KeepAliveRegistry peer m)
-> m (StrictTVar m (Set peer)) -> m (KeepAliveRegistry peer m)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Set peer -> m (StrictTVar m (Set peer))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO Set peer
forall a. Set a
Set.empty
bracketKeepAliveClient :: forall m a peer.
(MonadSTM m, MonadFork m, MonadMask m, Ord peer)
=> KeepAliveRegistry peer m
-> peer
-> (StrictTVar m (Map peer PeerGSV) -> m a)
-> m a
bracketKeepAliveClient :: forall (m :: * -> *) a peer.
(MonadSTM m, MonadFork m, MonadMask m, Ord peer) =>
KeepAliveRegistry peer m
-> peer -> (StrictTVar m (Map peer PeerGSV) -> m a) -> m a
bracketKeepAliveClient KeepAliveRegistry { StrictTVar m (Map peer PeerGSV)
dqRegistry :: forall peer (m :: * -> *).
KeepAliveRegistry peer m -> StrictTVar m (Map peer PeerGSV)
dqRegistry :: StrictTVar m (Map peer PeerGSV)
dqRegistry, StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry :: forall peer (m :: * -> *).
KeepAliveRegistry peer m
-> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry :: StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry, StrictTVar m (Set peer)
dyingRegistry :: forall peer (m :: * -> *).
KeepAliveRegistry peer m -> StrictTVar m (Set peer)
dyingRegistry :: StrictTVar m (Set peer)
dyingRegistry } peer
peer StrictTVar m (Map peer PeerGSV) -> m a
action = do
m () -> m () -> m a -> m a
forall a b c. m a -> m b -> m c -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> m b -> m c -> m c
bracket_ m ()
register m ()
unregister (StrictTVar m (Map peer PeerGSV) -> m a
action StrictTVar m (Map peer PeerGSV)
dqRegistry)
where
register :: m ()
register :: m ()
register =
STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
dr <- StrictTVar m (Map peer PeerGSV) -> STM m (Map peer PeerGSV)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (Map peer PeerGSV)
dqRegistry
check (peer `Map.notMember` dr)
modifyTVar dqRegistry $ \Map peer PeerGSV
m ->
Bool -> Map peer PeerGSV -> Map peer PeerGSV
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Map peer PeerGSV -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.notMember` Map peer PeerGSV
m) (Map peer PeerGSV -> Map peer PeerGSV)
-> Map peer PeerGSV -> Map peer PeerGSV
forall a b. (a -> b) -> a -> b
$
peer -> PeerGSV -> Map peer PeerGSV -> Map peer PeerGSV
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert peer
peer PeerGSV
defaultGSV Map peer PeerGSV
m
unregister :: m ()
unregister :: m ()
unregister = m () -> m ()
forall a. m a -> m a
forall (m :: * -> *) a. MonadMask m => m a -> m a
uninterruptibleMask_ (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
fetchclient_m <- STM m (Maybe (ThreadId m, StrictTMVar m ()))
-> m (Maybe (ThreadId m, StrictTMVar m ()))
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Maybe (ThreadId m, StrictTMVar m ()))
-> m (Maybe (ThreadId m, StrictTMVar m ())))
-> STM m (Maybe (ThreadId m, StrictTMVar m ()))
-> m (Maybe (ThreadId m, StrictTMVar m ()))
forall a b. (a -> b) -> a -> b
$ do
fetchclients <- StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
-> STM m (Map peer (ThreadId m, StrictTMVar m ()))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry
case Map.lookup peer fetchclients of
Maybe (ThreadId m, StrictTMVar m ())
Nothing -> do
StrictTVar m (Map peer PeerGSV)
-> (Map peer PeerGSV -> Map peer PeerGSV) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar StrictTVar m (Map peer PeerGSV)
dqRegistry ((Map peer PeerGSV -> Map peer PeerGSV) -> STM m ())
-> (Map peer PeerGSV -> Map peer PeerGSV) -> STM m ()
forall a b. (a -> b) -> a -> b
$ \Map peer PeerGSV
m ->
Bool -> Map peer PeerGSV -> Map peer PeerGSV
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Map peer PeerGSV -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map peer PeerGSV
m) (Map peer PeerGSV -> Map peer PeerGSV)
-> Map peer PeerGSV -> Map peer PeerGSV
forall a b. (a -> b) -> a -> b
$
peer -> Map peer PeerGSV -> Map peer PeerGSV
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete peer
peer Map peer PeerGSV
m
Maybe (ThreadId m, StrictTMVar m ())
-> STM m (Maybe (ThreadId m, StrictTMVar m ()))
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (ThreadId m, StrictTMVar m ())
forall a. Maybe a
Nothing
Just (ThreadId m, StrictTMVar m ())
rc -> do
StrictTVar m (Set peer) -> (Set peer -> Set peer) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar StrictTVar m (Set peer)
dyingRegistry ((Set peer -> Set peer) -> STM m ())
-> (Set peer -> Set peer) -> STM m ()
forall a b. (a -> b) -> a -> b
$ \Set peer
s ->
Bool -> Set peer -> Set peer
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Set peer -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.notMember` Set peer
s) (Set peer -> Set peer) -> Set peer -> Set peer
forall a b. (a -> b) -> a -> b
$
peer -> Set peer -> Set peer
forall a. Ord a => a -> Set a -> Set a
Set.insert peer
peer Set peer
s
Maybe (ThreadId m, StrictTMVar m ())
-> STM m (Maybe (ThreadId m, StrictTMVar m ()))
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ThreadId m, StrictTMVar m ())
-> STM m (Maybe (ThreadId m, StrictTMVar m ())))
-> Maybe (ThreadId m, StrictTMVar m ())
-> STM m (Maybe (ThreadId m, StrictTMVar m ()))
forall a b. (a -> b) -> a -> b
$ (ThreadId m, StrictTMVar m ())
-> Maybe (ThreadId m, StrictTMVar m ())
forall a. a -> Maybe a
Just (ThreadId m, StrictTMVar m ())
rc
case fetchclient_m of
Maybe (ThreadId m, StrictTMVar m ())
Nothing -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Just (ThreadId m
tid, StrictTMVar m ()
doneVar) -> do
ThreadId m -> AsyncCancelled -> m ()
forall e. Exception e => ThreadId m -> e -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid AsyncCancelled
AsyncCancelled
STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
StrictTMVar m () -> STM m ()
forall (m :: * -> *) a. MonadSTM m => StrictTMVar m a -> STM m a
readTMVar StrictTMVar m ()
doneVar
StrictTVar m (Map peer PeerGSV)
-> (Map peer PeerGSV -> Map peer PeerGSV) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar StrictTVar m (Map peer PeerGSV)
dqRegistry ((Map peer PeerGSV -> Map peer PeerGSV) -> STM m ())
-> (Map peer PeerGSV -> Map peer PeerGSV) -> STM m ()
forall a b. (a -> b) -> a -> b
$ \Map peer PeerGSV
m ->
Bool -> Map peer PeerGSV -> Map peer PeerGSV
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Map peer PeerGSV -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map peer PeerGSV
m) (Map peer PeerGSV -> Map peer PeerGSV)
-> Map peer PeerGSV -> Map peer PeerGSV
forall a b. (a -> b) -> a -> b
$
peer -> Map peer PeerGSV -> Map peer PeerGSV
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete peer
peer Map peer PeerGSV
m
StrictTVar m (Set peer) -> (Set peer -> Set peer) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar StrictTVar m (Set peer)
dyingRegistry ((Set peer -> Set peer) -> STM m ())
-> (Set peer -> Set peer) -> STM m ()
forall a b. (a -> b) -> a -> b
$ \Set peer
s ->
Bool -> Set peer -> Set peer
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Set peer -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set peer
s) (Set peer -> Set peer) -> Set peer -> Set peer
forall a b. (a -> b) -> a -> b
$
peer -> Set peer -> Set peer
forall a. Ord a => a -> Set a -> Set a
Set.delete peer
peer Set peer
s
readPeerGSVs :: forall m peer.
( MonadSTM m, Ord peer)
=> KeepAliveRegistry peer m
-> STM m (Map peer PeerGSV)
readPeerGSVs :: forall (m :: * -> *) peer.
(MonadSTM m, Ord peer) =>
KeepAliveRegistry peer m -> STM m (Map peer PeerGSV)
readPeerGSVs KeepAliveRegistry { StrictTVar m (Map peer PeerGSV)
dqRegistry :: forall peer (m :: * -> *).
KeepAliveRegistry peer m -> StrictTVar m (Map peer PeerGSV)
dqRegistry :: StrictTVar m (Map peer PeerGSV)
dqRegistry, StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry :: forall peer (m :: * -> *).
KeepAliveRegistry peer m
-> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry :: StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry } = do
dr <- StrictTVar m (Map peer PeerGSV) -> STM m (Map peer PeerGSV)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (Map peer PeerGSV)
dqRegistry
kr <- readTVar keepRegistry
return $ Map.intersection dr kr