{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Network.Mux.Bearer.Socket (socketAsBearer) where
import Control.Monad (when)
import Control.Tracer
import Data.ByteString.Lazy qualified as BL
import Data.Int
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI hiding (timeout)
import Network.Socket qualified as Socket
#if !defined(mingw32_HOST_OS)
import Network.Socket.ByteString.Lazy qualified as Socket (recv, sendAll)
#else
import System.Win32.Async.Socket.ByteString.Lazy qualified as Win32.Async
#endif
import Network.Mux.Codec qualified as Mx
import Network.Mux.Time qualified as Mx
import Network.Mux.Timeout qualified as Mx
import Network.Mux.Trace qualified as Mx
import Network.Mux.Types (Bearer)
import Network.Mux.Types qualified as Mx
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
import Network.Mux.TCPInfo (SocketOption (TCPInfoSocketOption))
#endif
socketAsBearer
:: Mx.SDUSize
-> DiffTime
-> Tracer IO Mx.Trace
-> Socket.Socket
-> Bearer IO
socketAsBearer :: SDUSize -> DiffTime -> Tracer IO Trace -> Socket -> Bearer IO
socketAsBearer SDUSize
sduSize DiffTime
sduTimeout Tracer IO Trace
tracer Socket
sd =
Mx.Bearer {
read :: TimeoutFn IO -> IO (SDU, Time)
Mx.read = TimeoutFn IO -> IO (SDU, Time)
readSocket,
write :: TimeoutFn IO -> SDU -> IO Time
Mx.write = TimeoutFn IO -> SDU -> IO Time
writeSocket,
sduSize :: SDUSize
Mx.sduSize = SDUSize
sduSize,
name :: String
Mx.name = String
"socket-bearer"
}
where
hdrLenght :: Int64
hdrLenght = Int64
8
readSocket :: Mx.TimeoutFn IO -> IO (Mx.SDU, Time)
readSocket :: TimeoutFn IO -> IO (SDU, Time)
readSocket TimeoutFn IO
timeout = do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceRecvHeaderStart
h0 <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
True Int64
hdrLenght
r_m <- timeout sduTimeout $ recvRem h0
case r_m of
Maybe (SDU, Time)
Nothing -> do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSDUReadTimeoutException
Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (Error -> IO (SDU, Time)) -> Error -> IO (SDU, Time)
forall a b. (a -> b) -> a -> b
$ Error
Mx.SDUReadTimeout
Just (SDU, Time)
r -> (SDU, Time) -> IO (SDU, Time)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SDU, Time)
r
recvRem :: BL.ByteString -> IO (Mx.SDU, Time)
recvRem :: ByteString -> IO (SDU, Time)
recvRem !ByteString
h0 = do
hbuf <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Int64
hdrLenght Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
BL.length ByteString
h0) [ByteString
h0]
case Mx.decodeSDU hbuf of
Left Error
e -> Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
e
Right header :: SDU
header@Mx.SDU { SDUHeader
msHeader :: SDUHeader
msHeader :: SDU -> SDUHeader
Mx.msHeader } -> do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer (Trace -> IO ()) -> Trace -> IO ()
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Trace
Mx.TraceRecvHeaderEnd SDUHeader
msHeader
!blob <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Word16 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int64) -> Word16 -> Int64
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Word16
Mx.mhLength SDUHeader
msHeader) []
!ts <- getMonotonicTime
let !header' = SDU
header {Mx.msBlob = blob}
traceWith tracer (Mx.TraceRecvDeltaQObservation msHeader ts)
return (header', ts)
recvLen' :: Int64 -> [BL.ByteString] -> IO BL.ByteString
recvLen' :: Int64 -> [ByteString] -> IO ByteString
recvLen' Int64
0 [ByteString]
bufs = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bufs
recvLen' Int64
l [ByteString]
bufs = do
buf <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
False Int64
l
recvLen' (l - BL.length buf) (buf : bufs)
recvAtMost :: Bool -> Int64 -> IO BL.ByteString
recvAtMost :: Bool -> Int64 -> IO ByteString
recvAtMost Bool
waitingOnNxtHeader Int64
l = do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer (Trace -> IO ()) -> Trace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Trace
Mx.TraceRecvStart (Int -> Trace) -> Int -> Trace
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
l
#if defined(mingw32_HOST_OS)
buf <- Win32.Async.recv sd (fromIntegral l)
#else
buf <- Socket -> Int64 -> IO ByteString
Socket.recv Socket
sd Int64
l
#endif
IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"recv errored"
if BL.null buf
then do
when waitingOnNxtHeader $
threadDelay 1
throwIO $ Mx.BearerClosed (show sd ++
" closed when reading data, waiting on next header " ++
show waitingOnNxtHeader)
else do
traceWith tracer $ Mx.TraceRecvEnd (fromIntegral $ BL.length buf)
return buf
writeSocket :: Mx.TimeoutFn IO -> Mx.SDU -> IO Time
writeSocket :: TimeoutFn IO -> SDU -> IO Time
writeSocket TimeoutFn IO
timeout SDU
sdu = do
ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
let ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
sdu' = SDU -> RemoteClockModel -> SDU
Mx.setTimestamp SDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32)
buf = SDU -> ByteString
Mx.encodeSDU SDU
sdu'
traceWith tracer $ Mx.TraceSendStart (Mx.msHeader sdu')
r <- timeout sduTimeout $
#if defined(mingw32_HOST_OS)
Win32.Async.sendAll sd buf
#else
Socket.sendAll sd buf
#endif
`catch` Mx.handleIOException "sendAll errored"
case r of
Maybe ()
Nothing -> do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSDUWriteTimeoutException
Error -> IO Time
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUWriteTimeout
Just ()
_ -> do
Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
traceWith tracer $ Mx.TraceTCPInfo tcpi (Mx.mhLength $ Mx.msHeader sdu)
#endif
Time -> IO Time
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts