{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}

module DMQ.Diffusion.NodeKernel
  ( NodeKernel (..)
  , withNodeKernel
  ) where

import Control.Concurrent.Class.MonadMVar
import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI

import Data.Function (on)
import Data.Sequence qualified as Seq
import Data.Time.Clock.POSIX (POSIXTime)
import Data.Time.Clock.POSIX qualified as Time
import Data.Void (Void)
import System.Random (StdGen)
import System.Random qualified as Random

import Ouroboros.Network.BlockFetch (FetchClientRegistry,
           newFetchClientRegistry)
import Ouroboros.Network.ConnectionId (ConnectionId (..))
import Ouroboros.Network.PeerSelection.Governor.Types
           (makePublicPeerSelectionStateVar)
import Ouroboros.Network.PeerSharing (PeerSharingAPI, PeerSharingRegistry,
           newPeerSharingAPI, newPeerSharingRegistry,
           ps_POLICY_PEER_SHARE_MAX_PEERS, ps_POLICY_PEER_SHARE_STICKY_TIME)
import Ouroboros.Network.TxSubmission.Inbound.V2.Registry
import Ouroboros.Network.TxSubmission.Mempool.Simple (Mempool (..))
import Ouroboros.Network.TxSubmission.Mempool.Simple qualified as Mempool

import DMQ.Protocol.SigSubmission.Type (Sig (..), SigId)


data NodeKernel ntnAddr m =
  NodeKernel {
    -- | The fetch client registry, used for the keep alive clients.
    forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m
-> FetchClientRegistry (ConnectionId ntnAddr) () () m
fetchClientRegistry :: FetchClientRegistry (ConnectionId ntnAddr) () () m

    -- | Read the current peer sharing registry, used for interacting with
    -- the PeerSharing protocol
  , forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m -> PeerSharingRegistry ntnAddr m
peerSharingRegistry :: PeerSharingRegistry ntnAddr m
  , forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m -> PeerSharingAPI ntnAddr StdGen m
peerSharingAPI      :: PeerSharingAPI ntnAddr StdGen m
  , forall ntnAddr (m :: * -> *). NodeKernel ntnAddr m -> Mempool m Sig
mempool             :: Mempool m Sig
  , forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m -> TxChannelsVar m ntnAddr SigId Sig
sigChannelVar       :: TxChannelsVar m ntnAddr SigId Sig
  , forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m -> TxMempoolSem m
sigMempoolSem       :: TxMempoolSem m
  , forall ntnAddr (m :: * -> *).
NodeKernel ntnAddr m -> SharedTxStateVar m ntnAddr SigId Sig
sigSharedTxStateVar :: SharedTxStateVar m ntnAddr SigId Sig
  }

newNodeKernel :: ( MonadLabelledSTM m
                 , MonadMVar m
                 , Ord ntnAddr
                 )
              => StdGen
              -> m (NodeKernel ntnAddr m)
newNodeKernel :: forall (m :: * -> *) ntnAddr.
(MonadLabelledSTM m, MonadMVar m, Ord ntnAddr) =>
StdGen -> m (NodeKernel ntnAddr m)
newNodeKernel StdGen
rng = do
  publicPeerSelectionStateVar <- m (StrictTVar m (PublicPeerSelectionState ntnAddr))
forall (m :: * -> *) peeraddr.
(MonadSTM m, Ord peeraddr) =>
m (StrictTVar m (PublicPeerSelectionState peeraddr))
makePublicPeerSelectionStateVar

  fetchClientRegistry <- newFetchClientRegistry
  peerSharingRegistry <- newPeerSharingRegistry

  mempool <- Mempool.empty
  sigChannelVar <- newTxChannelsVar
  sigMempoolSem <- newTxMempoolSem
  let (rng', rng'') = Random.split rng
  sigSharedTxStateVar <- newSharedTxStateVar rng'

  peerSharingAPI <-
    newPeerSharingAPI
      publicPeerSelectionStateVar
      rng''
      ps_POLICY_PEER_SHARE_STICKY_TIME
      ps_POLICY_PEER_SHARE_MAX_PEERS

  pure NodeKernel { fetchClientRegistry
                  , peerSharingRegistry
                  , peerSharingAPI
                  , mempool
                  , sigChannelVar
                  , sigMempoolSem
                  , sigSharedTxStateVar
                  }


withNodeKernel :: ( MonadAsync       m
                  , MonadFork        m
                  , MonadDelay       m
                  , MonadLabelledSTM m
                  , MonadMask        m
                  , MonadMVar        m
                  , MonadTime        m
                  , Ord ntnAddr
                  )
               => StdGen
               -> (NodeKernel ntnAddr m -> m a)
               -- ^ as soon as the callback exits the `mempoolWorker` will be
               -- killed
               -> m a
withNodeKernel :: forall (m :: * -> *) ntnAddr a.
(MonadAsync m, MonadFork m, MonadDelay m, MonadLabelledSTM m,
 MonadMask m, MonadMVar m, MonadTime m, Ord ntnAddr) =>
StdGen -> (NodeKernel ntnAddr m -> m a) -> m a
withNodeKernel StdGen
rng NodeKernel ntnAddr m -> m a
k = do
  nodeKernel@NodeKernel { mempool } <- StdGen -> m (NodeKernel ntnAddr m)
forall (m :: * -> *) ntnAddr.
(MonadLabelledSTM m, MonadMVar m, Ord ntnAddr) =>
StdGen -> m (NodeKernel ntnAddr m)
newNodeKernel StdGen
rng
  withAsync (mempoolWorker mempool)
    $ \Async m Void
thread -> Async m Void -> m ()
forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
Async m a -> m ()
link Async m Void
thread
              m () -> m a -> m a
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> NodeKernel ntnAddr m -> m a
k NodeKernel ntnAddr m
nodeKernel


mempoolWorker :: ( MonadDelay m
                 , MonadSTM   m
                 , MonadTime  m
                 )
              => Mempool m Sig
              -> m Void
mempoolWorker :: forall (m :: * -> *).
(MonadDelay m, MonadSTM m, MonadTime m) =>
Mempool m Sig -> m Void
mempoolWorker (Mempool StrictTVar m (Seq Sig)
v) = m Void
forall {b}. m b
loop
  where
    loop :: m b
loop = do
      now <- m POSIXTime
forall (m :: * -> *). MonadTime m => m POSIXTime
getCurrentPOSIXTime
      rt <- atomically $ do
        (sigs :: Seq.Seq Sig) <- readTVar v
        let sigs' :: Seq.Seq Sig
            (resumeTime, sigs') =
              foldr (\Sig
a (POSIXTime
rt, Seq Sig
as) -> if Sig -> POSIXTime
sigExpiresAt Sig
a POSIXTime -> POSIXTime -> Bool
forall a. Ord a => a -> a -> Bool
<= POSIXTime
now
                                    then (POSIXTime
rt, Seq Sig
as)
                                    else (POSIXTime
rt POSIXTime -> POSIXTime -> POSIXTime
forall a. Ord a => a -> a -> a
`min` Sig -> POSIXTime
sigExpiresAt Sig
a, Sig
a Sig -> Seq Sig -> Seq Sig
forall a. a -> Seq a -> Seq a
Seq.<| Seq Sig
as))
                    (now, Seq.empty)
                    sigs
        writeTVar v sigs'
        return resumeTime

      now' <- getCurrentPOSIXTime
      threadDelay $ rt `diffPOSIXTime` now' `max` _MEMPOOL_WORKER_MIN_DELAY

      loop



_MEMPOOL_WORKER_MIN_DELAY :: DiffTime
_MEMPOOL_WORKER_MIN_DELAY :: DiffTime
_MEMPOOL_WORKER_MIN_DELAY = DiffTime
0.05


--
-- POSIXTime utils
--


getCurrentPOSIXTime :: MonadTime m
                    => m POSIXTime
getCurrentPOSIXTime :: forall (m :: * -> *). MonadTime m => m POSIXTime
getCurrentPOSIXTime = UTCTime -> POSIXTime
Time.utcTimeToPOSIXSeconds (UTCTime -> POSIXTime) -> m UTCTime -> m POSIXTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m UTCTime
forall (m :: * -> *). MonadTime m => m UTCTime
getCurrentTime


diffPOSIXTime :: POSIXTime -> POSIXTime -> DiffTime
diffPOSIXTime :: POSIXTime -> POSIXTime -> DiffTime
diffPOSIXTime = (Time -> Time -> DiffTime)
-> (POSIXTime -> Time) -> POSIXTime -> POSIXTime -> DiffTime
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
on Time -> Time -> DiffTime
diffTime (DiffTime -> Time
Time (DiffTime -> Time) -> (POSIXTime -> DiffTime) -> POSIXTime -> Time
forall b c a. (b -> c) -> (a -> b) -> a -> c
. POSIXTime -> DiffTime
posixTimeToDiffTime)
  where
    posixTimeToDiffTime :: POSIXTime -> DiffTime
    posixTimeToDiffTime :: POSIXTime -> DiffTime
posixTimeToDiffTime = POSIXTime -> DiffTime
forall a b. (Real a, Fractional b) => a -> b
realToFrac