{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module DMQ.Diffusion.NodeKernel
( NodeKernel (..)
, withNodeKernel
) where
import Control.Concurrent.Class.MonadMVar
import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI
import Data.Function (on)
import Data.Sequence qualified as Seq
import Data.Time.Clock.POSIX (POSIXTime)
import Data.Time.Clock.POSIX qualified as Time
import Data.Void (Void)
import System.Random (StdGen)
import System.Random qualified as Random
import Ouroboros.Network.BlockFetch (FetchClientRegistry,
newFetchClientRegistry)
import Ouroboros.Network.ConnectionId (ConnectionId (..))
import Ouroboros.Network.PeerSelection.Governor.Types
(makePublicPeerSelectionStateVar)
import Ouroboros.Network.PeerSharing (PeerSharingAPI, PeerSharingRegistry,
newPeerSharingAPI, newPeerSharingRegistry,
ps_POLICY_PEER_SHARE_MAX_PEERS, ps_POLICY_PEER_SHARE_STICKY_TIME)
import Ouroboros.Network.TxSubmission.Inbound.V2.Registry
import Ouroboros.Network.TxSubmission.Mempool.Simple (Mempool (..))
import Ouroboros.Network.TxSubmission.Mempool.Simple qualified as Mempool
import DMQ.Protocol.SigSubmission.Type (Sig (..), SigId)
data NodeKernel ntnAddr m =
NodeKernel {
forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m
-> FetchClientRegistry (ConnectionId ntnAddr) () () m
fetchClientRegistry :: FetchClientRegistry (ConnectionId ntnAddr) () () m
, forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m -> PeerSharingRegistry ntnAddr m
peerSharingRegistry :: PeerSharingRegistry ntnAddr m
, forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m -> PeerSharingAPI ntnAddr StdGen m
peerSharingAPI :: PeerSharingAPI ntnAddr StdGen m
, forall ntnAddr (m :: * -> *). NodeKernel ntnAddr m -> Mempool m Sig
mempool :: Mempool m Sig
, forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m -> TxChannelsVar m ntnAddr SigId Sig
sigChannelVar :: TxChannelsVar m ntnAddr SigId Sig
, forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m -> TxMempoolSem m
sigMempoolSem :: TxMempoolSem m
, forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m -> SharedTxStateVar m ntnAddr SigId Sig
sigSharedTxStateVar :: SharedTxStateVar m ntnAddr SigId Sig
}
newNodeKernel :: ( MonadLabelledSTM m
, MonadMVar m
, Ord ntnAddr
)
=> StdGen
-> m (NodeKernel ntnAddr m)
newNodeKernel :: forall (m :: * -> *) ntnAddr.
(MonadLabelledSTM m, MonadMVar m, Ord ntnAddr) =>
StdGen -> m (NodeKernel ntnAddr m)
newNodeKernel StdGen
rng = do
publicPeerSelectionStateVar <- m (StrictTVar m (PublicPeerSelectionState ntnAddr))
forall (m :: * -> *) peeraddr.
(MonadSTM m, Ord peeraddr) =>
m (StrictTVar m (PublicPeerSelectionState peeraddr))
makePublicPeerSelectionStateVar
fetchClientRegistry <- newFetchClientRegistry
peerSharingRegistry <- newPeerSharingRegistry
mempool <- Mempool.empty
sigChannelVar <- newTxChannelsVar
sigMempoolSem <- newTxMempoolSem
let (rng', rng'') = Random.split rng
sigSharedTxStateVar <- newSharedTxStateVar rng'
peerSharingAPI <-
newPeerSharingAPI
publicPeerSelectionStateVar
rng''
ps_POLICY_PEER_SHARE_STICKY_TIME
ps_POLICY_PEER_SHARE_MAX_PEERS
pure NodeKernel { fetchClientRegistry
, peerSharingRegistry
, peerSharingAPI
, mempool
, sigChannelVar
, sigMempoolSem
, sigSharedTxStateVar
}
withNodeKernel :: ( MonadAsync m
, MonadFork m
, MonadDelay m
, MonadLabelledSTM m
, MonadMask m
, MonadMVar m
, MonadTime m
, Ord ntnAddr
)
=> StdGen
-> (NodeKernel ntnAddr m -> m a)
-> m a
withNodeKernel :: forall (m :: * -> *) ntnAddr a.
(MonadAsync m, MonadFork m, MonadDelay m, MonadLabelledSTM m,
MonadMask m, MonadMVar m, MonadTime m, Ord ntnAddr) =>
StdGen -> (NodeKernel ntnAddr m -> m a) -> m a
withNodeKernel StdGen
rng NodeKernel ntnAddr m -> m a
k = do
nodeKernel@NodeKernel { mempool } <- StdGen -> m (NodeKernel ntnAddr m)
forall (m :: * -> *) ntnAddr.
(MonadLabelledSTM m, MonadMVar m, Ord ntnAddr) =>
StdGen -> m (NodeKernel ntnAddr m)
newNodeKernel StdGen
rng
withAsync (mempoolWorker mempool)
$ \Async m Void
thread -> Async m Void -> m ()
forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
Async m a -> m ()
link Async m Void
thread
m () -> m a -> m a
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> NodeKernel ntnAddr m -> m a
k NodeKernel ntnAddr m
nodeKernel
mempoolWorker :: ( MonadDelay m
, MonadSTM m
, MonadTime m
)
=> Mempool m Sig
-> m Void
mempoolWorker :: forall (m :: * -> *).
(MonadDelay m, MonadSTM m, MonadTime m) =>
Mempool m Sig -> m Void
mempoolWorker (Mempool StrictTVar m (Seq Sig)
v) = m Void
forall {b}. m b
loop
where
loop :: m b
loop = do
now <- m POSIXTime
forall (m :: * -> *). MonadTime m => m POSIXTime
getCurrentPOSIXTime
rt <- atomically $ do
(sigs :: Seq.Seq Sig) <- readTVar v
let sigs' :: Seq.Seq Sig
(resumeTime, sigs') =
foldr (\Sig
a (POSIXTime
rt, Seq Sig
as) -> if Sig -> POSIXTime
sigExpiresAt Sig
a POSIXTime -> POSIXTime -> Bool
forall a. Ord a => a -> a -> Bool
<= POSIXTime
now
then (POSIXTime
rt, Seq Sig
as)
else (POSIXTime
rt POSIXTime -> POSIXTime -> POSIXTime
forall a. Ord a => a -> a -> a
`min` Sig -> POSIXTime
sigExpiresAt Sig
a, Sig
a Sig -> Seq Sig -> Seq Sig
forall a. a -> Seq a -> Seq a
Seq.<| Seq Sig
as))
(now, Seq.empty)
sigs
writeTVar v sigs'
return resumeTime
now' <- getCurrentPOSIXTime
threadDelay $ rt `diffPOSIXTime` now' `max` _MEMPOOL_WORKER_MIN_DELAY
loop
_MEMPOOL_WORKER_MIN_DELAY :: DiffTime
_MEMPOOL_WORKER_MIN_DELAY :: DiffTime
_MEMPOOL_WORKER_MIN_DELAY = DiffTime
0.05
getCurrentPOSIXTime :: MonadTime m
=> m POSIXTime
getCurrentPOSIXTime :: forall (m :: * -> *). MonadTime m => m POSIXTime
getCurrentPOSIXTime = UTCTime -> POSIXTime
Time.utcTimeToPOSIXSeconds (UTCTime -> POSIXTime) -> m UTCTime -> m POSIXTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m UTCTime
forall (m :: * -> *). MonadTime m => m UTCTime
getCurrentTime
diffPOSIXTime :: POSIXTime -> POSIXTime -> DiffTime
diffPOSIXTime :: POSIXTime -> POSIXTime -> DiffTime
diffPOSIXTime = (Time -> Time -> DiffTime)
-> (POSIXTime -> Time) -> POSIXTime -> POSIXTime -> DiffTime
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
on Time -> Time -> DiffTime
diffTime (DiffTime -> Time
Time (DiffTime -> Time) -> (POSIXTime -> DiffTime) -> POSIXTime -> Time
forall b c a. (b -> c) -> (a -> b) -> a -> c
. POSIXTime -> DiffTime
posixTimeToDiffTime)
where
posixTimeToDiffTime :: POSIXTime -> DiffTime
posixTimeToDiffTime :: POSIXTime -> DiffTime
posixTimeToDiffTime = POSIXTime -> DiffTime
forall a b. (Real a, Fractional b) => a -> b
realToFrac