{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Network.Mux.Bearer.Socket (socketAsBearer) where
import Control.Monad (when)
import Control.Tracer
import Data.ByteString.Lazy qualified as BL
import Data.Int
import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI hiding (timeout)
import Network.Socket qualified as Socket
#if !defined(mingw32_HOST_OS)
import Data.ByteString.Internal (create)
import Foreign.Marshal.Utils
import Network.Socket.ByteString qualified as Socket (sendMany)
import Network.Socket.ByteString.Lazy qualified as Socket (recv, sendAll)
#else
import System.Win32.Async.Socket.ByteString.Lazy qualified as Win32.Async
#endif
import Network.Mux.Codec qualified as Mx
import Network.Mux.Time qualified as Mx
import Network.Mux.Timeout qualified as Mx
import Network.Mux.Trace qualified as Mx
import Network.Mux.Types (Bearer, BearerTrace)
import Network.Mux.Types qualified as Mx
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
import Network.Mux.TCPInfo (SocketOption (TCPInfoSocketOption))
#endif
socketAsBearer
:: Mx.SDUSize
-> Int
-> Maybe (Mx.ReadBuffer IO)
-> DiffTime
-> DiffTime
-> Socket.Socket
-> Bearer IO
socketAsBearer :: SDUSize
-> Int
-> Maybe (ReadBuffer IO)
-> DiffTime
-> DiffTime
-> Socket
-> Bearer IO
socketAsBearer SDUSize
sduSize Int
batchSize Maybe (ReadBuffer IO)
readBuffer_m DiffTime
sduTimeout DiffTime
egressInterval Socket
sd =
Mx.Bearer {
read :: Tracer IO BearerTrace -> TimeoutFn IO -> IO (SDU, Time)
Mx.read = Tracer IO BearerTrace -> TimeoutFn IO -> IO (SDU, Time)
readSocket,
write :: Tracer IO BearerTrace -> TimeoutFn IO -> SDU -> IO Time
Mx.write = Tracer IO BearerTrace -> TimeoutFn IO -> SDU -> IO Time
writeSocket,
writeMany :: Tracer IO BearerTrace -> TimeoutFn IO -> [SDU] -> IO Time
Mx.writeMany = Tracer IO BearerTrace -> TimeoutFn IO -> [SDU] -> IO Time
writeSocketMany,
sduSize :: SDUSize
Mx.sduSize = SDUSize
sduSize,
batchSize :: Int
Mx.batchSize = Int
batchSize,
name :: String
Mx.name = String
"socket-bearer",
DiffTime
egressInterval :: DiffTime
egressInterval :: DiffTime
Mx.egressInterval
}
where
readSocket :: Tracer IO BearerTrace -> Mx.TimeoutFn IO -> IO (Mx.SDU, Time)
readSocket :: Tracer IO BearerTrace -> TimeoutFn IO -> IO (SDU, Time)
readSocket Tracer IO BearerTrace
tracer TimeoutFn IO
timeout = do
Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceRecvHeaderStart
h0 <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
True Int64
Mx.msHeaderLength
r_m <- timeout sduTimeout $ recvRem h0
case r_m of
Maybe (SDU, Time)
Nothing -> do
Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceSDUReadTimeoutException
Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUReadTimeout
Just (SDU, Time)
r -> (SDU, Time) -> IO (SDU, Time)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SDU, Time)
r
where
recvRem :: BL.ByteString -> IO (Mx.SDU, Time)
recvRem :: ByteString -> IO (SDU, Time)
recvRem !ByteString
h0 = do
hbuf <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Int64
Mx.msHeaderLength Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
BL.length ByteString
h0) [ByteString
h0]
case Mx.decodeSDU hbuf of
Left Error
e -> Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
e
Right header :: SDU
header@Mx.SDU { SDUHeader
msHeader :: SDUHeader
msHeader :: SDU -> SDUHeader
Mx.msHeader } -> do
Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer (BearerTrace -> IO ()) -> BearerTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ SDUHeader -> BearerTrace
Mx.TraceRecvHeaderEnd SDUHeader
msHeader
!blob <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Word16 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int64) -> Word16 -> Int64
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Word16
Mx.mhLength SDUHeader
msHeader) []
!ts <- getMonotonicTime
let !header' = SDU
header {Mx.msBlob = blob}
traceWith tracer (Mx.TraceRecvDeltaQObservation msHeader ts)
return (header', ts)
recvLen' :: Int64 -> [BL.ByteString] -> IO BL.ByteString
recvLen' :: Int64 -> [ByteString] -> IO ByteString
recvLen' Int64
0 [ByteString]
bufs = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bufs
recvLen' Int64
l [ByteString]
bufs = do
buf <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
False Int64
l
recvLen' (l - BL.length buf) (buf : bufs)
recvAtMost :: Bool -> Int64 -> IO BL.ByteString
recvAtMost :: Bool -> Int64 -> IO ByteString
recvAtMost Bool
waitingOnNxtHeader Int64
l = do
Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer (BearerTrace -> IO ()) -> BearerTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> BearerTrace
Mx.TraceRecvStart (Int -> BearerTrace) -> Int -> BearerTrace
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
l
case Maybe (ReadBuffer IO)
readBuffer_m of
Maybe (ReadBuffer IO)
Nothing ->
Bool -> Int64 -> IO ByteString
recvFromSocket Bool
waitingOnNxtHeader Int64
l
Just Mx.ReadBuffer{StrictTVar IO ByteString
rbVar :: StrictTVar IO ByteString
rbVar :: forall (m :: * -> *). ReadBuffer m -> StrictTVar m ByteString
Mx.rbVar, Int
rbSize :: Int
rbSize :: forall (m :: * -> *). ReadBuffer m -> Int
Mx.rbSize} -> do
availableData <- STM IO ByteString -> IO ByteString
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO ByteString -> IO ByteString)
-> STM IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
buf <- StrictTVar IO ByteString -> STM IO ByteString
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar IO ByteString
rbVar
if BL.length buf >= l
then do
let (toProcess, remaining) = BL.splitAt l buf
writeTVar rbVar remaining
return toProcess
else do
writeTVar rbVar BL.empty
return buf
if BL.null availableData
then do
#if !defined(mingw32_HOST_OS)
when (not waitingOnNxtHeader) $
Socket.setSocketOption sd Socket.RecvLowWater $ fromIntegral l
#endif
newBuf <- recvFromSocket waitingOnNxtHeader $ fromIntegral rbSize
atomically $ modifyTVar rbVar (`BL.append` newBuf)
#if !defined(mingw32_HOST_OS)
when (not waitingOnNxtHeader) $
Socket.setSocketOption sd Socket.RecvLowWater 1
#endif
recvAtMost waitingOnNxtHeader l
else do
traceWith tracer $ Mx.TraceRecvEnd $ fromIntegral $ BL.length availableData
return availableData
#if !defined(mingw32_HOST_OS)
recvBuf :: Mx.ReadBuffer IO -> Int64 -> IO BL.ByteString
recvBuf :: ReadBuffer IO -> Int64 -> IO ByteString
recvBuf Mx.ReadBuffer{Ptr Word8
rbBuf :: Ptr Word8
rbBuf :: forall (m :: * -> *). ReadBuffer m -> Ptr Word8
Mx.rbBuf, Int
rbSize :: forall (m :: * -> *). ReadBuffer m -> Int
rbSize :: Int
Mx.rbSize} Int64
maxLen = do
len <- Socket -> Ptr Word8 -> Int -> IO Int
Socket.recvBuf Socket
sd Ptr Word8
rbBuf (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
rbSize (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
maxLen)
traceWith tracer $ Mx.TraceRecvRaw len
if len > 0
then do
bs <- create len (\Ptr Word8
dest -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
dest Ptr Word8
rbBuf Int
len)
return $ BL.fromStrict bs
else return $ BL.empty
#endif
recvFromSocket :: Bool -> Int64 -> IO BL.ByteString
recvFromSocket :: Bool -> Int64 -> IO ByteString
recvFromSocket Bool
waitingOnNxtHeader Int64
l = do
#if defined(mingw32_HOST_OS)
buf <- Win32.Async.recv sd (fromIntegral l)
#else
buf <- (case Maybe (ReadBuffer IO)
readBuffer_m of
Maybe (ReadBuffer IO)
Nothing -> Socket -> Int64 -> IO ByteString
Socket.recv Socket
sd Int64
l
Just ReadBuffer IO
readBuffer -> ReadBuffer IO -> Int64 -> IO ByteString
recvBuf ReadBuffer IO
readBuffer Int64
l
)
#endif
IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"recv errored"
if BL.null buf
then do
when waitingOnNxtHeader $
threadDelay 1
throwIO $ Mx.BearerClosed (show sd ++
" closed when reading data, waiting on next header " ++
show waitingOnNxtHeader)
else return buf
writeSocket :: Tracer IO BearerTrace -> Mx.TimeoutFn IO -> Mx.SDU -> IO Time
writeSocket :: Tracer IO BearerTrace -> TimeoutFn IO -> SDU -> IO Time
writeSocket Tracer IO BearerTrace
tracer TimeoutFn IO
timeout SDU
sdu = do
ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
let ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
sdu' = SDU -> RemoteClockModel -> SDU
Mx.setTimestamp SDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32)
buf = SDU -> ByteString
Mx.encodeSDU SDU
sdu'
traceWith tracer $ Mx.TraceSendStart (Mx.msHeader sdu')
r <- timeout sduTimeout $
#if defined(mingw32_HOST_OS)
Win32.Async.sendAll sd buf
#else
Socket.sendAll sd buf
#endif
`catch` Mx.handleIOException "sendAll errored"
case r of
Maybe ()
Nothing -> do
Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceSDUWriteTimeoutException
Error -> IO Time
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUWriteTimeout
Just ()
_ -> do
Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
traceWith tracer $ Mx.TraceTCPInfo tcpi (Mx.mhLength $ Mx.msHeader sdu)
#endif
Time -> IO Time
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts
writeSocketMany :: Tracer IO BearerTrace -> Mx.TimeoutFn IO -> [Mx.SDU] -> IO Time
#if defined(mingw32_HOST_OS)
writeSocketMany tracer timeout sdus = do
ts <- getMonotonicTime
mapM_ (writeSocket tracer timeout) sdus
return ts
#else
writeSocketMany :: Tracer IO BearerTrace -> TimeoutFn IO -> [SDU] -> IO Time
writeSocketMany Tracer IO BearerTrace
tracer TimeoutFn IO
timeout [SDU]
sdus = do
ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
let ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
buf = (SDU -> ByteString) -> [SDU] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (SDU -> ByteString
Mx.encodeSDU (SDU -> ByteString) -> (SDU -> SDU) -> SDU -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
(\SDU
sdu -> SDU -> RemoteClockModel -> SDU
Mx.setTimestamp SDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32))) [SDU]
sdus
r <- timeout (fromIntegral (length sdus) * sduTimeout) $
Socket.sendMany sd (concatMap BL.toChunks buf)
`catch` Mx.handleIOException "sendAll errored"
case r of
Maybe ()
Nothing -> do
Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceSDUWriteTimeoutException
Error -> IO Time
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUWriteTimeout
Just ()
_ -> do
Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
traceWith tracer $ Mx.TraceTCPInfo tcpi (sum $ map (Mx.mhLength . Mx.msHeader) sdus)
#endif
Time -> IO Time
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts
#endif