{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.Mux.Bearer.Socket (socketAsBearer) where

import Control.Monad (when)
import Control.Tracer
import Data.ByteString.Lazy qualified as BL
import Data.Int

import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI hiding (timeout)

import Network.Socket qualified as Socket
#if !defined(mingw32_HOST_OS)
import Network.Socket.ByteString.Lazy qualified as Socket (recv, sendAll)
#else
import System.Win32.Async.Socket.ByteString.Lazy qualified as Win32.Async
#endif

import Network.Mux.Codec qualified as Mx
import Network.Mux.Time qualified as Mx
import Network.Mux.Timeout qualified as Mx
import Network.Mux.Trace qualified as Mx
import Network.Mux.Types (Bearer)
import Network.Mux.Types qualified as Mx
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
import Network.Mux.TCPInfo (SocketOption (TCPInfoSocketOption))
#endif

-- |
-- Create @'MuxBearer'@ from a socket.
--
-- On Windows 'System.Win32.Async` operations are used to read and write from
-- a socket.  This means that the socket must be associated with the I/O
-- completion port with
-- 'System.Win32.Async.IOManager.associateWithIOCompletionPort'.
--
-- Note: 'IOException's thrown by 'sendAll' and 'recv' are wrapped in
-- 'MuxError'.
--
socketAsBearer
  :: Mx.SDUSize
  -> DiffTime
  -> Tracer IO Mx.Trace
  -> Socket.Socket
  -> Bearer IO
socketAsBearer :: SDUSize -> DiffTime -> Tracer IO Trace -> Socket -> Bearer IO
socketAsBearer SDUSize
sduSize DiffTime
sduTimeout Tracer IO Trace
tracer Socket
sd =
      Mx.Bearer {
        read :: TimeoutFn IO -> IO (SDU, Time)
Mx.read    = TimeoutFn IO -> IO (SDU, Time)
readSocket,
        write :: TimeoutFn IO -> SDU -> IO Time
Mx.write   = TimeoutFn IO -> SDU -> IO Time
writeSocket,
        sduSize :: SDUSize
Mx.sduSize = SDUSize
sduSize,
        name :: String
Mx.name    = String
"socket-bearer"
      }
    where
      hdrLenght :: Int64
hdrLenght = Int64
8

      readSocket :: Mx.TimeoutFn IO -> IO (Mx.SDU, Time)
      readSocket :: TimeoutFn IO -> IO (SDU, Time)
readSocket TimeoutFn IO
timeout = do
          Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceRecvHeaderStart

          -- Wait for the first part of the header without any timeout
          h0 <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
True Int64
hdrLenght

          -- Optionally wait at most sduTimeout seconds for the complete SDU.
          r_m <- timeout sduTimeout $ recvRem h0
          case r_m of
                Maybe (SDU, Time)
Nothing -> do
                    Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSDUReadTimeoutException
                    Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (Error -> IO (SDU, Time)) -> Error -> IO (SDU, Time)
forall a b. (a -> b) -> a -> b
$ Error
Mx.SDUReadTimeout
                Just (SDU, Time)
r -> (SDU, Time) -> IO (SDU, Time)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SDU, Time)
r

      recvRem :: BL.ByteString -> IO (Mx.SDU, Time)
      recvRem :: ByteString -> IO (SDU, Time)
recvRem !ByteString
h0 = do
          hbuf <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Int64
hdrLenght Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- ByteString -> Int64
BL.length ByteString
h0) [ByteString
h0]
          case Mx.decodeSDU hbuf of
               Left  Error
e ->  Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
e
               Right header :: SDU
header@Mx.SDU { SDUHeader
msHeader :: SDUHeader
msHeader :: SDU -> SDUHeader
Mx.msHeader } -> do
                   Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer (Trace -> IO ()) -> Trace -> IO ()
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Trace
Mx.TraceRecvHeaderEnd SDUHeader
msHeader
                   !blob <- Int64 -> [ByteString] -> IO ByteString
recvLen' (Word16 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int64) -> Word16 -> Int64
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Word16
Mx.mhLength SDUHeader
msHeader) []

                   !ts <- getMonotonicTime
                   let !header' = SDU
header {Mx.msBlob = blob}
                   traceWith tracer (Mx.TraceRecvDeltaQObservation msHeader ts)
                   return (header', ts)

      recvLen' ::  Int64 -> [BL.ByteString] -> IO BL.ByteString
      recvLen' :: Int64 -> [ByteString] -> IO ByteString
recvLen' Int64
0 [ByteString]
bufs = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bufs
      recvLen' Int64
l [ByteString]
bufs = do
          buf <- Bool -> Int64 -> IO ByteString
recvAtMost Bool
False Int64
l
          recvLen' (l - BL.length buf) (buf : bufs)

      recvAtMost :: Bool -> Int64 -> IO BL.ByteString
      recvAtMost :: Bool -> Int64 -> IO ByteString
recvAtMost Bool
waitingOnNxtHeader Int64
l = do
          Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer (Trace -> IO ()) -> Trace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Trace
Mx.TraceRecvStart (Int -> Trace) -> Int -> Trace
forall a b. (a -> b) -> a -> b
$ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
l
#if defined(mingw32_HOST_OS)
          buf <- Win32.Async.recv sd (fromIntegral l)
#else
          buf <- Socket -> Int64 -> IO ByteString
Socket.recv Socket
sd Int64
l
#endif
                    IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"recv errored"
          if BL.null buf
              then do
                  when waitingOnNxtHeader $
                      {- This may not be an error, but could be an orderly shutdown.
                       - We wait 1 seconds to give the mux protocols time to perform
                       - a clean up and exit.
                       -}
                      threadDelay 1
                  throwIO $ Mx.BearerClosed (show sd ++
                      " closed when reading data, waiting on next header " ++
                      show waitingOnNxtHeader)
              else do
                  traceWith tracer $ Mx.TraceRecvEnd (fromIntegral $ BL.length buf)
                  return buf

      writeSocket :: Mx.TimeoutFn IO -> Mx.SDU -> IO Time
      writeSocket :: TimeoutFn IO -> SDU -> IO Time
writeSocket TimeoutFn IO
timeout SDU
sdu = do
          ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
          let ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
              sdu' = SDU -> RemoteClockModel -> SDU
Mx.setTimestamp SDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32)
              buf  = SDU -> ByteString
Mx.encodeSDU SDU
sdu'
          traceWith tracer $ Mx.TraceSendStart (Mx.msHeader sdu')
          r <- timeout sduTimeout $
#if defined(mingw32_HOST_OS)
              Win32.Async.sendAll sd buf
#else
              Socket.sendAll sd buf
#endif
              `catch` Mx.handleIOException "sendAll errored"
          case r of
               Maybe ()
Nothing -> do
                    Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSDUWriteTimeoutException
                    Error -> IO Time
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
Mx.SDUWriteTimeout
               Just ()
_ -> do
                   Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceSendEnd
#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO)
                   -- If it was possible to detect if the TraceTCPInfo was
                   -- enable we wouldn't have to hide the getSockOpt
                   -- syscall in this ifdef. Instead we would only call it if
                   -- we knew that the information would be traced.
                   tcpi <- Socket.getSockOpt sd TCPInfoSocketOption
                   traceWith tracer $ Mx.TraceTCPInfo tcpi (Mx.mhLength $ Mx.msHeader sdu)
#endif
                   Time -> IO Time
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts