{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Network.Mux.Bearer.Socket (socketAsBearer) where
import Control.Monad (when)
import Control.Tracer
import Data.ByteString.Lazy qualified as BL
import Data.Int
import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI hiding (timeout)
import Network.Socket qualified as Socket
#if !defined(mingw32_HOST_OS)
import Data.ByteString.Internal (create)
import Foreign.Marshal.Utils
import Network.Socket.ByteString qualified as Socket (sendMany)
import Network.Socket.ByteString.Lazy qualified as Socket (recv, sendAll)
#else
import System.Win32.Async.Socket.ByteString.Lazy qualified as Win32.Async
#endif
import Network.Mux.Codec qualified as Mx
import Network.Mux.Time qualified as Mx
import Network.Mux.Timeout qualified as Mx
import Network.Mux.Trace qualified as Mx
import Network.Mux.Types (Bearer)
import Network.Mux.Types qualified as Mx
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
import Network.Mux.TCPInfo (SocketOption (TCPInfoSocketOption))
#endif
socketAsBearer
:: Mx.SDUSize
-> Int
-> Maybe (Mx.ReadBuffer IO)
-> DiffTime
-> DiffTime
-> Tracer IO Mx.Trace
-> Socket.Socket
-> Bearer IO
socketAsBearer :: SDUSize
-> Int
-> Maybe (ReadBuffer IO)
-> DiffTime
-> DiffTime
-> Tracer IO Trace
-> Socket
-> Bearer IO
socketAsBearer SDUSize
sduSize Int
batchSize Maybe (ReadBuffer IO)
readBuffer_m DiffTime
sduTimeout DiffTime
pollInterval Tracer IO Trace
tracer Socket
sd =
Mx.Bearer {
read :: TimeoutFn IO -> IO (SDU, Time)
Mx.read = TimeoutFn IO -> IO (SDU, Time)
readSocket,
write :: TimeoutFn IO -> SDU -> IO Time
Mx.write = TimeoutFn IO -> SDU -> IO Time
writeSocket,
writeMany :: TimeoutFn IO -> [SDU] -> IO Time
Mx.writeMany = TimeoutFn IO -> [SDU] -> IO Time
writeSocketMany,
sduSize :: SDUSize
Mx.sduSize = SDUSize
sduSize,
batchSize :: Int
Mx.batchSize = Int
batchSize,
name :: String
Mx.name = String
"socket-bearer",
egressInterval :: DiffTime
Mx.egressInterval = DiffTime
pollInterval
}
where
readSocket :: Mx.TimeoutFn IO -> IO (Mx.SDU, Time)
readSocket :: TimeoutFn IO -> IO (SDU, Time)
readSocket TimeoutFn IO
timeout = do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceRecvHeaderStart
h0 <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
True Int64
Mx.msHeaderLength
r_m <- timeout sduTimeout $ recvRem h0
case r_m of
Maybe (SDU, Time)
Nothing -> do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSDUReadTimeoutException
Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (Error -> IO (SDU, Time)) -> Error -> IO (SDU, Time)
forall a b. (a -> b) -> a -> b
$ Error
Mx.SDUReadTimeout
Just (SDU, Time)
r -> (SDU, Time) -> IO (SDU, Time)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SDU, Time)
r
recvRem :: BL.ByteString -> IO (Mx.SDU, Time)
recvRem :: ByteString -> IO (SDU, Time)
recvRem !ByteString
h0 = do
hbuf <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Int64
Mx.msHeaderLength Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
BL.length ByteString
h0) [ByteString
h0]
case Mx.decodeSDU hbuf of
Left Error
e -> Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
e
Right header :: SDU
header@Mx.SDU { SDUHeader
msHeader :: SDUHeader
msHeader :: SDU -> SDUHeader
Mx.msHeader } -> do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer (Trace -> IO ()) -> Trace -> IO ()
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Trace
Mx.TraceRecvHeaderEnd SDUHeader
msHeader
!blob <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Word16 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int64) -> Word16 -> Int64
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Word16
Mx.mhLength SDUHeader
msHeader) []
!ts <- getMonotonicTime
let !header' = SDU
header {Mx.msBlob = blob}
traceWith tracer (Mx.TraceRecvDeltaQObservation msHeader ts)
return (header', ts)
recvLen' :: Int64 -> [BL.ByteString] -> IO BL.ByteString
recvLen' :: Int64 -> [ByteString] -> IO ByteString
recvLen' Int64
0 [ByteString]
bufs = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bufs
recvLen' Int64
l [ByteString]
bufs = do
buf <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
False Int64
l
recvLen' (l - BL.length buf) (buf : bufs)
recvAtMost :: Bool -> Int64 -> IO BL.ByteString
recvAtMost :: Bool -> Int64 -> IO ByteString
recvAtMost Bool
waitingOnNxtHeader Int64
l = do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer (Trace -> IO ()) -> Trace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Trace
Mx.TraceRecvStart (Int -> Trace) -> Int -> Trace
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
l
case Maybe (ReadBuffer IO)
readBuffer_m of
Maybe (ReadBuffer IO)
Nothing ->
Bool -> Int64 -> IO ByteString
recvFromSocket Bool
waitingOnNxtHeader Int64
l
Just Mx.ReadBuffer{StrictTVar IO ByteString
rbVar :: StrictTVar IO ByteString
rbVar :: forall (m :: * -> *). ReadBuffer m -> StrictTVar m ByteString
Mx.rbVar, Int
rbSize :: Int
rbSize :: forall (m :: * -> *). ReadBuffer m -> Int
Mx.rbSize} -> do
availableData <- STM IO ByteString -> IO ByteString
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO ByteString -> IO ByteString)
-> STM IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
buf <- StrictTVar IO ByteString -> STM IO ByteString
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar IO ByteString
rbVar
if BL.length buf >= l
then do
let (toProcess, remaining) = BL.splitAt l buf
writeTVar rbVar remaining
return toProcess
else do
writeTVar rbVar BL.empty
return buf
if BL.null availableData
then do
#if !defined(mingw32_HOST_OS)
when (not waitingOnNxtHeader) $
Socket.setSocketOption sd Socket.RecvLowWater $ fromIntegral l
#endif
newBuf <- recvFromSocket waitingOnNxtHeader $ fromIntegral rbSize
atomically $ modifyTVar rbVar (`BL.append` newBuf)
#if !defined(mingw32_HOST_OS)
when (not waitingOnNxtHeader) $
Socket.setSocketOption sd Socket.RecvLowWater 1
#endif
recvAtMost waitingOnNxtHeader l
else do
traceWith tracer $ Mx.TraceRecvEnd $ fromIntegral $ BL.length availableData
return availableData
#if !defined(mingw32_HOST_OS)
recvBuf :: Mx.ReadBuffer IO -> Int64 -> IO BL.ByteString
recvBuf :: ReadBuffer IO -> Int64 -> IO ByteString
recvBuf Mx.ReadBuffer{Ptr Word8
rbBuf :: Ptr Word8
rbBuf :: forall (m :: * -> *). ReadBuffer m -> Ptr Word8
Mx.rbBuf, Int
rbSize :: forall (m :: * -> *). ReadBuffer m -> Int
rbSize :: Int
Mx.rbSize} Int64
maxLen = do
len <- Socket -> Ptr Word8 -> Int -> IO Int
Socket.recvBuf Socket
sd Ptr Word8
rbBuf (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
rbSize (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
maxLen)
traceWith tracer $ Mx.TraceRecvRaw len
if len > 0
then do
bs <- create len (\Ptr Word8
dest -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
dest Ptr Word8
rbBuf Int
len)
return $ BL.fromStrict bs
else return $ BL.empty
#endif
recvFromSocket :: Bool -> Int64 -> IO BL.ByteString
recvFromSocket :: Bool -> Int64 -> IO ByteString
recvFromSocket Bool
waitingOnNxtHeader Int64
l = do
#if defined(mingw32_HOST_OS)
buf <- Win32.Async.recv sd (fromIntegral l)
#else
buf <- (case Maybe (ReadBuffer IO)
readBuffer_m of
Maybe (ReadBuffer IO)
Nothing -> Socket -> Int64 -> IO ByteString
Socket.recv Socket
sd Int64
l
Just ReadBuffer IO
readBuffer -> ReadBuffer IO -> Int64 -> IO ByteString
recvBuf ReadBuffer IO
readBuffer Int64
l
)
#endif
IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"recv errored"
if BL.null buf
then do
when waitingOnNxtHeader $
threadDelay 1
throwIO $ Mx.BearerClosed (show sd ++
" closed when reading data, waiting on next header " ++
show waitingOnNxtHeader)
else return buf
writeSocket :: Mx.TimeoutFn IO -> Mx.SDU -> IO Time
writeSocket :: TimeoutFn IO -> SDU -> IO Time
writeSocket TimeoutFn IO
timeout SDU
sdu = do
ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
let ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
sdu' = SDU -> RemoteClockModel -> SDU
Mx.setTimestamp SDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32)
buf = SDU -> ByteString
Mx.encodeSDU SDU
sdu'
traceWith tracer $ Mx.TraceSendStart (Mx.msHeader sdu')
r <- timeout sduTimeout $
#if defined(mingw32_HOST_OS)
Win32.Async.sendAll sd buf
#else
Socket.sendAll sd buf
#endif
`catch` Mx.handleIOException "sendAll errored"
case r of
Maybe ()
Nothing -> do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSDUWriteTimeoutException
Error -> IO Time
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUWriteTimeout
Just ()
_ -> do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
traceWith tracer $ Mx.TraceTCPInfo tcpi (Mx.mhLength $ Mx.msHeader sdu)
#endif
Time -> IO Time
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts
writeSocketMany :: Mx.TimeoutFn IO -> [Mx.SDU] -> IO Time
#if defined(mingw32_HOST_OS)
writeSocketMany timeout sdus = do
ts <- getMonotonicTime
mapM_ (writeSocket timeout) sdus
return ts
#else
writeSocketMany :: TimeoutFn IO -> [SDU] -> IO Time
writeSocketMany TimeoutFn IO
timeout [SDU]
sdus = do
ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
let ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
buf = (SDU -> ByteString) -> [SDU] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (SDU -> ByteString
Mx.encodeSDU (SDU -> ByteString) -> (SDU -> SDU) -> SDU -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
(\SDU
sdu -> SDU -> RemoteClockModel -> SDU
Mx.setTimestamp SDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32))) [SDU]
sdus
r <- timeout ((fromIntegral $ length sdus) * sduTimeout) $
Socket.sendMany sd (concatMap BL.toChunks buf)
`catch` Mx.handleIOException "sendAll errored"
case r of
Maybe ()
Nothing -> do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSDUWriteTimeoutException
Error -> IO Time
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUWriteTimeout
Just ()
_ -> do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
traceWith tracer $ Mx.TraceTCPInfo tcpi (sum $ map (Mx.mhLength . Mx.msHeader) sdus)
#endif
Time -> IO Time
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts
#endif