{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Test.Ouroboros.Network.KeepAlive (tests) where

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad (void)
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadSay
import Control.Monad.Class.MonadST
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI
import Control.Monad.IOSim
import Control.Tracer
import Data.ByteString.Lazy qualified as BL
import Data.Typeable (Typeable)
import System.Random


import Ouroboros.Network.BlockFetch
import Ouroboros.Network.Channel
import Ouroboros.Network.ControlMessage (ControlMessage (..), ControlMessageSTM)
import Ouroboros.Network.DeltaQ
import Ouroboros.Network.Driver.Limits
import Ouroboros.Network.KeepAlive
import Ouroboros.Network.Protocol.KeepAlive.Client
import Ouroboros.Network.Protocol.KeepAlive.Codec
import Ouroboros.Network.Protocol.KeepAlive.Server

import Test.QuickCheck
import Test.Tasty (TestTree, testGroup)
import Test.Tasty.QuickCheck (testProperty)


tests :: TestTree
tests :: TestTree
tests = String -> [TestTree] -> TestTree
testGroup String
"KeepAlive"
    [ String -> (NetworkDelay -> Int -> Property) -> TestTree
forall a. Testable a => String -> a -> TestTree
testProperty String
"KeepAlive Convergence" NetworkDelay -> Int -> Property
prop_keepAlive_convergence]

runKeepAliveClient
    :: forall m peer header block.
        ( MonadAsync m
        , MonadFork m
        , MonadMask m
        , MonadST m
        , MonadTimer m
        , MonadThrow (STM m)
        , Ord peer)
    => Tracer m (TraceKeepAliveClient peer)
    -> StdGen
    -> ControlMessageSTM m
    -> FetchClientRegistry peer header block m
    -> peer
    -> Channel m BL.ByteString
    -> KeepAliveInterval
    -> m ((), Maybe BL.ByteString)
runKeepAliveClient :: forall (m :: * -> *) peer header block.
(MonadAsync m, MonadFork m, MonadMask m, MonadST m, MonadTimer m,
 MonadThrow (STM m), Ord peer) =>
Tracer m (TraceKeepAliveClient peer)
-> StdGen
-> ControlMessageSTM m
-> FetchClientRegistry peer header block m
-> peer
-> Channel m ByteString
-> KeepAliveInterval
-> m ((), Maybe ByteString)
runKeepAliveClient Tracer m (TraceKeepAliveClient peer)
tracer StdGen
rng ControlMessageSTM m
controlMessageSTM FetchClientRegistry peer header block m
registry peer
peer Channel m ByteString
channel KeepAliveInterval
keepAliveInterval =
    let kacApp :: StrictTVar m (Map peer PeerGSV) -> m ((), Maybe ByteString)
kacApp StrictTVar m (Map peer PeerGSV)
dqCtx = Tracer m (TraceSendRecv KeepAlive)
-> Codec KeepAlive DeserialiseFailure m ByteString
-> ProtocolSizeLimits KeepAlive ByteString
-> ProtocolTimeLimits KeepAlive
-> Channel m ByteString
-> Peer KeepAlive 'AsClient 'NonPipelined 'StClient m ()
-> m ((), Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
       a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
 MonadTimer m, ShowProxy ps,
 forall (st' :: ps) stok. (stok ~ StateToken st') => Show stok,
 Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr 'NonPipelined st m a
-> m (a, Maybe bytes)
runPeerWithLimits
                         Tracer m (TraceSendRecv KeepAlive)
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer
                         Codec KeepAlive DeserialiseFailure m ByteString
forall (m :: * -> *).
MonadST m =>
Codec KeepAlive DeserialiseFailure m ByteString
codecKeepAlive_v2
                         ((ByteString -> Word) -> ProtocolSizeLimits KeepAlive ByteString
forall bytes. (bytes -> Word) -> ProtocolSizeLimits KeepAlive bytes
byteLimitsKeepAlive (Int64 -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Word) -> (ByteString -> Int64) -> ByteString -> Word
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int64
BL.length))
                         ProtocolTimeLimits KeepAlive
timeLimitsKeepAlive
                         Channel m ByteString
channel
                         (Peer KeepAlive 'AsClient 'NonPipelined 'StClient m ()
 -> m ((), Maybe ByteString))
-> Peer KeepAlive 'AsClient 'NonPipelined 'StClient m ()
-> m ((), Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ KeepAliveClient m ()
-> Peer KeepAlive 'AsClient 'NonPipelined 'StClient m ()
forall (m :: * -> *) a.
MonadThrow m =>
KeepAliveClient m a -> Client KeepAlive 'NonPipelined 'StClient m a
keepAliveClientPeer
                         (KeepAliveClient m ()
 -> Peer KeepAlive 'AsClient 'NonPipelined 'StClient m ())
-> KeepAliveClient m ()
-> Peer KeepAlive 'AsClient 'NonPipelined 'StClient m ()
forall a b. (a -> b) -> a -> b
$ Tracer m (TraceKeepAliveClient peer)
-> StdGen
-> ControlMessageSTM m
-> peer
-> StrictTVar m (Map peer PeerGSV)
-> KeepAliveInterval
-> KeepAliveClient m ()
forall (m :: * -> *) peer.
(MonadTimer m, Ord peer) =>
Tracer m (TraceKeepAliveClient peer)
-> StdGen
-> ControlMessageSTM m
-> peer
-> StrictTVar m (Map peer PeerGSV)
-> KeepAliveInterval
-> KeepAliveClient m ()
keepAliveClient Tracer m (TraceKeepAliveClient peer)
tracer StdGen
rng ControlMessageSTM m
controlMessageSTM peer
peer StrictTVar m (Map peer PeerGSV)
dqCtx KeepAliveInterval
keepAliveInterval in
    FetchClientRegistry peer header block m
-> peer
-> (StrictTVar m (Map peer PeerGSV) -> m ((), Maybe ByteString))
-> m ((), Maybe ByteString)
forall (m :: * -> *) a peer header block.
(MonadSTM m, MonadFork m, MonadMask m, Ord peer) =>
FetchClientRegistry peer header block m
-> peer -> (StrictTVar m (Map peer PeerGSV) -> m a) -> m a
bracketKeepAliveClient FetchClientRegistry peer header block m
registry peer
peer StrictTVar m (Map peer PeerGSV) -> m ((), Maybe ByteString)
kacApp

runKeepAliveServer
    :: forall m.
        ( MonadAsync m
        , MonadFork m
        , MonadMask m
        , MonadST m
        , MonadTimer m
        , MonadThrow (STM m)
        )
    => Channel m BL.ByteString
    -> m ((), Maybe BL.ByteString)
runKeepAliveServer :: forall (m :: * -> *).
(MonadAsync m, MonadFork m, MonadMask m, MonadST m, MonadTimer m,
 MonadThrow (STM m)) =>
Channel m ByteString -> m ((), Maybe ByteString)
runKeepAliveServer Channel m ByteString
channel =
    Tracer m (TraceSendRecv KeepAlive)
-> Codec KeepAlive DeserialiseFailure m ByteString
-> ProtocolSizeLimits KeepAlive ByteString
-> ProtocolTimeLimits KeepAlive
-> Channel m ByteString
-> Peer KeepAlive 'AsServer 'NonPipelined 'StClient m ()
-> m ((), Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
       a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
 MonadTimer m, ShowProxy ps,
 forall (st' :: ps) stok. (stok ~ StateToken st') => Show stok,
 Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr 'NonPipelined st m a
-> m (a, Maybe bytes)
runPeerWithLimits
        Tracer m (TraceSendRecv KeepAlive)
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer
        Codec KeepAlive DeserialiseFailure m ByteString
forall (m :: * -> *).
MonadST m =>
Codec KeepAlive DeserialiseFailure m ByteString
codecKeepAlive_v2
        ((ByteString -> Word) -> ProtocolSizeLimits KeepAlive ByteString
forall bytes. (bytes -> Word) -> ProtocolSizeLimits KeepAlive bytes
byteLimitsKeepAlive (Int64 -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Word) -> (ByteString -> Int64) -> ByteString -> Word
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int64
BL.length))
        ProtocolTimeLimits KeepAlive
timeLimitsKeepAlive
        Channel m ByteString
channel
        (Peer KeepAlive 'AsServer 'NonPipelined 'StClient m ()
 -> m ((), Maybe ByteString))
-> Peer KeepAlive 'AsServer 'NonPipelined 'StClient m ()
-> m ((), Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ KeepAliveServer m ()
-> Peer KeepAlive 'AsServer 'NonPipelined 'StClient m ()
forall (m :: * -> *) a.
Functor m =>
KeepAliveServer m a -> Server KeepAlive 'NonPipelined 'StClient m a
keepAliveServerPeer
        (KeepAliveServer m ()
 -> Peer KeepAlive 'AsServer 'NonPipelined 'StClient m ())
-> KeepAliveServer m ()
-> Peer KeepAlive 'AsServer 'NonPipelined 'StClient m ()
forall a b. (a -> b) -> a -> b
$ KeepAliveServer m ()
forall (m :: * -> *). Applicative m => KeepAliveServer m ()
keepAliveServer

runKeepAliveClientAndServer
    :: forall m peer header block.
        ( MonadAsync m
        , MonadDelay m
        , MonadFork m
        , MonadMask m
        , MonadSay m
        , MonadST m
        , MonadTimer m
        , MonadThrow (STM m)
        , Ord peer
        )
    => NetworkDelay
    -> Int
    -> Tracer m (TraceKeepAliveClient peer)
    -> ControlMessageSTM m
    -> FetchClientRegistry peer header block m
    -> peer
    -> KeepAliveInterval
    -> m (Async m ((), Maybe BL.ByteString), Async m ((), Maybe BL.ByteString))
runKeepAliveClientAndServer :: forall (m :: * -> *) peer header block.
(MonadAsync m, MonadDelay m, MonadFork m, MonadMask m, MonadSay m,
 MonadST m, MonadTimer m, MonadThrow (STM m), Ord peer) =>
NetworkDelay
-> Int
-> Tracer m (TraceKeepAliveClient peer)
-> ControlMessageSTM m
-> FetchClientRegistry peer header block m
-> peer
-> KeepAliveInterval
-> m (Async m ((), Maybe ByteString),
      Async m ((), Maybe ByteString))
runKeepAliveClientAndServer (NetworkDelay DiffTime
nd) Int
seed Tracer m (TraceKeepAliveClient peer)
tracer ControlMessageSTM m
controlMessageSTM FetchClientRegistry peer header block m
registry peer
peer KeepAliveInterval
keepAliveInterval = do
    (clientChannel, serverChannel) <- m (Channel m ByteString, Channel m ByteString)
forall (m :: * -> *) a. MonadSTM m => m (Channel m a, Channel m a)
createConnectedChannels

    clientAsync <- async $ runKeepAliveClient tracer (mkStdGen seed) controlMessageSTM registry peer
                               (delayChannel nd clientChannel) keepAliveInterval
    serverAsync <- async $ runKeepAliveServer serverChannel
    return (clientAsync, serverAsync)

newtype NetworkDelay = NetworkDelay {
      NetworkDelay -> DiffTime
unNetworkDelay :: DiffTime
    } deriving Int -> NetworkDelay -> ShowS
[NetworkDelay] -> ShowS
NetworkDelay -> String
(Int -> NetworkDelay -> ShowS)
-> (NetworkDelay -> String)
-> ([NetworkDelay] -> ShowS)
-> Show NetworkDelay
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NetworkDelay -> ShowS
showsPrec :: Int -> NetworkDelay -> ShowS
$cshow :: NetworkDelay -> String
show :: NetworkDelay -> String
$cshowList :: [NetworkDelay] -> ShowS
showList :: [NetworkDelay] -> ShowS
Show

instance Arbitrary NetworkDelay where
    arbitrary :: Gen NetworkDelay
arbitrary = do
        m <- (Int, Int) -> Gen Int
forall a. Random a => (a, a) -> Gen a
choose (Int
1, Int
1000 :: Int) -- A delay between 1 and 1000 ms
        return $ NetworkDelay $ (fromIntegral m) / 1000

prop_keepAlive_convergenceM
    :: forall m.
        ( MonadAsync m
        , MonadDelay m
        , MonadFork m
        , MonadLabelledSTM m
        , MonadMask m
        , MonadSay m
        , MonadST m
        , MonadTimer m
        , MonadThrow (STM m)
        )
    => Tracer m (TraceKeepAliveClient String)
    -> NetworkDelay
    -> Int
    -> m ()
prop_keepAlive_convergenceM :: forall (m :: * -> *).
(MonadAsync m, MonadDelay m, MonadFork m, MonadLabelledSTM m,
 MonadMask m, MonadSay m, MonadST m, MonadTimer m,
 MonadThrow (STM m)) =>
Tracer m (TraceKeepAliveClient String)
-> NetworkDelay -> Int -> m ()
prop_keepAlive_convergenceM Tracer m (TraceKeepAliveClient String)
tracer (NetworkDelay DiffTime
nd) Int
seed = do
    registry <- m (FetchClientRegistry String Any Any m)
forall (m :: * -> *) peer header block.
MonadSTM m =>
m (FetchClientRegistry peer header block m)
newFetchClientRegistry
    controlMessageV <- newTVarIO Continue
    let controlMessageSTM = StrictTVar m ControlMessage -> STM m ControlMessage
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m ControlMessage
controlMessageV
        clientId = String
"client"
        keepAliveInterval = DiffTime
10

    (c_aid, s_aid) <- runKeepAliveClientAndServer (NetworkDelay nd) seed tracer controlMessageSTM
                          registry clientId (KeepAliveInterval keepAliveInterval)

    void $ wait c_aid
    void $ wait s_aid

-- Test that our estimate of PeerGSV's G terms converge to
-- a given constant delay.
prop_keepAlive_convergence :: NetworkDelay -> Int -> Property
prop_keepAlive_convergence :: NetworkDelay -> Int -> Property
prop_keepAlive_convergence NetworkDelay
nd Int
seed =
    let trace :: [TraceKeepAliveClient String]
trace = Int
-> [TraceKeepAliveClient String] -> [TraceKeepAliveClient String]
forall a. Int -> [a] -> [a]
take Int
1000
              ([TraceKeepAliveClient String] -> [TraceKeepAliveClient String])
-> (SimTrace () -> [TraceKeepAliveClient String])
-> SimTrace ()
-> [TraceKeepAliveClient String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimTrace () -> [TraceKeepAliveClient String]
forall a b. Typeable b => SimTrace a -> [b]
selectTraceEventsDynamic
              (SimTrace () -> [TraceKeepAliveClient String])
-> SimTrace () -> [TraceKeepAliveClient String]
forall a b. (a -> b) -> a -> b
$ (forall s. IOSim s ()) -> SimTrace ()
forall a. (forall s. IOSim s a) -> SimTrace a
runSimTrace
              ((forall s. IOSim s ()) -> SimTrace ())
-> (forall s. IOSim s ()) -> SimTrace ()
forall a b. (a -> b) -> a -> b
$ Tracer (IOSim s) (TraceKeepAliveClient String)
-> NetworkDelay -> Int -> IOSim s ()
forall (m :: * -> *).
(MonadAsync m, MonadDelay m, MonadFork m, MonadLabelledSTM m,
 MonadMask m, MonadSay m, MonadST m, MonadTimer m,
 MonadThrow (STM m)) =>
Tracer m (TraceKeepAliveClient String)
-> NetworkDelay -> Int -> m ()
prop_keepAlive_convergenceM Tracer (IOSim s) (TraceKeepAliveClient String)
forall a s. Typeable a => Tracer (IOSim s) a
dynamicTracer NetworkDelay
nd Int
seed
     in [TraceKeepAliveClient String] -> Property
verifyConvergence [TraceKeepAliveClient String]
trace
  where
    verifyConvergence :: [TraceKeepAliveClient String] -> Property
    verifyConvergence :: [TraceKeepAliveClient String] -> Property
verifyConvergence [] = Bool -> Property
forall prop. Testable prop => prop -> Property
property Bool
False
    verifyConvergence [TraceKeepAliveClient String
e] = Bool -> Property
forall prop. Testable prop => prop -> Property
property (Bool -> Property) -> Bool -> Property
forall a b. (a -> b) -> a -> b
$ (GSV -> Bool) -> TraceKeepAliveClient String -> Bool
validTrace GSV -> Bool
lastG TraceKeepAliveClient String
e
    verifyConvergence (TraceKeepAliveClient String
e:[TraceKeepAliveClient String]
es) =
        if (GSV -> Bool) -> TraceKeepAliveClient String -> Bool
validTrace GSV -> Bool
validG TraceKeepAliveClient String
e then [TraceKeepAliveClient String] -> Property
verifyConvergence [TraceKeepAliveClient String]
es
                               else Bool -> Property
forall prop. Testable prop => prop -> Property
property Bool
False

    validTrace :: (GSV -> Bool) -> TraceKeepAliveClient String -> Bool
    validTrace :: (GSV -> Bool) -> TraceKeepAliveClient String -> Bool
validTrace GSV -> Bool
vg (AddSample String
_ DiffTime
rtt PeerGSV{GSV
outboundGSV :: GSV
outboundGSV :: PeerGSV -> GSV
outboundGSV, GSV
inboundGSV :: GSV
inboundGSV :: PeerGSV -> GSV
inboundGSV}) =
        NetworkDelay -> DiffTime
unNetworkDelay NetworkDelay
nd DiffTime -> DiffTime -> Bool
forall a. Eq a => a -> a -> Bool
== DiffTime
rtt Bool -> Bool -> Bool
&& GSV -> Bool
vg GSV
outboundGSV Bool -> Bool -> Bool
&& GSV -> Bool
vg GSV
inboundGSV

    validG :: GSV -> Bool
    validG :: GSV -> Bool
validG (GSV DiffTime
g SizeInBytes -> DiffTime
_ Distribution DiffTime
_) = DiffTime
g DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
>= DiffTime
0 Bool -> Bool -> Bool
&& DiffTime
g DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
< DiffTime
2 DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
* NetworkDelay -> DiffTime
unNetworkDelay NetworkDelay
nd

    lastG :: GSV -> Bool
    lastG :: GSV -> Bool
lastG (GSV DiffTime
g SizeInBytes -> DiffTime
_ Distribution DiffTime
_) =
        let low :: DiffTime
low = DiffTime
0.95 DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
* (NetworkDelay -> DiffTime
unNetworkDelay NetworkDelay
nd) DiffTime -> DiffTime -> DiffTime
forall a. Fractional a => a -> a -> a
/ DiffTime
2
            high :: DiffTime
high = DiffTime
1.05 DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
* (NetworkDelay -> DiffTime
unNetworkDelay NetworkDelay
nd) DiffTime -> DiffTime -> DiffTime
forall a. Fractional a => a -> a -> a
/ DiffTime
2 in
        DiffTime
g DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
>= DiffTime
low Bool -> Bool -> Bool
&& DiffTime
g DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
<= DiffTime
high

dynamicTracer :: Typeable a => Tracer (IOSim s) a
dynamicTracer :: forall a s. Typeable a => Tracer (IOSim s) a
dynamicTracer = (a -> IOSim s ()) -> Tracer (IOSim s) a
forall (m :: * -> *) a. (a -> m ()) -> Tracer m a
Tracer a -> IOSim s ()
forall a s. Typeable a => a -> IOSim s ()
traceM