{-# LANGUAGE CPP                        #-}
{-# LANGUAGE DerivingStrategies         #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE ScopedTypeVariables        #-}

module Ouroboros.Network.RawBearer.Test.Utils where

import Ouroboros.Network.RawBearer
import Ouroboros.Network.Snocket

import Control.Concurrent.Class.MonadMVar
import Control.Exception (Exception)
import Control.Monad (when)
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadFork (labelThisThread)
import Control.Monad.Class.MonadST (MonadST, stToIO)
import Control.Monad.Class.MonadThrow (MonadThrow, bracket, finally, throwIO)
import Control.Monad.ST.Unsafe (unsafeIOToST)
import Control.Tracer (Tracer (..), traceWith)
import Data.ByteString (ByteString)
import Data.ByteString qualified as BS
import Foreign.Marshal (copyBytes, free, mallocBytes)
import Foreign.Ptr (castPtr, plusPtr)

import Test.QuickCheck

newtype Message = Message { Message -> ByteString
messageBytes :: ByteString }
  deriving (Int -> Message -> ShowS
[Message] -> ShowS
Message -> String
(Int -> Message -> ShowS)
-> (Message -> String) -> ([Message] -> ShowS) -> Show Message
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Message -> ShowS
showsPrec :: Int -> Message -> ShowS
$cshow :: Message -> String
show :: Message -> String
$cshowList :: [Message] -> ShowS
showList :: [Message] -> ShowS
Show, Message -> Message -> Bool
(Message -> Message -> Bool)
-> (Message -> Message -> Bool) -> Eq Message
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Message -> Message -> Bool
== :: Message -> Message -> Bool
$c/= :: Message -> Message -> Bool
/= :: Message -> Message -> Bool
Eq, Eq Message
Eq Message =>
(Message -> Message -> Ordering)
-> (Message -> Message -> Bool)
-> (Message -> Message -> Bool)
-> (Message -> Message -> Bool)
-> (Message -> Message -> Bool)
-> (Message -> Message -> Message)
-> (Message -> Message -> Message)
-> Ord Message
Message -> Message -> Bool
Message -> Message -> Ordering
Message -> Message -> Message
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Message -> Message -> Ordering
compare :: Message -> Message -> Ordering
$c< :: Message -> Message -> Bool
< :: Message -> Message -> Bool
$c<= :: Message -> Message -> Bool
<= :: Message -> Message -> Bool
$c> :: Message -> Message -> Bool
> :: Message -> Message -> Bool
$c>= :: Message -> Message -> Bool
>= :: Message -> Message -> Bool
$cmax :: Message -> Message -> Message
max :: Message -> Message -> Message
$cmin :: Message -> Message -> Message
min :: Message -> Message -> Message
Ord)

instance Arbitrary Message where
  shrink :: Message -> [Message]
shrink = (Message -> Bool) -> [Message] -> [Message]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Message -> Bool) -> Message -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Bool
BS.null (ByteString -> Bool) -> (Message -> ByteString) -> Message -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Message -> ByteString
messageBytes) ([Message] -> [Message])
-> (Message -> [Message]) -> Message -> [Message]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Word8] -> Message) -> [[Word8]] -> [Message]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ByteString -> Message
Message (ByteString -> Message)
-> ([Word8] -> ByteString) -> [Word8] -> Message
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
BS.pack) ([[Word8]] -> [Message])
-> (Message -> [[Word8]]) -> Message -> [Message]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> [[Word8]]
forall a. Arbitrary a => a -> [a]
shrink ([Word8] -> [[Word8]])
-> (Message -> [Word8]) -> Message -> [[Word8]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Word8]
BS.unpack (ByteString -> [Word8])
-> (Message -> ByteString) -> Message -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Message -> ByteString
messageBytes
  arbitrary :: Gen Message
arbitrary = ByteString -> Message
Message (ByteString -> Message)
-> ([Word8] -> ByteString) -> [Word8] -> Message
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
BS.pack ([Word8] -> Message) -> Gen [Word8] -> Gen Message
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Word8 -> Gen [Word8]
forall a. Gen a -> Gen [a]
listOf1 Gen Word8
forall a. Arbitrary a => Gen a
arbitrary

newtype TestError = TestError String
  deriving (Int -> TestError -> ShowS
[TestError] -> ShowS
TestError -> String
(Int -> TestError -> ShowS)
-> (TestError -> String)
-> ([TestError] -> ShowS)
-> Show TestError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TestError -> ShowS
showsPrec :: Int -> TestError -> ShowS
$cshow :: TestError -> String
show :: TestError -> String
$cshowList :: [TestError] -> ShowS
showList :: [TestError] -> ShowS
Show)

instance Exception TestError where

rawBearerSendAndReceive :: forall m fd addr
                         . ( MonadST m
                           , MonadThrow m
                           , MonadAsync m
                           , MonadMVar m
                           , Show addr
                           )
                        => Tracer m String
                        -> Snocket m fd addr
                        -> MakeRawBearer m fd
                        -> addr
                        -> Maybe addr
                        -> Message
                        -> m Property
rawBearerSendAndReceive :: forall (m :: * -> *) fd addr.
(MonadST m, MonadThrow m, MonadAsync m, MonadMVar m, Show addr) =>
Tracer m String
-> Snocket m fd addr
-> MakeRawBearer m fd
-> addr
-> Maybe addr
-> Message
-> m Property
rawBearerSendAndReceive Tracer m String
tracer Snocket m fd addr
snocket MakeRawBearer m fd
mkrb addr
serverAddr Maybe addr
mclientAddr Message
msg = do
    let io :: IO a -> m a
io = ST (PrimState m) a -> m a
forall a. ST (PrimState m) a -> m a
forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (ST (PrimState m) a -> m a)
-> (IO a -> ST (PrimState m) a) -> IO a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> ST (PrimState m) a
forall a s. IO a -> ST s a
unsafeIOToST
    let size :: Int
size = ByteString -> Int
BS.length (Message -> ByteString
messageBytes Message
msg)
    retVar <- m (MVar m ByteString)
forall a. m (MVar m a)
forall (m :: * -> *) a. MonadMVar m => m (MVar m a)
newEmptyMVar
    senderDone <- newEmptyMVar
    let sender = m fd -> (fd -> m ()) -> (fd -> m ()) -> m ()
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (Snocket m fd addr -> addr -> m fd
forall (m :: * -> *) fd addr. Snocket m fd addr -> addr -> m fd
openToConnect Snocket m fd addr
snocket addr
serverAddr) (\fd
s -> Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer String
"sender: closing" m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Snocket m fd addr -> fd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
close Snocket m fd addr
snocket fd
s) ((fd -> m ()) -> m ()) -> (fd -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \fd
s -> do
                    case Maybe addr
mclientAddr of
                      Maybe addr
Nothing -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                      Just addr
clientAddr -> do
                        Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"sender: binding to " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
clientAddr
                        Snocket m fd addr -> fd -> addr -> m ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
bind Snocket m fd addr
snocket fd
s addr
clientAddr
                    Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"sender: connecting to " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
serverAddr
                    Snocket m fd addr -> fd -> addr -> m ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
connect Snocket m fd addr
snocket fd
s addr
serverAddr
                    Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer String
"sender: connected"
                    bearer <- MakeRawBearer m fd -> fd -> m (RawBearer m)
forall (m :: * -> *) fd.
MakeRawBearer m fd -> fd -> m (RawBearer m)
getRawBearer MakeRawBearer m fd
mkrb fd
s
                    bracket (io $ mallocBytes size) (io . free) $ \Ptr CChar
srcBuf -> do
                      IO () -> m ()
forall {a}. IO a -> m a
io (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ ByteString -> (CStringLen -> IO ()) -> IO ()
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen (Message -> ByteString
messageBytes Message
msg)
                            ((Ptr CChar -> Int -> IO ()) -> CStringLen -> IO ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Ptr CChar -> Ptr CChar -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr CChar
srcBuf))
                      let go :: Ptr Word8 -> Int -> m ()
go Ptr Word8
_ Int
0 = do
                            Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer String
"sender: done"
                            () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                          go Ptr Word8
_ Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = do
                            String -> m ()
forall a. HasCallStack => String -> a
error String
"sender: negative byte count"
                          go Ptr Word8
buf Int
n = do
                            Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"sender: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" bytes left"
                            bytesSent <- RawBearer m -> Ptr Word8 -> Int -> m Int
forall (m :: * -> *). RawBearer m -> Ptr Word8 -> Int -> m Int
send RawBearer m
bearer Ptr Word8
buf Int
n
                            when (bytesSent == 0) (throwIO $ TestError "sender: premature hangup")
                            let n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
bytesSent
                            traceWith tracer $ "sender: " ++ show bytesSent ++ " bytes sent, " ++ show n' ++ " remaining"
                            go (plusPtr buf bytesSent) n'
                      Ptr Word8 -> Int -> m ()
go (Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
srcBuf) Int
size
                      MVar m () -> () -> m ()
forall a. MVar m a -> a -> m ()
forall (m :: * -> *) a. MonadMVar m => MVar m a -> a -> m ()
putMVar MVar m ()
senderDone ()
        receiver fd
s = do
          let acceptLoop :: Accept m fd addr -> m ()
              acceptLoop :: Accept m fd addr -> m ()
acceptLoop Accept m fd addr
accept0 = do
                Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer String
"receiver: accepting connection"
                (accepted, acceptNext) <- Accept m fd addr -> m (Accepted fd addr, Accept m fd addr)
forall (m :: * -> *) fd addr.
Accept m fd addr -> m (Accepted fd addr, Accept m fd addr)
runAccept Accept m fd addr
accept0
                case accepted :: Accepted fd addr of
                  AcceptFailure SomeException
err ->
                    SomeException -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO SomeException
err
                  Accepted fd
s' addr
_ -> do
                    String -> m ()
forall (m :: * -> *). MonadThread m => String -> m ()
labelThisThread String
"accept"
                    Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer String
"receiver: connection accepted"
                    (m () -> m () -> m ()) -> m () -> m () -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip m () -> m () -> m ()
forall a b. m a -> m b -> m a
forall (m :: * -> *) a b. MonadThrow m => m a -> m b -> m a
finally (Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer String
"receiver: closing connection" m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Snocket m fd addr -> fd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
close Snocket m fd addr
snocket fd
s' m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer String
"receiver: connection closed") (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                      bearer <- MakeRawBearer m fd -> fd -> m (RawBearer m)
forall (m :: * -> *) fd.
MakeRawBearer m fd -> fd -> m (RawBearer m)
getRawBearer MakeRawBearer m fd
mkrb fd
s'
                      retval <- bracket (io $ mallocBytes size) (io . free) $ \Ptr Any
dstBuf -> do
                        let go :: Ptr Word8 -> Int -> m ()
go Ptr Word8
_ Int
0 = do
                              Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer String
"receiver: done receiving"
                              () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                            go Ptr Word8
_ Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = do
                              String -> m ()
forall a. HasCallStack => String -> a
error String
"receiver: negative byte count"
                            go Ptr Word8
buf Int
n = do
                              Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"receiver: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" bytes left"
                              bytesReceived <- RawBearer m -> Ptr Word8 -> Int -> m Int
forall (m :: * -> *). RawBearer m -> Ptr Word8 -> Int -> m Int
recv RawBearer m
bearer Ptr Word8
buf Int
n
                              when (bytesReceived == 0) (throwIO $ TestError "receiver: premature hangup")
                              let n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
bytesReceived
                              traceWith tracer $ "receiver: " ++ show bytesReceived ++ " bytes received, " ++ show n' ++ " remaining"
                              go (plusPtr buf bytesReceived) n'
                        Ptr Word8 -> Int -> m ()
go (Ptr Any -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr Any
dstBuf) Int
size
                        IO ByteString -> m ByteString
forall {a}. IO a -> m a
io (CStringLen -> IO ByteString
BS.packCStringLen (Ptr Any -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Any
dstBuf, Int
size))
                      traceWith tracer $ "receiver: received " ++ show retval
                      written <- tryPutMVar retVar retval
                      traceWith tracer $ if written then "receiver: stored " ++ show retval else "receiver: already have result"
                    Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer String
"receiver: finishing connection"
                    Accept m fd addr -> m ()
acceptLoop Accept m fd addr
acceptNext
          Snocket m fd addr -> fd -> m (Accept m fd addr)
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> m (Accept m fd addr)
accept Snocket m fd addr
snocket fd
s m (Accept m fd addr) -> (Accept m fd addr -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Accept m fd addr -> m ()
acceptLoop

    resBSEither <- bracket (open snocket (addrFamily snocket serverAddr)) (close snocket) $ \fd
s -> do
      Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer String
"receiver: starting"
      Snocket m fd addr -> fd -> addr -> m ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
bind Snocket m fd addr
snocket fd
s addr
serverAddr
      Snocket m fd addr -> fd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
listen Snocket m fd addr
snocket fd
s
      Tracer m String -> String -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m String
tracer String
"receiver: listening"
      m ((), ()) -> m ByteString -> m (Either ((), ()) ByteString)
forall a b. m a -> m b -> m (Either a b)
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> m b -> m (Either a b)
race
        (m ()
sender m () -> m () -> m ((), ())
forall a b. m a -> m b -> m (a, b)
forall (m :: * -> *) a b. MonadAsync m => m a -> m b -> m (a, b)
`concurrently` fd -> m ()
receiver fd
s)
        (MVar m ByteString -> m ByteString
forall a. MVar m a -> m a
forall (m :: * -> *) a. MonadMVar m => MVar m a -> m a
takeMVar MVar m ByteString
retVar m ByteString -> m () -> m ByteString
forall a b. m a -> m b -> m a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* MVar m () -> m ()
forall a. MVar m a -> m a
forall (m :: * -> *) a. MonadMVar m => MVar m a -> m a
takeMVar MVar m ()
senderDone)
    return $ resBSEither === Right (messageBytes msg)