{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.Mux.Bearer.Socket (socketAsBearer) where

import Control.Monad (when)
import Control.Tracer
import Data.ByteString.Lazy qualified as BL
import Data.Int

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI hiding (timeout)

import Network.Socket qualified as Socket
#if !defined(mingw32_HOST_OS)
import Data.ByteString.Internal (create)
import Foreign.Marshal.Utils
import Network.Socket.ByteString qualified as Socket (sendMany)
import Network.Socket.ByteString.Lazy qualified as Socket (recv, sendAll)
#else
import System.Win32.Async.Socket.ByteString.Lazy qualified as Win32.Async
#endif

import Network.Mux.Codec qualified as Mx
import Network.Mux.Time qualified as Mx
import Network.Mux.Timeout qualified as Mx
import Network.Mux.Trace qualified as Mx
import Network.Mux.Types (Bearer, BearerTrace)
import Network.Mux.Types qualified as Mx
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
import Network.Mux.TCPInfo (SocketOption (TCPInfoSocketOption))
#endif

-- |
-- Create @'MuxBearer'@ from a socket.
--
-- On Windows 'System.Win32.Async` operations are used to read and write from
-- a socket.  This means that the socket must be associated with the I/O
-- completion port with
-- 'System.Win32.Async.IOManager.associateWithIOCompletionPort'.
--
-- Note: 'IOException's thrown by 'sendAll' and 'recv' are wrapped in
-- 'MuxError'.
--
socketAsBearer
  :: Mx.SDUSize
  -> Int
  -> Maybe (Mx.ReadBuffer IO)
  -> DiffTime -- ^ SDU timeout
  -> DiffTime -- ^ egress poll interval
  -> Socket.Socket
  -> Bearer IO
socketAsBearer :: SDUSize
-> Int
-> Maybe (ReadBuffer IO)
-> DiffTime
-> DiffTime
-> Socket
-> Bearer IO
socketAsBearer SDUSize
sduSize Int
batchSize Maybe (ReadBuffer IO)
readBuffer_m DiffTime
sduTimeout DiffTime
egressInterval Socket
sd =
      Mx.Bearer {
        read :: Tracer IO BearerTrace -> TimeoutFn IO -> IO (SDU, Time)
Mx.read           = Tracer IO BearerTrace -> TimeoutFn IO -> IO (SDU, Time)
readSocket,
        write :: Tracer IO BearerTrace -> TimeoutFn IO -> SDU -> IO Time
Mx.write          = Tracer IO BearerTrace -> TimeoutFn IO -> SDU -> IO Time
writeSocket,
        writeMany :: Tracer IO BearerTrace -> TimeoutFn IO -> [SDU] -> IO Time
Mx.writeMany      = Tracer IO BearerTrace -> TimeoutFn IO -> [SDU] -> IO Time
writeSocketMany,
        sduSize :: SDUSize
Mx.sduSize        = SDUSize
sduSize,
        batchSize :: Int
Mx.batchSize      = Int
batchSize,
        name :: String
Mx.name           = String
"socket-bearer",
        DiffTime
egressInterval :: DiffTime
egressInterval :: DiffTime
Mx.egressInterval
      }
    where
      readSocket :: Tracer IO BearerTrace -> Mx.TimeoutFn IO -> IO (Mx.SDU, Time)
      readSocket :: Tracer IO BearerTrace -> TimeoutFn IO -> IO (SDU, Time)
readSocket Tracer IO BearerTrace
tracer TimeoutFn IO
timeout = do
          Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceRecvHeaderStart

          -- Wait for the first part of the header without any timeout
          h0 <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
True Int64
Mx.msHeaderLength

          -- Optionally wait at most sduTimeout seconds for the complete SDU.
          r_m <- timeout sduTimeout $ recvRem h0
          case r_m of
                Maybe (SDU, Time)
Nothing -> do
                    Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceSDUReadTimeoutException
                    Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUReadTimeout
                Just (SDU, Time)
r -> (SDU, Time) -> IO (SDU, Time)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SDU, Time)
r
        where
          recvRem :: BL.ByteString -> IO (Mx.SDU, Time)
          recvRem :: ByteString -> IO (SDU, Time)
recvRem !ByteString
h0 = do
              hbuf <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Int64
Mx.msHeaderLength Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
BL.length ByteString
h0) [ByteString
h0]
              case Mx.decodeSDU hbuf of
                   Left  Error
e ->  Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
e
                   Right header :: SDU
header@Mx.SDU { SDUHeader
msHeader :: SDUHeader
msHeader :: SDU -> SDUHeader
Mx.msHeader } -> do
                       Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer (BearerTrace -> IO ()) -> BearerTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ SDUHeader -> BearerTrace
Mx.TraceRecvHeaderEnd SDUHeader
msHeader
                       !blob <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Word16 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int64) -> Word16 -> Int64
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Word16
Mx.mhLength SDUHeader
msHeader) []

                       !ts <- getMonotonicTime
                       let !header' = SDU
header {Mx.msBlob = blob}
                       traceWith tracer (Mx.TraceRecvDeltaQObservation msHeader ts)
                       return (header', ts)

          recvLen' ::  Int64 -> [BL.ByteString] -> IO BL.ByteString
          recvLen' :: Int64 -> [ByteString] -> IO ByteString
recvLen' Int64
0 [ByteString]
bufs = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bufs
          recvLen' Int64
l [ByteString]
bufs = do
              buf <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
False Int64
l
              recvLen' (l - BL.length buf) (buf : bufs)

          recvAtMost :: Bool -> Int64 -> IO BL.ByteString
          recvAtMost :: Bool -> Int64 -> IO ByteString
recvAtMost Bool
waitingOnNxtHeader Int64
l = do
              Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer (BearerTrace -> IO ()) -> BearerTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> BearerTrace
Mx.TraceRecvStart (Int -> BearerTrace) -> Int -> BearerTrace
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
l

              case Maybe (ReadBuffer IO)
readBuffer_m of
                   Maybe (ReadBuffer IO)
Nothing -> -- No read buffer available; read directly from socket
                       Bool -> Int64 -> IO ByteString
recvFromSocket Bool
waitingOnNxtHeader Int64
l
                   Just Mx.ReadBuffer{StrictTVar IO ByteString
rbVar :: StrictTVar IO ByteString
rbVar :: forall (m :: * -> *). ReadBuffer m -> StrictTVar m ByteString
Mx.rbVar, Int
rbSize :: Int
rbSize :: forall (m :: * -> *). ReadBuffer m -> Int
Mx.rbSize} -> do
                       availableData <- STM IO ByteString -> IO ByteString
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO ByteString -> IO ByteString)
-> STM IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
                           buf <- StrictTVar IO ByteString -> STM IO ByteString
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar IO ByteString
rbVar
                           if BL.length buf >= l
                              then do
                                let (toProcess, remaining) = BL.splitAt l buf
                                writeTVar rbVar remaining
                                return toProcess
                              else do
                                writeTVar rbVar BL.empty
                                return buf

                       if BL.null availableData
                          then do
#if !defined(mingw32_HOST_OS)
                            -- Not data in buffer; read more from socket
                            when (not waitingOnNxtHeader) $
                              -- Don't let the kernel wake us up until there is
                              -- at least l bytes of data.
                              Socket.setSocketOption sd Socket.RecvLowWater $ fromIntegral l
#endif
                            newBuf <- recvFromSocket waitingOnNxtHeader $ fromIntegral rbSize
                            atomically $ modifyTVar rbVar (`BL.append` newBuf)
#if !defined(mingw32_HOST_OS)
                            when (not waitingOnNxtHeader) $
                              Socket.setSocketOption sd Socket.RecvLowWater 1
#endif
                            recvAtMost waitingOnNxtHeader l
                          else do
                            traceWith tracer $ Mx.TraceRecvEnd $ fromIntegral $ BL.length availableData
                            return availableData
#if !defined(mingw32_HOST_OS)
          -- Read at most `min rbSize maxLen` bytes from the socket
          -- into rbBuf.
          -- Creates and returns a Bytestring matching the exact size
          -- of the number of bytes read.
          recvBuf :: Mx.ReadBuffer IO -> Int64 -> IO BL.ByteString
          recvBuf :: ReadBuffer IO -> Int64 -> IO ByteString
recvBuf Mx.ReadBuffer{Ptr Word8
rbBuf :: Ptr Word8
rbBuf :: forall (m :: * -> *). ReadBuffer m -> Ptr Word8
Mx.rbBuf, Int
rbSize :: forall (m :: * -> *). ReadBuffer m -> Int
rbSize :: Int
Mx.rbSize} Int64
maxLen = do
            len <- Socket -> Ptr Word8 -> Int -> IO Int
Socket.recvBuf Socket
sd Ptr Word8
rbBuf (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
rbSize (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
maxLen)
            traceWith tracer $ Mx.TraceRecvRaw len
            if len > 0
               then do
                 bs <- create len (\Ptr Word8
dest -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
dest Ptr Word8
rbBuf Int
len)
                 return $ BL.fromStrict bs
               else return $ BL.empty
#endif

          recvFromSocket :: Bool -> Int64 -> IO BL.ByteString
          recvFromSocket :: Bool -> Int64 -> IO ByteString
recvFromSocket Bool
waitingOnNxtHeader Int64
l = do
#if defined(mingw32_HOST_OS)
              buf <- Win32.Async.recv sd (fromIntegral l)
#else
              buf <- (case Maybe (ReadBuffer IO)
readBuffer_m of
                          Maybe (ReadBuffer IO)
Nothing         -> Socket -> Int64 -> IO ByteString
Socket.recv Socket
sd Int64
l
                          Just ReadBuffer IO
readBuffer -> ReadBuffer IO -> Int64 -> IO ByteString
recvBuf ReadBuffer IO
readBuffer Int64
l
                         )
#endif
                        IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"recv errored"
              if BL.null buf
                  then do
                      when waitingOnNxtHeader $
                          {- This may not be an error, but could be an orderly shutdown.
                           - We wait 1 seconds to give the mux protocols time to perform
                           - a clean up and exit.
                           -}
                          threadDelay 1
                      throwIO $ Mx.BearerClosed (show sd ++
                          " closed when reading data, waiting on next header " ++
                          show waitingOnNxtHeader)
                  else return buf

      writeSocket :: Tracer IO BearerTrace -> Mx.TimeoutFn IO -> Mx.SDU -> IO Time
      writeSocket :: Tracer IO BearerTrace -> TimeoutFn IO -> SDU -> IO Time
writeSocket Tracer IO BearerTrace
tracer TimeoutFn IO
timeout SDU
sdu = do
          ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
          let ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
              sdu' = SDU -> RemoteClockModel -> SDU
Mx.setTimestamp SDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32)
              buf  = SDU -> ByteString
Mx.encodeSDU SDU
sdu'
          traceWith tracer $ Mx.TraceSendStart (Mx.msHeader sdu')
          r <- timeout sduTimeout $
#if defined(mingw32_HOST_OS)
              Win32.Async.sendAll sd buf
#else
              Socket.sendAll sd buf
#endif
              `catch` Mx.handleIOException "sendAll errored"
          case r of
               Maybe ()
Nothing -> do
                    Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceSDUWriteTimeoutException
                    Error -> IO Time
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUWriteTimeout
               Just ()
_ -> do
                   Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
                   -- If it was possible to detect if the TraceTCPInfo was
                   -- enable we wouldn't have to hide the getSockOpt
                   -- syscall in this ifdef. Instead we would only call it if
                   -- we knew that the information would be traced.
                   tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
                   traceWith tracer $ Mx.TraceTCPInfo tcpi (Mx.mhLength $ Mx.msHeader sdu)
#endif
                   Time -> IO Time
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts

      writeSocketMany :: Tracer IO BearerTrace -> Mx.TimeoutFn IO -> [Mx.SDU] -> IO Time
#if defined(mingw32_HOST_OS)
      writeSocketMany tracer timeout sdus = do
        ts <- getMonotonicTime
        mapM_ (writeSocket tracer timeout) sdus
        return ts
#else
      writeSocketMany :: Tracer IO BearerTrace -> TimeoutFn IO -> [SDU] -> IO Time
writeSocketMany Tracer IO BearerTrace
tracer TimeoutFn IO
timeout [SDU]
sdus = do
          ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
          let ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
              buf  = (SDU -> ByteString) -> [SDU] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (SDU -> ByteString
Mx.encodeSDU (SDU -> ByteString) -> (SDU -> SDU) -> SDU -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                           (\SDU
sdu -> SDU -> RemoteClockModel -> SDU
Mx.setTimestamp SDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32))) [SDU]
sdus
          r <- timeout (fromIntegral (length sdus) * sduTimeout) $
              Socket.sendMany sd (concatMap BL.toChunks buf)
              `catch` Mx.handleIOException "sendAll errored"
          case r of
               Maybe ()
Nothing -> do
                   Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceSDUWriteTimeoutException
                   Error -> IO Time
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUWriteTimeout
               Just ()
_ -> do
                   Tracer IO BearerTrace -> BearerTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO BearerTrace
tracer BearerTrace
Mx.TraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
                   -- If it was possible to detect if the TraceTCPInfo was
                   -- enable we wouldn't have to hide the getSockOpt
                   -- syscall in this ifdef. Instead we would only call it if
                   -- we knew that the information would be traced.
                   tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
                   traceWith tracer $ Mx.TraceTCPInfo tcpi (sum $ map (Mx.mhLength . Mx.msHeader) sdus)
#endif
                   Time -> IO Time
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts
#endif