{-# LANGUAGE CPP #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} module Ouroboros.Network.RawBearer.Test.Utils where import Ouroboros.Network.RawBearer import Ouroboros.Network.Snocket import Control.Concurrent.Class.MonadMVar import Control.Exception (Exception) import Control.Monad (when) import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadFork (labelThisThread) import Control.Monad.Class.MonadST (MonadST, stToIO) import Control.Monad.Class.MonadThrow (MonadThrow, bracket, finally, throwIO) import Control.Monad.ST.Unsafe (unsafeIOToST) import Control.Tracer (Tracer (..), traceWith) import Data.ByteString (ByteString) import Data.ByteString qualified as BS import Foreign.Marshal (copyBytes, free, mallocBytes) import Foreign.Ptr (castPtr, plusPtr) import Test.QuickCheck newtype Message = Message { Message -> ByteString messageBytes :: ByteString } deriving (Int -> Message -> ShowS [Message] -> ShowS Message -> String (Int -> Message -> ShowS) -> (Message -> String) -> ([Message] -> ShowS) -> Show Message forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a $cshowsPrec :: Int -> Message -> ShowS showsPrec :: Int -> Message -> ShowS $cshow :: Message -> String show :: Message -> String $cshowList :: [Message] -> ShowS showList :: [Message] -> ShowS Show, Message -> Message -> Bool (Message -> Message -> Bool) -> (Message -> Message -> Bool) -> Eq Message forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a $c== :: Message -> Message -> Bool == :: Message -> Message -> Bool $c/= :: Message -> Message -> Bool /= :: Message -> Message -> Bool Eq, Eq Message Eq Message => (Message -> Message -> Ordering) -> (Message -> Message -> Bool) -> (Message -> Message -> Bool) -> (Message -> Message -> Bool) -> (Message -> Message -> Bool) -> (Message -> Message -> Message) -> (Message -> Message -> Message) -> Ord Message Message -> Message -> Bool Message -> Message -> Ordering Message -> Message -> Message forall a. Eq a => (a -> a -> Ordering) -> (a -> a -> Bool) -> (a -> a -> Bool) -> (a -> a -> Bool) -> (a -> a -> Bool) -> (a -> a -> a) -> (a -> a -> a) -> Ord a $ccompare :: Message -> Message -> Ordering compare :: Message -> Message -> Ordering $c< :: Message -> Message -> Bool < :: Message -> Message -> Bool $c<= :: Message -> Message -> Bool <= :: Message -> Message -> Bool $c> :: Message -> Message -> Bool > :: Message -> Message -> Bool $c>= :: Message -> Message -> Bool >= :: Message -> Message -> Bool $cmax :: Message -> Message -> Message max :: Message -> Message -> Message $cmin :: Message -> Message -> Message min :: Message -> Message -> Message Ord) instance Arbitrary Message where shrink :: Message -> [Message] shrink = (Message -> Bool) -> [Message] -> [Message] forall a. (a -> Bool) -> [a] -> [a] filter (Bool -> Bool not (Bool -> Bool) -> (Message -> Bool) -> Message -> Bool forall b c a. (b -> c) -> (a -> b) -> a -> c . ByteString -> Bool BS.null (ByteString -> Bool) -> (Message -> ByteString) -> Message -> Bool forall b c a. (b -> c) -> (a -> b) -> a -> c . Message -> ByteString messageBytes) ([Message] -> [Message]) -> (Message -> [Message]) -> Message -> [Message] forall b c a. (b -> c) -> (a -> b) -> a -> c . ([Word8] -> Message) -> [[Word8]] -> [Message] forall a b. (a -> b) -> [a] -> [b] forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b fmap (ByteString -> Message Message (ByteString -> Message) -> ([Word8] -> ByteString) -> [Word8] -> Message forall b c a. (b -> c) -> (a -> b) -> a -> c . [Word8] -> ByteString BS.pack) ([[Word8]] -> [Message]) -> (Message -> [[Word8]]) -> Message -> [Message] forall b c a. (b -> c) -> (a -> b) -> a -> c . [Word8] -> [[Word8]] forall a. Arbitrary a => a -> [a] shrink ([Word8] -> [[Word8]]) -> (Message -> [Word8]) -> Message -> [[Word8]] forall b c a. (b -> c) -> (a -> b) -> a -> c . ByteString -> [Word8] BS.unpack (ByteString -> [Word8]) -> (Message -> ByteString) -> Message -> [Word8] forall b c a. (b -> c) -> (a -> b) -> a -> c . Message -> ByteString messageBytes arbitrary :: Gen Message arbitrary = ByteString -> Message Message (ByteString -> Message) -> ([Word8] -> ByteString) -> [Word8] -> Message forall b c a. (b -> c) -> (a -> b) -> a -> c . [Word8] -> ByteString BS.pack ([Word8] -> Message) -> Gen [Word8] -> Gen Message forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> Gen Word8 -> Gen [Word8] forall a. Gen a -> Gen [a] listOf1 Gen Word8 forall a. Arbitrary a => Gen a arbitrary newtype TestError = TestError String deriving (Int -> TestError -> ShowS [TestError] -> ShowS TestError -> String (Int -> TestError -> ShowS) -> (TestError -> String) -> ([TestError] -> ShowS) -> Show TestError forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a $cshowsPrec :: Int -> TestError -> ShowS showsPrec :: Int -> TestError -> ShowS $cshow :: TestError -> String show :: TestError -> String $cshowList :: [TestError] -> ShowS showList :: [TestError] -> ShowS Show) instance Exception TestError where rawBearerSendAndReceive :: forall m fd addr . ( MonadST m , MonadThrow m , MonadAsync m , MonadMVar m , Show addr ) => Tracer m String -> Snocket m fd addr -> MakeRawBearer m fd -> addr -> Maybe addr -> Message -> m Property rawBearerSendAndReceive :: forall (m :: * -> *) fd addr. (MonadST m, MonadThrow m, MonadAsync m, MonadMVar m, Show addr) => Tracer m String -> Snocket m fd addr -> MakeRawBearer m fd -> addr -> Maybe addr -> Message -> m Property rawBearerSendAndReceive Tracer m String tracer Snocket m fd addr snocket MakeRawBearer m fd mkrb addr serverAddr Maybe addr mclientAddr Message msg = do let io :: IO a -> m a io = ST (PrimState m) a -> m a forall a. ST (PrimState m) a -> m a forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a stToIO (ST (PrimState m) a -> m a) -> (IO a -> ST (PrimState m) a) -> IO a -> m a forall b c a. (b -> c) -> (a -> b) -> a -> c . IO a -> ST (PrimState m) a forall a s. IO a -> ST s a unsafeIOToST let size :: Int size = ByteString -> Int BS.length (Message -> ByteString messageBytes Message msg) retVar <- m (MVar m ByteString) forall a. m (MVar m a) forall (m :: * -> *) a. MonadMVar m => m (MVar m a) newEmptyMVar senderDone <- newEmptyMVar let sender = m fd -> (fd -> m ()) -> (fd -> m ()) -> m () forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c forall (m :: * -> *) a b c. MonadThrow m => m a -> (a -> m b) -> (a -> m c) -> m c bracket (Snocket m fd addr -> addr -> m fd forall (m :: * -> *) fd addr. Snocket m fd addr -> addr -> m fd openToConnect Snocket m fd addr snocket addr serverAddr) (\fd s -> Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer String "sender: closing" m () -> m () -> m () forall a b. m a -> m b -> m b forall (m :: * -> *) a b. Monad m => m a -> m b -> m b >> Snocket m fd addr -> fd -> m () forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m () close Snocket m fd addr snocket fd s) ((fd -> m ()) -> m ()) -> (fd -> m ()) -> m () forall a b. (a -> b) -> a -> b $ \fd s -> do case Maybe addr mclientAddr of Maybe addr Nothing -> () -> m () forall a. a -> m a forall (m :: * -> *) a. Monad m => a -> m a return () Just addr clientAddr -> do Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer (String -> m ()) -> String -> m () forall a b. (a -> b) -> a -> b $ String "sender: binding to " String -> ShowS forall a. [a] -> [a] -> [a] ++ addr -> String forall a. Show a => a -> String show addr clientAddr Snocket m fd addr -> fd -> addr -> m () forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> addr -> m () bind Snocket m fd addr snocket fd s addr clientAddr Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer (String -> m ()) -> String -> m () forall a b. (a -> b) -> a -> b $ String "sender: connecting to " String -> ShowS forall a. [a] -> [a] -> [a] ++ addr -> String forall a. Show a => a -> String show addr serverAddr Snocket m fd addr -> fd -> addr -> m () forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> addr -> m () connect Snocket m fd addr snocket fd s addr serverAddr Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer String "sender: connected" bearer <- MakeRawBearer m fd -> fd -> m (RawBearer m) forall (m :: * -> *) fd. MakeRawBearer m fd -> fd -> m (RawBearer m) getRawBearer MakeRawBearer m fd mkrb fd s bracket (io $ mallocBytes size) (io . free) $ \Ptr CChar srcBuf -> do IO () -> m () forall {a}. IO a -> m a io (IO () -> m ()) -> IO () -> m () forall a b. (a -> b) -> a -> b $ ByteString -> (CStringLen -> IO ()) -> IO () forall a. ByteString -> (CStringLen -> IO a) -> IO a BS.useAsCStringLen (Message -> ByteString messageBytes Message msg) ((Ptr CChar -> Int -> IO ()) -> CStringLen -> IO () forall a b c. (a -> b -> c) -> (a, b) -> c uncurry (Ptr CChar -> Ptr CChar -> Int -> IO () forall a. Ptr a -> Ptr a -> Int -> IO () copyBytes Ptr CChar srcBuf)) let go :: Ptr Word8 -> Int -> m () go Ptr Word8 _ Int 0 = do Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer String "sender: done" () -> m () forall a. a -> m a forall (m :: * -> *) a. Monad m => a -> m a return () go Ptr Word8 _ Int n | Int n Int -> Int -> Bool forall a. Ord a => a -> a -> Bool < Int 0 = do String -> m () forall a. HasCallStack => String -> a error String "sender: negative byte count" go Ptr Word8 buf Int n = do Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer (String -> m ()) -> String -> m () forall a b. (a -> b) -> a -> b $ String "sender: " String -> ShowS forall a. [a] -> [a] -> [a] ++ Int -> String forall a. Show a => a -> String show Int n String -> ShowS forall a. [a] -> [a] -> [a] ++ String " bytes left" bytesSent <- RawBearer m -> Ptr Word8 -> Int -> m Int forall (m :: * -> *). RawBearer m -> Ptr Word8 -> Int -> m Int send RawBearer m bearer Ptr Word8 buf Int n when (bytesSent == 0) (throwIO $ TestError "sender: premature hangup") let n' = Int n Int -> Int -> Int forall a. Num a => a -> a -> a - Int bytesSent traceWith tracer $ "sender: " ++ show bytesSent ++ " bytes sent, " ++ show n' ++ " remaining" go (plusPtr buf bytesSent) n' Ptr Word8 -> Int -> m () go (Ptr CChar -> Ptr Word8 forall a b. Ptr a -> Ptr b castPtr Ptr CChar srcBuf) Int size MVar m () -> () -> m () forall a. MVar m a -> a -> m () forall (m :: * -> *) a. MonadMVar m => MVar m a -> a -> m () putMVar MVar m () senderDone () receiver fd s = do let acceptLoop :: Accept m fd addr -> m () acceptLoop :: Accept m fd addr -> m () acceptLoop Accept m fd addr accept0 = do Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer String "receiver: accepting connection" (accepted, acceptNext) <- Accept m fd addr -> m (Accepted fd addr, Accept m fd addr) forall (m :: * -> *) fd addr. Accept m fd addr -> m (Accepted fd addr, Accept m fd addr) runAccept Accept m fd addr accept0 case accepted :: Accepted fd addr of AcceptFailure SomeException err -> SomeException -> m () forall e a. Exception e => e -> m a forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a throwIO SomeException err Accepted fd s' addr _ -> do String -> m () forall (m :: * -> *). MonadThread m => String -> m () labelThisThread String "accept" Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer String "receiver: connection accepted" (m () -> m () -> m ()) -> m () -> m () -> m () forall a b c. (a -> b -> c) -> b -> a -> c flip m () -> m () -> m () forall a b. m a -> m b -> m a forall (m :: * -> *) a b. MonadThrow m => m a -> m b -> m a finally (Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer String "receiver: closing connection" m () -> m () -> m () forall a b. m a -> m b -> m b forall (m :: * -> *) a b. Monad m => m a -> m b -> m b >> Snocket m fd addr -> fd -> m () forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m () close Snocket m fd addr snocket fd s' m () -> m () -> m () forall a b. m a -> m b -> m b forall (m :: * -> *) a b. Monad m => m a -> m b -> m b >> Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer String "receiver: connection closed") (m () -> m ()) -> m () -> m () forall a b. (a -> b) -> a -> b $ do bearer <- MakeRawBearer m fd -> fd -> m (RawBearer m) forall (m :: * -> *) fd. MakeRawBearer m fd -> fd -> m (RawBearer m) getRawBearer MakeRawBearer m fd mkrb fd s' retval <- bracket (io $ mallocBytes size) (io . free) $ \Ptr Any dstBuf -> do let go :: Ptr Word8 -> Int -> m () go Ptr Word8 _ Int 0 = do Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer String "receiver: done receiving" () -> m () forall a. a -> m a forall (m :: * -> *) a. Monad m => a -> m a return () go Ptr Word8 _ Int n | Int n Int -> Int -> Bool forall a. Ord a => a -> a -> Bool < Int 0 = do String -> m () forall a. HasCallStack => String -> a error String "receiver: negative byte count" go Ptr Word8 buf Int n = do Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer (String -> m ()) -> String -> m () forall a b. (a -> b) -> a -> b $ String "receiver: " String -> ShowS forall a. [a] -> [a] -> [a] ++ Int -> String forall a. Show a => a -> String show Int n String -> ShowS forall a. [a] -> [a] -> [a] ++ String " bytes left" bytesReceived <- RawBearer m -> Ptr Word8 -> Int -> m Int forall (m :: * -> *). RawBearer m -> Ptr Word8 -> Int -> m Int recv RawBearer m bearer Ptr Word8 buf Int n when (bytesReceived == 0) (throwIO $ TestError "receiver: premature hangup") let n' = Int n Int -> Int -> Int forall a. Num a => a -> a -> a - Int bytesReceived traceWith tracer $ "receiver: " ++ show bytesReceived ++ " bytes received, " ++ show n' ++ " remaining" go (plusPtr buf bytesReceived) n' Ptr Word8 -> Int -> m () go (Ptr Any -> Ptr Word8 forall a b. Ptr a -> Ptr b castPtr Ptr Any dstBuf) Int size IO ByteString -> m ByteString forall {a}. IO a -> m a io (CStringLen -> IO ByteString BS.packCStringLen (Ptr Any -> Ptr CChar forall a b. Ptr a -> Ptr b castPtr Ptr Any dstBuf, Int size)) traceWith tracer $ "receiver: received " ++ show retval written <- tryPutMVar retVar retval traceWith tracer $ if written then "receiver: stored " ++ show retval else "receiver: already have result" Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer String "receiver: finishing connection" Accept m fd addr -> m () acceptLoop Accept m fd addr acceptNext Snocket m fd addr -> fd -> m (Accept m fd addr) forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m (Accept m fd addr) accept Snocket m fd addr snocket fd s m (Accept m fd addr) -> (Accept m fd addr -> m ()) -> m () forall a b. m a -> (a -> m b) -> m b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= Accept m fd addr -> m () acceptLoop resBSEither <- bracket (open snocket (addrFamily snocket serverAddr)) (close snocket) $ \fd s -> do Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer String "receiver: starting" Snocket m fd addr -> fd -> addr -> m () forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> addr -> m () bind Snocket m fd addr snocket fd s addr serverAddr Snocket m fd addr -> fd -> m () forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m () listen Snocket m fd addr snocket fd s Tracer m String -> String -> m () forall (m :: * -> *) a. Tracer m a -> a -> m () traceWith Tracer m String tracer String "receiver: listening" m ((), ()) -> m ByteString -> m (Either ((), ()) ByteString) forall a b. m a -> m b -> m (Either a b) forall (m :: * -> *) a b. MonadAsync m => m a -> m b -> m (Either a b) race (m () sender m () -> m () -> m ((), ()) forall a b. m a -> m b -> m (a, b) forall (m :: * -> *) a b. MonadAsync m => m a -> m b -> m (a, b) `concurrently` fd -> m () receiver fd s) (MVar m ByteString -> m ByteString forall a. MVar m a -> m a forall (m :: * -> *) a. MonadMVar m => MVar m a -> m a takeMVar MVar m ByteString retVar m ByteString -> m () -> m ByteString forall a b. m a -> m b -> m a forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a <* MVar m () -> m () forall a. MVar m a -> m a forall (m :: * -> *) a. MonadMVar m => MVar m a -> m a takeMVar MVar m () senderDone) return $ resBSEither === Right (messageBytes msg)