{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.Mux.Bearer.Socket (socketAsBearer) where

import Control.Monad (when)
import Control.Tracer
import Data.ByteString.Lazy qualified as BL
import Data.Int

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI hiding (timeout)

import Network.Socket qualified as Socket
#if !defined(mingw32_HOST_OS)
import Data.ByteString.Internal (create)
import Foreign.Marshal.Utils
import Network.Socket.ByteString qualified as Socket (sendMany)
import Network.Socket.ByteString.Lazy qualified as Socket (recv, sendAll)
#else
import System.Win32.Async.Socket.ByteString.Lazy qualified as Win32.Async
#endif

import Network.Mux.Codec qualified as Mx
import Network.Mux.Time qualified as Mx
import Network.Mux.Timeout qualified as Mx
import Network.Mux.Trace qualified as Mx
import Network.Mux.Types (Bearer)
import Network.Mux.Types qualified as Mx
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
import Network.Mux.TCPInfo (SocketOption (TCPInfoSocketOption))
#endif

-- |
-- Create @'MuxBearer'@ from a socket.
--
-- On Windows 'System.Win32.Async` operations are used to read and write from
-- a socket.  This means that the socket must be associated with the I/O
-- completion port with
-- 'System.Win32.Async.IOManager.associateWithIOCompletionPort'.
--
-- Note: 'IOException's thrown by 'sendAll' and 'recv' are wrapped in
-- 'MuxError'.
--
socketAsBearer
  :: Mx.SDUSize
  -> Int
  -> Maybe (Mx.ReadBuffer IO)
  -> DiffTime
  -> DiffTime
  -> Tracer IO Mx.Trace
  -> Socket.Socket
  -> Bearer IO
socketAsBearer :: SDUSize
-> Int
-> Maybe (ReadBuffer IO)
-> DiffTime
-> DiffTime
-> Tracer IO Trace
-> Socket
-> Bearer IO
socketAsBearer SDUSize
sduSize Int
batchSize Maybe (ReadBuffer IO)
readBuffer_m DiffTime
sduTimeout DiffTime
pollInterval Tracer IO Trace
tracer Socket
sd =
      Mx.Bearer {
        read :: TimeoutFn IO -> IO (SDU, Time)
Mx.read           = TimeoutFn IO -> IO (SDU, Time)
readSocket,
        write :: TimeoutFn IO -> SDU -> IO Time
Mx.write          = TimeoutFn IO -> SDU -> IO Time
writeSocket,
        writeMany :: TimeoutFn IO -> [SDU] -> IO Time
Mx.writeMany      = TimeoutFn IO -> [SDU] -> IO Time
writeSocketMany,
        sduSize :: SDUSize
Mx.sduSize        = SDUSize
sduSize,
        batchSize :: Int
Mx.batchSize      = Int
batchSize,
        name :: String
Mx.name           = String
"socket-bearer",
        egressInterval :: DiffTime
Mx.egressInterval = DiffTime
pollInterval
      }
    where
      readSocket :: Mx.TimeoutFn IO -> IO (Mx.SDU, Time)
      readSocket :: TimeoutFn IO -> IO (SDU, Time)
readSocket TimeoutFn IO
timeout = do
          Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceRecvHeaderStart

          -- Wait for the first part of the header without any timeout
          h0 <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
True Int64
Mx.msHeaderLength

          -- Optionally wait at most sduTimeout seconds for the complete SDU.
          r_m <- timeout sduTimeout $ recvRem h0
          case r_m of
                Maybe (SDU, Time)
Nothing -> do
                    Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSDUReadTimeoutException
                    Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (Error -> IO (SDU, Time)) -> Error -> IO (SDU, Time)
forall a b. (a -> b) -> a -> b
$ Error
Mx.SDUReadTimeout
                Just (SDU, Time)
r -> (SDU, Time) -> IO (SDU, Time)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SDU, Time)
r

      recvRem :: BL.ByteString -> IO (Mx.SDU, Time)
      recvRem :: ByteString -> IO (SDU, Time)
recvRem !ByteString
h0 = do
          hbuf <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Int64
Mx.msHeaderLength Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
BL.length ByteString
h0) [ByteString
h0]
          case Mx.decodeSDU hbuf of
               Left  Error
e ->  Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
e
               Right header :: SDU
header@Mx.SDU { SDUHeader
msHeader :: SDUHeader
msHeader :: SDU -> SDUHeader
Mx.msHeader } -> do
                   Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer (Trace -> IO ()) -> Trace -> IO ()
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Trace
Mx.TraceRecvHeaderEnd SDUHeader
msHeader
                   !blob <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Word16 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int64) -> Word16 -> Int64
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Word16
Mx.mhLength SDUHeader
msHeader) []

                   !ts <- getMonotonicTime
                   let !header' = SDU
header {Mx.msBlob = blob}
                   traceWith tracer (Mx.TraceRecvDeltaQObservation msHeader ts)
                   return (header', ts)

      recvLen' ::  Int64 -> [BL.ByteString] -> IO BL.ByteString
      recvLen' :: Int64 -> [ByteString] -> IO ByteString
recvLen' Int64
0 [ByteString]
bufs = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bufs
      recvLen' Int64
l [ByteString]
bufs = do
          buf <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
False Int64
l
          recvLen' (l - BL.length buf) (buf : bufs)

      recvAtMost :: Bool -> Int64 -> IO BL.ByteString
      recvAtMost :: Bool -> Int64 -> IO ByteString
recvAtMost Bool
waitingOnNxtHeader Int64
l = do
          Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer (Trace -> IO ()) -> Trace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Trace
Mx.TraceRecvStart (Int -> Trace) -> Int -> Trace
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
l

          case Maybe (ReadBuffer IO)
readBuffer_m of
               Maybe (ReadBuffer IO)
Nothing -> -- No read buffer available; read directly from socket
                   Bool -> Int64 -> IO ByteString
recvFromSocket Bool
waitingOnNxtHeader Int64
l
               Just Mx.ReadBuffer{StrictTVar IO ByteString
rbVar :: StrictTVar IO ByteString
rbVar :: forall (m :: * -> *). ReadBuffer m -> StrictTVar m ByteString
Mx.rbVar, Int
rbSize :: Int
rbSize :: forall (m :: * -> *). ReadBuffer m -> Int
Mx.rbSize} -> do
                   availableData <- STM IO ByteString -> IO ByteString
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO ByteString -> IO ByteString)
-> STM IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
                       buf <- StrictTVar IO ByteString -> STM IO ByteString
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar IO ByteString
rbVar
                       if BL.length buf >= l
                          then do
                            let (toProcess, remaining) = BL.splitAt l buf
                            writeTVar rbVar remaining
                            return toProcess
                          else do
                            writeTVar rbVar BL.empty
                            return buf

                   if BL.null availableData
                      then do
#if !defined(mingw32_HOST_OS)
                        -- Not data in buffer; read more from socket
                        when (not waitingOnNxtHeader) $
                          -- Don't let the kernel wake us up until there is
                          -- at least l bytes of data.
                          Socket.setSocketOption sd Socket.RecvLowWater $ fromIntegral l
#endif
                        newBuf <- recvFromSocket waitingOnNxtHeader $ fromIntegral rbSize
                        atomically $ modifyTVar rbVar (`BL.append` newBuf)
#if !defined(mingw32_HOST_OS)
                        when (not waitingOnNxtHeader) $
                          Socket.setSocketOption sd Socket.RecvLowWater 1
#endif
                        recvAtMost waitingOnNxtHeader l
                      else do
                        traceWith tracer $ Mx.TraceRecvEnd $ fromIntegral $ BL.length availableData
                        return availableData
#if !defined(mingw32_HOST_OS)
      -- Read at most `min rbSize maxLen` bytes from the socket
      -- into rbBuf.
      -- Creates and returns a Bytestring matching the exact size
      -- of the number of bytes read.
      recvBuf :: Mx.ReadBuffer IO -> Int64 -> IO BL.ByteString
      recvBuf :: ReadBuffer IO -> Int64 -> IO ByteString
recvBuf Mx.ReadBuffer{Ptr Word8
rbBuf :: Ptr Word8
rbBuf :: forall (m :: * -> *). ReadBuffer m -> Ptr Word8
Mx.rbBuf, Int
rbSize :: forall (m :: * -> *). ReadBuffer m -> Int
rbSize :: Int
Mx.rbSize} Int64
maxLen = do
        len <- Socket -> Ptr Word8 -> Int -> IO Int
Socket.recvBuf Socket
sd Ptr Word8
rbBuf (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
rbSize (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
maxLen)
        traceWith tracer $ Mx.TraceRecvRaw len
        if len > 0
           then do
             bs <- create len (\Ptr Word8
dest -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
dest Ptr Word8
rbBuf Int
len)
             return $ BL.fromStrict bs
           else return $ BL.empty
#endif

      recvFromSocket :: Bool -> Int64 -> IO BL.ByteString
      recvFromSocket :: Bool -> Int64 -> IO ByteString
recvFromSocket Bool
waitingOnNxtHeader Int64
l = do
#if defined(mingw32_HOST_OS)
          buf <- Win32.Async.recv sd (fromIntegral l)
#else
          buf <- (case Maybe (ReadBuffer IO)
readBuffer_m of
                      Maybe (ReadBuffer IO)
Nothing         -> Socket -> Int64 -> IO ByteString
Socket.recv Socket
sd Int64
l
                      Just ReadBuffer IO
readBuffer -> ReadBuffer IO -> Int64 -> IO ByteString
recvBuf ReadBuffer IO
readBuffer Int64
l
                     )
#endif
                    IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"recv errored"
          if BL.null buf
              then do
                  when waitingOnNxtHeader $
                      {- This may not be an error, but could be an orderly shutdown.
                       - We wait 1 seconds to give the mux protocols time to perform
                       - a clean up and exit.
                       -}
                      threadDelay 1
                  throwIO $ Mx.BearerClosed (show sd ++
                      " closed when reading data, waiting on next header " ++
                      show waitingOnNxtHeader)
              else return buf

      writeSocket :: Mx.TimeoutFn IO -> Mx.SDU -> IO Time
      writeSocket :: TimeoutFn IO -> SDU -> IO Time
writeSocket TimeoutFn IO
timeout SDU
sdu = do
          ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
          let ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
              sdu' = SDU -> RemoteClockModel -> SDU
Mx.setTimestamp SDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32)
              buf  = SDU -> ByteString
Mx.encodeSDU SDU
sdu'
          traceWith tracer $ Mx.TraceSendStart (Mx.msHeader sdu')
          r <- timeout sduTimeout $
#if defined(mingw32_HOST_OS)
              Win32.Async.sendAll sd buf
#else
              Socket.sendAll sd buf
#endif
              `catch` Mx.handleIOException "sendAll errored"
          case r of
               Maybe ()
Nothing -> do
                    Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSDUWriteTimeoutException
                    Error -> IO Time
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUWriteTimeout
               Just ()
_ -> do
                   Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
                   -- If it was possible to detect if the TraceTCPInfo was
                   -- enable we wouldn't have to hide the getSockOpt
                   -- syscall in this ifdef. Instead we would only call it if
                   -- we knew that the information would be traced.
                   tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
                   traceWith tracer $ Mx.TraceTCPInfo tcpi (Mx.mhLength $ Mx.msHeader sdu)
#endif
                   Time -> IO Time
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts

      writeSocketMany :: Mx.TimeoutFn IO -> [Mx.SDU] -> IO Time
#if defined(mingw32_HOST_OS)
      writeSocketMany timeout sdus = do
        ts <- getMonotonicTime
        mapM_ (writeSocket timeout) sdus
        return ts
#else
      writeSocketMany :: TimeoutFn IO -> [SDU] -> IO Time
writeSocketMany TimeoutFn IO
timeout [SDU]
sdus = do
          ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
          let ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
              buf  = (SDU -> ByteString) -> [SDU] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (SDU -> ByteString
Mx.encodeSDU (SDU -> ByteString) -> (SDU -> SDU) -> SDU -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                           (\SDU
sdu -> SDU -> RemoteClockModel -> SDU
Mx.setTimestamp SDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32))) [SDU]
sdus
          r <- timeout ((fromIntegral $ length sdus) * sduTimeout) $
              Socket.sendMany sd (concatMap BL.toChunks buf)
              `catch` Mx.handleIOException "sendAll errored"
          case r of
               Maybe ()
Nothing -> do
                   Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSDUWriteTimeoutException
                   Error -> IO Time
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUWriteTimeout
               Just ()
_ -> do
                   Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
                   -- If it was possible to detect if the TraceTCPInfo was
                   -- enable we wouldn't have to hide the getSockOpt
                   -- syscall in this ifdef. Instead we would only call it if
                   -- we knew that the information would be traced.
                   tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
                   traceWith tracer $ Mx.TraceTCPInfo tcpi (sum $ map (Mx.mhLength . Mx.msHeader) sdus)
#endif
                   Time -> IO Time
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts
#endif