{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Ouroboros.Network.KeepAlive.Registry
  ( KeepAliveRegistry (..)
  , newKeepAliveRegistry
  , bracketKeepAliveClient
  , readPeerGSVs
  ) where

import Data.Map (Map)
import Data.Map qualified as Map
import Data.Set (Set)
import Data.Set qualified as Set

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Exception (assert)
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadThrow

import Ouroboros.Network.DeltaQ

-- | A registry which keeps `PeerGSV` information based on `keep-alive`
-- measurements.
--
data KeepAliveRegistry peer m = KeepAliveRegistry {
       forall peer (m :: * -> *).
KeepAliveRegistry peer m -> StrictTVar m (Map peer PeerGSV)
dqRegistry
         :: StrictTVar  m (Map peer PeerGSV),
       forall peer (m :: * -> *).
KeepAliveRegistry peer m
-> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry
         :: StrictTVar  m (Map peer (ThreadId m, StrictTMVar m ())),
       forall peer (m :: * -> *).
KeepAliveRegistry peer m -> StrictTVar m (Set peer)
dyingRegistry
         :: StrictTVar m (Set peer)
     }

newKeepAliveRegistry :: MonadSTM m
                     => m (KeepAliveRegistry peer m)
newKeepAliveRegistry :: forall (m :: * -> *) peer.
MonadSTM m =>
m (KeepAliveRegistry peer m)
newKeepAliveRegistry = StrictTVar m (Map peer PeerGSV)
-> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
-> StrictTVar m (Set peer)
-> KeepAliveRegistry peer m
forall peer (m :: * -> *).
StrictTVar m (Map peer PeerGSV)
-> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
-> StrictTVar m (Set peer)
-> KeepAliveRegistry peer m
KeepAliveRegistry (StrictTVar m (Map peer PeerGSV)
 -> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
 -> StrictTVar m (Set peer)
 -> KeepAliveRegistry peer m)
-> m (StrictTVar m (Map peer PeerGSV))
-> m (StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
      -> StrictTVar m (Set peer) -> KeepAliveRegistry peer m)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map peer PeerGSV -> m (StrictTVar m (Map peer PeerGSV))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO Map peer PeerGSV
forall k a. Map k a
Map.empty
                                         m (StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
   -> StrictTVar m (Set peer) -> KeepAliveRegistry peer m)
-> m (StrictTVar m (Map peer (ThreadId m, StrictTMVar m ())))
-> m (StrictTVar m (Set peer) -> KeepAliveRegistry peer m)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Map peer (ThreadId m, StrictTMVar m ())
-> m (StrictTVar m (Map peer (ThreadId m, StrictTMVar m ())))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO Map peer (ThreadId m, StrictTMVar m ())
forall k a. Map k a
Map.empty
                                         m (StrictTVar m (Set peer) -> KeepAliveRegistry peer m)
-> m (StrictTVar m (Set peer)) -> m (KeepAliveRegistry peer m)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Set peer -> m (StrictTVar m (Set peer))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO Set peer
forall a. Set a
Set.empty

bracketKeepAliveClient :: forall m a peer.
                              (MonadSTM m, MonadFork m, MonadMask m, Ord peer)
                       => KeepAliveRegistry peer m
                       -> peer
                       -> (StrictTVar m (Map peer PeerGSV) -> m a)
                       -> m a
bracketKeepAliveClient :: forall (m :: * -> *) a peer.
(MonadSTM m, MonadFork m, MonadMask m, Ord peer) =>
KeepAliveRegistry peer m
-> peer -> (StrictTVar m (Map peer PeerGSV) -> m a) -> m a
bracketKeepAliveClient KeepAliveRegistry { StrictTVar m (Map peer PeerGSV)
dqRegistry :: forall peer (m :: * -> *).
KeepAliveRegistry peer m -> StrictTVar m (Map peer PeerGSV)
dqRegistry :: StrictTVar m (Map peer PeerGSV)
dqRegistry, StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry :: forall peer (m :: * -> *).
KeepAliveRegistry peer m
-> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry :: StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry, StrictTVar m (Set peer)
dyingRegistry :: forall peer (m :: * -> *).
KeepAliveRegistry peer m -> StrictTVar m (Set peer)
dyingRegistry :: StrictTVar m (Set peer)
dyingRegistry } peer
peer StrictTVar m (Map peer PeerGSV) -> m a
action = do
    m () -> m () -> m a -> m a
forall a b c. m a -> m b -> m c -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> m b -> m c -> m c
bracket_ m ()
register m ()
unregister (StrictTVar m (Map peer PeerGSV) -> m a
action StrictTVar m (Map peer PeerGSV)
dqRegistry)
  where
    -- the keepAliveClient will register a PeerGSV and the block fetch client will wait on it.
    register :: m ()
    register :: m ()
register =
      STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        -- Wait for previous keep alive client to cleanup
        dr <- StrictTVar m (Map peer PeerGSV) -> STM m (Map peer PeerGSV)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (Map peer PeerGSV)
dqRegistry
        check (peer `Map.notMember` dr)

        modifyTVar dqRegistry $ \Map peer PeerGSV
m ->
          Bool -> Map peer PeerGSV -> Map peer PeerGSV
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Map peer PeerGSV -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.notMember` Map peer PeerGSV
m) (Map peer PeerGSV -> Map peer PeerGSV)
-> Map peer PeerGSV -> Map peer PeerGSV
forall a b. (a -> b) -> a -> b
$
          peer -> PeerGSV -> Map peer PeerGSV -> Map peer PeerGSV
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert peer
peer PeerGSV
defaultGSV Map peer PeerGSV
m

    -- It is possible for the keepAlive client to keep running even without a fetch client, but
    -- a fetch client shouldn't run without a keepAlive client.
    unregister :: m ()
    unregister :: m ()
unregister = m () -> m ()
forall a. m a -> m a
forall (m :: * -> *) a. MonadMask m => m a -> m a
uninterruptibleMask_ (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      fetchclient_m <- STM m (Maybe (ThreadId m, StrictTMVar m ()))
-> m (Maybe (ThreadId m, StrictTMVar m ()))
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Maybe (ThreadId m, StrictTMVar m ()))
 -> m (Maybe (ThreadId m, StrictTMVar m ())))
-> STM m (Maybe (ThreadId m, StrictTMVar m ()))
-> m (Maybe (ThreadId m, StrictTMVar m ()))
forall a b. (a -> b) -> a -> b
$ do
        fetchclients <- StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
-> STM m (Map peer (ThreadId m, StrictTMVar m ()))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry
        case Map.lookup peer fetchclients of
             Maybe (ThreadId m, StrictTMVar m ())
Nothing -> do
               -- If the fetch client is already dead we remove PeerGSV ourself directly.
               StrictTVar m (Map peer PeerGSV)
-> (Map peer PeerGSV -> Map peer PeerGSV) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar StrictTVar m (Map peer PeerGSV)
dqRegistry ((Map peer PeerGSV -> Map peer PeerGSV) -> STM m ())
-> (Map peer PeerGSV -> Map peer PeerGSV) -> STM m ()
forall a b. (a -> b) -> a -> b
$ \Map peer PeerGSV
m ->
                 Bool -> Map peer PeerGSV -> Map peer PeerGSV
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Map peer PeerGSV -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map peer PeerGSV
m) (Map peer PeerGSV -> Map peer PeerGSV)
-> Map peer PeerGSV -> Map peer PeerGSV
forall a b. (a -> b) -> a -> b
$
                 peer -> Map peer PeerGSV -> Map peer PeerGSV
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete peer
peer Map peer PeerGSV
m
               Maybe (ThreadId m, StrictTMVar m ())
-> STM m (Maybe (ThreadId m, StrictTMVar m ()))
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (ThreadId m, StrictTMVar m ())
forall a. Maybe a
Nothing
             Just (ThreadId m, StrictTMVar m ())
rc -> do
               -- Prevent a new fetchclient from starting while we are killing the old one.
               StrictTVar m (Set peer) -> (Set peer -> Set peer) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar StrictTVar m (Set peer)
dyingRegistry ((Set peer -> Set peer) -> STM m ())
-> (Set peer -> Set peer) -> STM m ()
forall a b. (a -> b) -> a -> b
$ \Set peer
s ->
                 Bool -> Set peer -> Set peer
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Set peer -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.notMember` Set peer
s) (Set peer -> Set peer) -> Set peer -> Set peer
forall a b. (a -> b) -> a -> b
$
                 peer -> Set peer -> Set peer
forall a. Ord a => a -> Set a -> Set a
Set.insert peer
peer Set peer
s
               Maybe (ThreadId m, StrictTMVar m ())
-> STM m (Maybe (ThreadId m, StrictTMVar m ()))
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ThreadId m, StrictTMVar m ())
 -> STM m (Maybe (ThreadId m, StrictTMVar m ())))
-> Maybe (ThreadId m, StrictTMVar m ())
-> STM m (Maybe (ThreadId m, StrictTMVar m ()))
forall a b. (a -> b) -> a -> b
$ (ThreadId m, StrictTMVar m ())
-> Maybe (ThreadId m, StrictTMVar m ())
forall a. a -> Maybe a
Just (ThreadId m, StrictTMVar m ())
rc
      case fetchclient_m of
           Maybe (ThreadId m, StrictTMVar m ())
Nothing -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
           Just (ThreadId m
tid, StrictTMVar m ()
doneVar) -> do
             -- Cancel the fetch client.
             ThreadId m -> AsyncCancelled -> m ()
forall e. Exception e => ThreadId m -> e -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid AsyncCancelled
AsyncCancelled
             STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
               -- wait for fetch client to exit.
               StrictTMVar m () -> STM m ()
forall (m :: * -> *) a. MonadSTM m => StrictTMVar m a -> STM m a
readTMVar StrictTMVar m ()
doneVar
               StrictTVar m (Map peer PeerGSV)
-> (Map peer PeerGSV -> Map peer PeerGSV) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar StrictTVar m (Map peer PeerGSV)
dqRegistry ((Map peer PeerGSV -> Map peer PeerGSV) -> STM m ())
-> (Map peer PeerGSV -> Map peer PeerGSV) -> STM m ()
forall a b. (a -> b) -> a -> b
$ \Map peer PeerGSV
m ->
                 Bool -> Map peer PeerGSV -> Map peer PeerGSV
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Map peer PeerGSV -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map peer PeerGSV
m) (Map peer PeerGSV -> Map peer PeerGSV)
-> Map peer PeerGSV -> Map peer PeerGSV
forall a b. (a -> b) -> a -> b
$
                 peer -> Map peer PeerGSV -> Map peer PeerGSV
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete peer
peer Map peer PeerGSV
m
               StrictTVar m (Set peer) -> (Set peer -> Set peer) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar StrictTVar m (Set peer)
dyingRegistry ((Set peer -> Set peer) -> STM m ())
-> (Set peer -> Set peer) -> STM m ()
forall a b. (a -> b) -> a -> b
$ \Set peer
s ->
                 Bool -> Set peer -> Set peer
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Set peer -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set peer
s) (Set peer -> Set peer) -> Set peer -> Set peer
forall a b. (a -> b) -> a -> b
$
                 peer -> Set peer -> Set peer
forall a. Ord a => a -> Set a -> Set a
Set.delete peer
peer Set peer
s

-- | A read-only 'STM' action to get the 'PeerGSV's for all fetch
-- clients in the 'FetchClientRegistry'.
--
readPeerGSVs :: forall m peer.
                ( MonadSTM m, Ord peer)
             => KeepAliveRegistry peer m
             -> STM m (Map peer PeerGSV)
readPeerGSVs :: forall (m :: * -> *) peer.
(MonadSTM m, Ord peer) =>
KeepAliveRegistry peer m -> STM m (Map peer PeerGSV)
readPeerGSVs KeepAliveRegistry { StrictTVar m (Map peer PeerGSV)
dqRegistry :: forall peer (m :: * -> *).
KeepAliveRegistry peer m -> StrictTVar m (Map peer PeerGSV)
dqRegistry :: StrictTVar m (Map peer PeerGSV)
dqRegistry, StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry :: forall peer (m :: * -> *).
KeepAliveRegistry peer m
-> StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry :: StrictTVar m (Map peer (ThreadId m, StrictTMVar m ()))
keepRegistry } = do
  dr <- StrictTVar m (Map peer PeerGSV) -> STM m (Map peer PeerGSV)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (Map peer PeerGSV)
dqRegistry
  kr <- readTVar keepRegistry
  -- The intersection gives us only the currently hot peers
  return $ Map.intersection dr kr