{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module DMQ.Diffusion.NodeKernel
  ( NodeKernel (..)
  , withNodeKernel
  ) where

import Control.Concurrent.Class.MonadMVar
import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI
import Control.Tracer (Tracer, nullTracer)

import Data.Aeson qualified as Aeson
import Data.Function (on)
import Data.Functor.Contravariant ((>$<))
import Data.Hashable
import Data.Sequence qualified as Seq
import Data.Time.Clock.POSIX (POSIXTime)
import Data.Time.Clock.POSIX qualified as Time
import Data.Void (Void)
import System.Random (StdGen)
import System.Random qualified as Random

import Cardano.KESAgent.KES.Crypto (Crypto (..))

import Ouroboros.Network.BlockFetch (FetchClientRegistry,
           newFetchClientRegistry)
import Ouroboros.Network.ConnectionId (ConnectionId (..))
import Ouroboros.Network.PeerSelection.Governor.Types
           (makePublicPeerSelectionStateVar)
import Ouroboros.Network.PeerSharing (PeerSharingAPI, PeerSharingRegistry,
           newPeerSharingAPI, newPeerSharingRegistry,
           ps_POLICY_PEER_SHARE_MAX_PEERS, ps_POLICY_PEER_SHARE_STICKY_TIME)
import Ouroboros.Network.TxSubmission.Inbound.V2
import Ouroboros.Network.TxSubmission.Mempool.Simple (Mempool (..))
import Ouroboros.Network.TxSubmission.Mempool.Simple qualified as Mempool

import DMQ.Configuration
import DMQ.Protocol.SigSubmission.Type (Sig (sigExpiresAt), SigId)
import DMQ.Tracer


data NodeKernel crypto ntnAddr m =
  NodeKernel {
    -- | The fetch client registry, used for the keep alive clients.
    forall crypto ntnAddr (m :: * -> *).
NodeKernel crypto ntnAddr m
-> FetchClientRegistry (ConnectionId ntnAddr) () () m
fetchClientRegistry :: !(FetchClientRegistry (ConnectionId ntnAddr) () () m)

    -- | Read the current peer sharing registry, used for interacting with
    -- the PeerSharing protocol
  , forall crypto ntnAddr (m :: * -> *).
NodeKernel crypto ntnAddr m -> PeerSharingRegistry ntnAddr m
peerSharingRegistry :: !(PeerSharingRegistry ntnAddr m)
  , forall crypto ntnAddr (m :: * -> *).
NodeKernel crypto ntnAddr m -> PeerSharingAPI ntnAddr StdGen m
peerSharingAPI      :: !(PeerSharingAPI ntnAddr StdGen m)
  , forall crypto ntnAddr (m :: * -> *).
NodeKernel crypto ntnAddr m -> Mempool m (Sig crypto)
mempool             :: !(Mempool m (Sig crypto))
  , forall crypto ntnAddr (m :: * -> *).
NodeKernel crypto ntnAddr m
-> TxChannelsVar m ntnAddr SigId (Sig crypto)
sigChannelVar       :: !(TxChannelsVar m ntnAddr SigId (Sig crypto))
  , forall crypto ntnAddr (m :: * -> *).
NodeKernel crypto ntnAddr m -> TxMempoolSem m
sigMempoolSem       :: !(TxMempoolSem m)
  , forall crypto ntnAddr (m :: * -> *).
NodeKernel crypto ntnAddr m
-> SharedTxStateVar m ntnAddr SigId (Sig crypto)
sigSharedTxStateVar :: !(SharedTxStateVar m ntnAddr SigId (Sig crypto))
  }

newNodeKernel :: ( MonadLabelledSTM m
                 , MonadMVar m
                 , Ord ntnAddr
                 )
              => StdGen
              -> m (NodeKernel crypto ntnAddr m)
newNodeKernel :: forall (m :: * -> *) ntnAddr crypto.
(MonadLabelledSTM m, MonadMVar m, Ord ntnAddr) =>
StdGen -> m (NodeKernel crypto ntnAddr m)
newNodeKernel StdGen
rng = do
  publicPeerSelectionStateVar <- m (StrictTVar m (PublicPeerSelectionState ntnAddr))
forall (m :: * -> *) peeraddr.
(MonadSTM m, Ord peeraddr) =>
m (StrictTVar m (PublicPeerSelectionState peeraddr))
makePublicPeerSelectionStateVar

  fetchClientRegistry <- newFetchClientRegistry
  peerSharingRegistry <- newPeerSharingRegistry

  mempool <- Mempool.empty
  sigChannelVar <- newTxChannelsVar
  sigMempoolSem <- newTxMempoolSem
  let (rng', rng'') = Random.split rng
  sigSharedTxStateVar <- newSharedTxStateVar rng'

  peerSharingAPI <-
    newPeerSharingAPI
      publicPeerSelectionStateVar
      rng''
      ps_POLICY_PEER_SHARE_STICKY_TIME
      ps_POLICY_PEER_SHARE_MAX_PEERS

  pure NodeKernel { fetchClientRegistry
                  , peerSharingRegistry
                  , peerSharingAPI
                  , mempool
                  , sigChannelVar
                  , sigMempoolSem
                  , sigSharedTxStateVar
                  }


withNodeKernel :: forall crypto ntnAddr m a.
                  ( Crypto crypto
                  , MonadAsync       m
                  , MonadFork        m
                  , MonadDelay       m
                  , MonadLabelledSTM m
                  , MonadMask        m
                  , MonadMVar        m
                  , MonadTime        m
                  , Ord ntnAddr
                  , Show ntnAddr
                  , Hashable ntnAddr
                  )
               => (forall ev. Aeson.ToJSON ev => Tracer m (WithEventType ev))
               -> Configuration
               -> StdGen
               -> (NodeKernel crypto ntnAddr m -> m a)
               -- ^ as soon as the callback exits the `mempoolWorker` and all
               -- decision logic threads will be killed
               -> m a
withNodeKernel :: forall crypto ntnAddr (m :: * -> *) a.
(Crypto crypto, MonadAsync m, MonadFork m, MonadDelay m,
 MonadLabelledSTM m, MonadMask m, MonadMVar m, MonadTime m,
 Ord ntnAddr, Show ntnAddr, Hashable ntnAddr) =>
(forall ev. ToJSON ev => Tracer m (WithEventType ev))
-> Configuration
-> StdGen
-> (NodeKernel crypto ntnAddr m -> m a)
-> m a
withNodeKernel forall ev. ToJSON ev => Tracer m (WithEventType ev)
tracer
               Configuration {
                 dmqcSigSubmissionLogicTracer :: forall (f :: * -> *). Configuration' f -> f Bool
dmqcSigSubmissionLogicTracer = I Bool
sigSubmissionLogicTracer
               }
               StdGen
rng NodeKernel crypto ntnAddr m -> m a
k = do
  nodeKernel@NodeKernel { mempool,
                          sigChannelVar,
                          sigSharedTxStateVar
                        }
    <- StdGen -> m (NodeKernel crypto ntnAddr m)
forall (m :: * -> *) ntnAddr crypto.
(MonadLabelledSTM m, MonadMVar m, Ord ntnAddr) =>
StdGen -> m (NodeKernel crypto ntnAddr m)
newNodeKernel StdGen
rng
  withAsync (mempoolWorker mempool)
          $ \Async m Void
mempoolThread ->
    m Void -> (Async m Void -> m a) -> m a
forall a b. m a -> (Async m a -> m b) -> m b
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync (Tracer m (TraceTxLogic ntnAddr SigId (Sig crypto))
-> Tracer m TxSubmissionCounters
-> TxDecisionPolicy
-> TxChannelsVar m ntnAddr SigId (Sig crypto)
-> SharedTxStateVar m ntnAddr SigId (Sig crypto)
-> m Void
forall (m :: * -> *) peeraddr txid tx.
(MonadDelay m, MonadMVar m, MonadMask m, MonadAsync m, MonadFork m,
 Ord peeraddr, Ord txid, Hashable peeraddr) =>
Tracer m (TraceTxLogic peeraddr txid tx)
-> Tracer m TxSubmissionCounters
-> TxDecisionPolicy
-> TxChannelsVar m peeraddr txid tx
-> SharedTxStateVar m peeraddr txid tx
-> m Void
decisionLogicThreads
                (if Bool
sigSubmissionLogicTracer
                   then String
-> TraceTxLogic ntnAddr SigId (Sig crypto)
-> WithEventType (TraceTxLogic ntnAddr SigId (Sig crypto))
forall a. String -> a -> WithEventType a
WithEventType String
"SigSubmission.Logic" (TraceTxLogic ntnAddr SigId (Sig crypto)
 -> WithEventType (TraceTxLogic ntnAddr SigId (Sig crypto)))
-> Tracer
     m (WithEventType (TraceTxLogic ntnAddr SigId (Sig crypto)))
-> Tracer m (TraceTxLogic ntnAddr SigId (Sig crypto))
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
>$< Tracer m (WithEventType (TraceTxLogic ntnAddr SigId (Sig crypto)))
forall ev. ToJSON ev => Tracer m (WithEventType ev)
tracer
                   else Tracer m (TraceTxLogic ntnAddr SigId (Sig crypto))
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer)
                Tracer m TxSubmissionCounters
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer
                TxDecisionPolicy
defaultSigDecisionPolicy
                TxChannelsVar m ntnAddr SigId (Sig crypto)
sigChannelVar
                SharedTxStateVar m ntnAddr SigId (Sig crypto)
sigSharedTxStateVar)
            ((Async m Void -> m a) -> m a) -> (Async m Void -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \Async m Void
sigLogicThread
      -> Async m Void -> m ()
forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
Async m a -> m ()
link Async m Void
mempoolThread
      m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Async m Void -> m ()
forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
Async m a -> m ()
link Async m Void
sigLogicThread
      m () -> m a -> m a
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> NodeKernel crypto ntnAddr m -> m a
k NodeKernel crypto ntnAddr m
nodeKernel


mempoolWorker :: forall crypto m.
                 ( MonadDelay m
                 , MonadSTM   m
                 , MonadTime  m
                 )
              => Mempool m (Sig crypto)
              -> m Void
mempoolWorker :: forall crypto (m :: * -> *).
(MonadDelay m, MonadSTM m, MonadTime m) =>
Mempool m (Sig crypto) -> m Void
mempoolWorker (Mempool StrictTVar m (Seq (Sig crypto))
v) = m Void
loop
  where
    loop :: m Void
loop = do
      now <- m POSIXTime
forall (m :: * -> *). MonadTime m => m POSIXTime
getCurrentPOSIXTime
      rt <- atomically $ do
        (sigs :: Seq.Seq (Sig crypto)) <- readTVar v
        let sigs' :: Seq.Seq (Sig crypto)
            (resumeTime, sigs') =
              foldr (\Sig crypto
a (POSIXTime
rt, Seq (Sig crypto)
as) -> if Sig crypto -> POSIXTime
forall crypto. Sig crypto -> POSIXTime
sigExpiresAt Sig crypto
a POSIXTime -> POSIXTime -> Bool
forall a. Ord a => a -> a -> Bool
<= POSIXTime
now
                                    then (POSIXTime
rt, Seq (Sig crypto)
as)
                                    else (POSIXTime
rt POSIXTime -> POSIXTime -> POSIXTime
forall a. Ord a => a -> a -> a
`min` Sig crypto -> POSIXTime
forall crypto. Sig crypto -> POSIXTime
sigExpiresAt Sig crypto
a, Sig crypto
a Sig crypto -> Seq (Sig crypto) -> Seq (Sig crypto)
forall a. a -> Seq a -> Seq a
Seq.<| Seq (Sig crypto)
as))
                    (now, Seq.empty)
                    sigs
        writeTVar v sigs'
        return resumeTime

      now' <- getCurrentPOSIXTime
      threadDelay $ rt `diffPOSIXTime` now' `max` _MEMPOOL_WORKER_MIN_DELAY

      loop



_MEMPOOL_WORKER_MIN_DELAY :: DiffTime
_MEMPOOL_WORKER_MIN_DELAY :: DiffTime
_MEMPOOL_WORKER_MIN_DELAY = DiffTime
0.05


--
-- POSIXTime utils
--


getCurrentPOSIXTime :: MonadTime m
                    => m POSIXTime
getCurrentPOSIXTime :: forall (m :: * -> *). MonadTime m => m POSIXTime
getCurrentPOSIXTime = UTCTime -> POSIXTime
Time.utcTimeToPOSIXSeconds (UTCTime -> POSIXTime) -> m UTCTime -> m POSIXTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m UTCTime
forall (m :: * -> *). MonadTime m => m UTCTime
getCurrentTime


diffPOSIXTime :: POSIXTime -> POSIXTime -> DiffTime
diffPOSIXTime :: POSIXTime -> POSIXTime -> DiffTime
diffPOSIXTime = (Time -> Time -> DiffTime)
-> (POSIXTime -> Time) -> POSIXTime -> POSIXTime -> DiffTime
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
on Time -> Time -> DiffTime
diffTime (DiffTime -> Time
Time (DiffTime -> Time) -> (POSIXTime -> DiffTime) -> POSIXTime -> Time
forall b c a. (b -> c) -> (a -> b) -> a -> c
. POSIXTime -> DiffTime
posixTimeToDiffTime)
  where
    posixTimeToDiffTime :: POSIXTime -> DiffTime
    posixTimeToDiffTime :: POSIXTime -> DiffTime
posixTimeToDiffTime = POSIXTime -> DiffTime
forall a b. (Real a, Fractional b) => a -> b
realToFrac