{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Ouroboros.Network.KeepAlive
  ( KeepAliveInterval (..)
  , keepAliveClient
  , keepAliveServer
  , TraceKeepAliveClient (..)
  ) where

import Control.Concurrent.Class.MonadSTM qualified as Lazy
import Control.Concurrent.Class.MonadSTM.Strict
import Control.Exception (assert)
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI
import Control.Tracer (Tracer, traceWith)
import Data.Map.Strict qualified as M
import Data.Maybe (fromJust)
import System.Random (StdGen, random)

import Ouroboros.Network.ControlMessage (ControlMessage (..), ControlMessageSTM)
import Ouroboros.Network.DeltaQ
import Ouroboros.Network.Protocol.KeepAlive.Client
import Ouroboros.Network.Protocol.KeepAlive.Server
import Ouroboros.Network.Protocol.KeepAlive.Type


newtype KeepAliveInterval = KeepAliveInterval { KeepAliveInterval -> DiffTime
keepAliveInterval :: DiffTime }

data TraceKeepAliveClient peer =
    AddSample peer DiffTime PeerGSV

instance Show peer => Show (TraceKeepAliveClient peer) where
    show :: TraceKeepAliveClient peer -> String
show (AddSample peer
peer DiffTime
rtt PeerGSV
gsv) = String
"AddSample " String -> ShowS
forall a. [a] -> [a] -> [a]
++ peer -> String
forall a. Show a => a -> String
show peer
peer String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" sample: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ DiffTime -> String
forall a. Show a => a -> String
show DiffTime
rtt
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" gsv: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PeerGSV -> String
forall a. Show a => a -> String
show PeerGSV
gsv

keepAliveClient
    :: forall m peer.
       ( MonadTimer m
       , Ord peer
       )
    => Tracer m (TraceKeepAliveClient peer)
    -> StdGen
    -> ControlMessageSTM m
    -> peer
    -> StrictTVar m (M.Map peer PeerGSV)
    -> KeepAliveInterval
    -> KeepAliveClient m ()
keepAliveClient :: forall (m :: * -> *) peer.
(MonadTimer m, Ord peer) =>
Tracer m (TraceKeepAliveClient peer)
-> StdGen
-> ControlMessageSTM m
-> peer
-> StrictTVar m (Map peer PeerGSV)
-> KeepAliveInterval
-> KeepAliveClient m ()
keepAliveClient Tracer m (TraceKeepAliveClient peer)
tracer StdGen
inRng ControlMessageSTM m
controlMessageSTM peer
peer StrictTVar m (Map peer PeerGSV)
dqCtx KeepAliveInterval { DiffTime
keepAliveInterval :: KeepAliveInterval -> DiffTime
keepAliveInterval :: DiffTime
keepAliveInterval } =
    let (Word16
cookie, StdGen
rng) = StdGen -> (Word16, StdGen)
forall g. RandomGen g => g -> (Word16, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random StdGen
inRng in
    m (KeepAliveClientSt m ()) -> KeepAliveClient m ()
forall (m :: * -> *) a.
m (KeepAliveClientSt m a) -> KeepAliveClient m a
KeepAliveClient (m (KeepAliveClientSt m ()) -> KeepAliveClient m ())
-> m (KeepAliveClientSt m ()) -> KeepAliveClient m ()
forall a b. (a -> b) -> a -> b
$ do
      Time
startTime <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
      KeepAliveClientSt m () -> m (KeepAliveClientSt m ())
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (KeepAliveClientSt m () -> m (KeepAliveClientSt m ()))
-> KeepAliveClientSt m () -> m (KeepAliveClientSt m ())
forall a b. (a -> b) -> a -> b
$ Cookie -> m (KeepAliveClientSt m ()) -> KeepAliveClientSt m ()
forall (m :: * -> *) a.
Cookie -> m (KeepAliveClientSt m a) -> KeepAliveClientSt m a
SendMsgKeepAlive (Word16 -> Cookie
Cookie Word16
cookie) (StdGen -> Time -> m (KeepAliveClientSt m ())
go StdGen
rng Time
startTime)
  where
    payloadSize :: SizeInBytes
payloadSize = SizeInBytes
2

    decisionSTM :: Lazy.TVar m Bool
                -> STM  m ControlMessage
    decisionSTM :: TVar m Bool -> ControlMessageSTM m
decisionSTM TVar m Bool
delayVar = do
       ControlMessage
controlMessage <- ControlMessageSTM m
controlMessageSTM
       case ControlMessage
controlMessage of
            ControlMessage
Terminate -> ControlMessage -> ControlMessageSTM m
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return ControlMessage
Terminate

            -- Continue
            ControlMessage
_  -> do
              Bool
done <- TVar m Bool -> STM m Bool
forall a. TVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
Lazy.readTVar TVar m Bool
delayVar
              if Bool
done
                 then ControlMessage -> ControlMessageSTM m
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return ControlMessage
Continue
                 else ControlMessageSTM m
forall a. STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry

    go :: StdGen -> Time -> m (KeepAliveClientSt m ())
    go :: StdGen -> Time -> m (KeepAliveClientSt m ())
go StdGen
rng Time
startTime = do
      Time
endTime <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
      let rtt :: DiffTime
rtt = Time -> Time -> DiffTime
diffTime Time
endTime Time
startTime
          sample :: PeerGSV
sample = Time -> Time -> SizeInBytes -> PeerGSV
fromSample Time
startTime Time
endTime SizeInBytes
payloadSize
      PeerGSV
gsv' <- STM m PeerGSV -> m PeerGSV
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m PeerGSV -> m PeerGSV) -> STM m PeerGSV -> m PeerGSV
forall a b. (a -> b) -> a -> b
$ do
          Map peer PeerGSV
m <- StrictTVar m (Map peer PeerGSV) -> STM m (Map peer PeerGSV)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (Map peer PeerGSV)
dqCtx
          Bool -> STM m PeerGSV -> STM m PeerGSV
forall a. HasCallStack => Bool -> a -> a
assert (peer
peer peer -> Map peer PeerGSV -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map peer PeerGSV
m) (STM m PeerGSV -> STM m PeerGSV) -> STM m PeerGSV -> STM m PeerGSV
forall a b. (a -> b) -> a -> b
$ do
            let (Maybe PeerGSV
gsv', Map peer PeerGSV
m') = (peer -> PeerGSV -> Maybe PeerGSV)
-> peer -> Map peer PeerGSV -> (Maybe PeerGSV, Map peer PeerGSV)
forall k a.
Ord k =>
(k -> a -> Maybe a) -> k -> Map k a -> (Maybe a, Map k a)
M.updateLookupWithKey
                    (\peer
_ PeerGSV
a -> if PeerGSV -> Time
sampleTime PeerGSV
a Time -> Time -> Bool
forall a. Eq a => a -> a -> Bool
== DiffTime -> Time
Time DiffTime
0 -- Ignore the initial dummy value
                                then PeerGSV -> Maybe PeerGSV
forall a. a -> Maybe a
Just PeerGSV
sample
                                else PeerGSV -> Maybe PeerGSV
forall a. a -> Maybe a
Just (PeerGSV -> Maybe PeerGSV) -> PeerGSV -> Maybe PeerGSV
forall a b. (a -> b) -> a -> b
$ PeerGSV
sample PeerGSV -> PeerGSV -> PeerGSV
forall a. Semigroup a => a -> a -> a
<> PeerGSV
a
                    ) peer
peer Map peer PeerGSV
m
            StrictTVar m (Map peer PeerGSV) -> Map peer PeerGSV -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m (Map peer PeerGSV)
dqCtx Map peer PeerGSV
m'
            PeerGSV -> STM m PeerGSV
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (PeerGSV -> STM m PeerGSV) -> PeerGSV -> STM m PeerGSV
forall a b. (a -> b) -> a -> b
$ Maybe PeerGSV -> PeerGSV
forall a. HasCallStack => Maybe a -> a
fromJust Maybe PeerGSV
gsv'
      Tracer m (TraceKeepAliveClient peer)
-> TraceKeepAliveClient peer -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (TraceKeepAliveClient peer)
tracer (TraceKeepAliveClient peer -> m ())
-> TraceKeepAliveClient peer -> m ()
forall a b. (a -> b) -> a -> b
$ peer -> DiffTime -> PeerGSV -> TraceKeepAliveClient peer
forall peer.
peer -> DiffTime -> PeerGSV -> TraceKeepAliveClient peer
AddSample peer
peer DiffTime
rtt PeerGSV
gsv'

      TVar m Bool
delayVar <- DiffTime -> m (TVar m Bool)
forall (m :: * -> *). MonadTimer m => DiffTime -> m (TVar m Bool)
registerDelay DiffTime
keepAliveInterval
      ControlMessage
decision <- ControlMessageSTM m -> m ControlMessage
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TVar m Bool -> ControlMessageSTM m
decisionSTM TVar m Bool
delayVar)
      Time
now <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
      case ControlMessage
decision of
        -- 'decisionSTM' above cannot return 'Quiesce'
        ControlMessage
Quiesce   -> String -> m (KeepAliveClientSt m ())
forall a. HasCallStack => String -> a
error String
"keepAliveClient: impossible happened"
        ControlMessage
Continue  ->
            let (Word16
cookie, StdGen
rng') = StdGen -> (Word16, StdGen)
forall g. RandomGen g => g -> (Word16, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random StdGen
rng in
            KeepAliveClientSt m () -> m (KeepAliveClientSt m ())
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Cookie -> m (KeepAliveClientSt m ()) -> KeepAliveClientSt m ()
forall (m :: * -> *) a.
Cookie -> m (KeepAliveClientSt m a) -> KeepAliveClientSt m a
SendMsgKeepAlive (Word16 -> Cookie
Cookie Word16
cookie) (m (KeepAliveClientSt m ()) -> KeepAliveClientSt m ())
-> m (KeepAliveClientSt m ()) -> KeepAliveClientSt m ()
forall a b. (a -> b) -> a -> b
$ StdGen -> Time -> m (KeepAliveClientSt m ())
go StdGen
rng' Time
now)
        ControlMessage
Terminate -> KeepAliveClientSt m () -> m (KeepAliveClientSt m ())
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (m () -> KeepAliveClientSt m ()
forall (m :: * -> *) a. m a -> KeepAliveClientSt m a
SendMsgDone (() -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()))


keepAliveServer
  :: forall m.  Applicative m
  => KeepAliveServer m ()
keepAliveServer :: forall (m :: * -> *). Applicative m => KeepAliveServer m ()
keepAliveServer = KeepAliveServer {
    recvMsgKeepAlive :: m (KeepAliveServer m ())
recvMsgKeepAlive = KeepAliveServer m () -> m (KeepAliveServer m ())
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure KeepAliveServer m ()
forall (m :: * -> *). Applicative m => KeepAliveServer m ()
keepAliveServer,
    recvMsgDone :: m ()
recvMsgDone      = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  }