{-# LANGUAGE CPP                 #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE PolyKinds           #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.Mux.Bearer.Pipe (
    PipeChannel (..)
  , pipeChannelFromHandles
#if defined(mingw32_HOST_OS)
  , pipeChannelFromNamedPipe
#endif
  , pipeAsBearer
  ) where

import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime.SI
import           Control.Tracer
import qualified Data.ByteString.Lazy as BL
import           System.IO (Handle, hFlush)

#if defined(mingw32_HOST_OS)
import           Data.Foldable (traverse_)

import qualified System.Win32.Types as Win32 (HANDLE)
import qualified System.Win32.Async as Win32.Async
#endif

import           Network.Mux.Types (Bearer)
import qualified Network.Mux.Types as Mx
import qualified Network.Mux.Trace as Mx
import qualified Network.Mux.Codec as Mx
import qualified Network.Mux.Time as Mx
import qualified Network.Mux.Timeout as Mx


-- | Abstraction over various types of handles.  We provide two instances:
--
--  * based on 'Handle': os independent, but will not work well on Windows,
--  * based on 'Win32.HANDLE': Windows specific.
--
data PipeChannel = PipeChannel {
    PipeChannel -> Int -> IO ByteString
readHandle  :: Int -> IO BL.ByteString,
    PipeChannel -> ByteString -> IO ()
writeHandle :: BL.ByteString -> IO ()
  }

pipeChannelFromHandles :: Handle
                       -- ^ read handle
                       -> Handle
                       -- ^ write handle
                       -> PipeChannel
pipeChannelFromHandles :: Handle -> Handle -> PipeChannel
pipeChannelFromHandles Handle
r Handle
w = PipeChannel {
    readHandle :: Int -> IO ByteString
readHandle  = Handle -> Int -> IO ByteString
BL.hGet Handle
r,
    writeHandle :: ByteString -> IO ()
writeHandle = \ByteString
a -> Handle -> ByteString -> IO ()
BL.hPut Handle
w ByteString
a IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
hFlush Handle
w
  }

#if defined(mingw32_HOST_OS)
-- | Create a 'PipeChannel' from a named pipe.  This allows to emulate
-- anonymous pipes using named pipes on Windows.
--
pipeChannelFromNamedPipe :: Win32.HANDLE
                         -> PipeChannel
pipeChannelFromNamedPipe h = PipeChannel {
      readHandle  = fmap BL.fromStrict . Win32.Async.readHandle h,
      writeHandle = traverse_ (Win32.Async.writeHandle h) . BL.toChunks
    }
#endif

pipeAsBearer
  :: Mx.SDUSize
  -> Tracer IO Mx.Trace
  -> PipeChannel
  -> Bearer IO
pipeAsBearer :: SDUSize -> Tracer IO Trace -> PipeChannel -> Bearer IO
pipeAsBearer SDUSize
sduSize Tracer IO Trace
tracer PipeChannel
channel =
      Mx.Bearer {
          read :: TimeoutFn IO -> IO (SDU, Time)
Mx.read    = TimeoutFn IO -> IO (SDU, Time)
readPipe,
          write :: TimeoutFn IO -> SDU -> IO Time
Mx.write   = TimeoutFn IO -> SDU -> IO Time
writePipe,
          sduSize :: SDUSize
Mx.sduSize = SDUSize
sduSize,
          name :: String
Mx.name    = String
"pipe"
        }
    where
      readPipe :: Mx.TimeoutFn IO -> IO (Mx.SDU, Time)
      readPipe :: TimeoutFn IO -> IO (SDU, Time)
readPipe TimeoutFn IO
_ = do
          Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer Trace
Mx.TraceRecvHeaderStart
          hbuf <- Int -> [ByteString] -> IO ByteString
recvLen' Int
8 []
          case Mx.decodeSDU hbuf of
              Left Error
e -> Error -> IO (SDU, Time)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Error
e
              Right header :: SDU
header@Mx.SDU { SDUHeader
msHeader :: SDUHeader
msHeader :: SDU -> SDUHeader
Mx.msHeader } -> do
                  Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer (Trace -> IO ()) -> Trace -> IO ()
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Trace
Mx.TraceRecvHeaderEnd SDUHeader
msHeader
                  blob <- Int -> [ByteString] -> IO ByteString
recvLen' (Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int) -> Word16 -> Int
forall a b. (a -> b) -> a -> b
$ SDUHeader -> Word16
Mx.mhLength SDUHeader
msHeader) []
                  ts <- getMonotonicTime
                  traceWith tracer (Mx.TraceRecvDeltaQObservation msHeader ts)
                  return (header {Mx.msBlob = blob}, ts)

      recvLen' :: Int -> [BL.ByteString] -> IO BL.ByteString
      recvLen' :: Int -> [ByteString] -> IO ByteString
recvLen' Int
0 [ByteString]
bufs = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bufs
      recvLen' Int
l [ByteString]
bufs = do
          Tracer IO Trace -> Trace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO Trace
tracer (Trace -> IO ()) -> Trace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Trace
Mx.TraceRecvStart Int
l
          buf <- PipeChannel -> Int -> IO ByteString
readHandle PipeChannel
channel Int
l
                    IO ByteString -> (IOException -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` String -> IOException -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"readHandle errored"
          if BL.null buf
              then throwIO $ Mx.BearerClosed "Pipe closed when reading data"
              else do
                  traceWith tracer $ Mx.TraceRecvEnd (fromIntegral $ BL.length buf)
                  recvLen' (l - fromIntegral (BL.length buf)) (buf : bufs)

      writePipe :: Mx.TimeoutFn IO -> Mx.SDU -> IO Time
      writePipe :: TimeoutFn IO -> SDU -> IO Time
writePipe TimeoutFn IO
_ SDU
sdu = do
          ts <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
          let ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
              sdu' = SDU -> RemoteClockModel -> SDU
Mx.setTimestamp SDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32)
              buf  = SDU -> ByteString
Mx.encodeSDU SDU
sdu'
          traceWith tracer $ Mx.TraceSendStart (Mx.msHeader sdu')
          writeHandle channel buf
            `catch` Mx.handleIOException "writeHandle errored"
          traceWith tracer Mx.TraceSendEnd
          return ts