{-# LANGUAGE CPP                    #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE GADTs                  #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE NumericUnderscores     #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE UndecidableInstances   #-}

module Network.Mux.Bearer
  ( Bearer (..)
  , MakeBearer (..)
  , makeSocketBearer
  , makeSocketBearer'
  , makePipeChannelBearer
  , makeQueueChannelBearer
#if defined(mingw32_HOST_OS)
  , makeNamedPipeBearer
#endif
  , withReadBufferIO
  ) where

import           Control.Monad.Class.MonadSTM
import           Control.Concurrent.Class.MonadSTM.Strict
import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime.SI
import           Control.Tracer (Tracer)

import           Data.ByteString.Lazy qualified as BL
import           Network.Socket (Socket)
#if defined(mingw32_HOST_OS)
import           System.Win32 (HANDLE)
#endif
import           Foreign.Marshal.Alloc

import           Network.Mux.Bearer.Pipe
import           Network.Mux.Bearer.Queues
import           Network.Mux.Bearer.Socket
import           Network.Mux.Trace
import           Network.Mux.Types hiding (sduSize)
#if defined(mingw32_HOST_OS)
import           Network.Mux.Bearer.NamedPipe
#endif

newtype MakeBearer m fd = MakeBearer {
    forall (m :: * -> *) fd.
MakeBearer m fd
-> DiffTime
-> Tracer m Trace
-> fd
-> Maybe (ReadBuffer m)
-> m (Bearer m)
getBearer
      :: DiffTime
      -- timeout for reading an SDU segment, if negative no
      -- timeout is applied.
      -> Tracer m Trace
      -- tracer
      -> fd
      -- file descriptor
      -> Maybe (ReadBuffer m)
      -- Optional Readbuffer
      -> m (Bearer m)
  }

pureBearer :: Applicative m
           => (DiffTime -> Tracer m Trace -> fd -> Maybe (ReadBuffer m) ->   Bearer m)
           ->  DiffTime -> Tracer m Trace -> fd -> Maybe (ReadBuffer m) -> m (Bearer m)
pureBearer :: forall (m :: * -> *) fd.
Applicative m =>
(DiffTime
 -> Tracer m Trace -> fd -> Maybe (ReadBuffer m) -> Bearer m)
-> DiffTime
-> Tracer m Trace
-> fd
-> Maybe (ReadBuffer m)
-> m (Bearer m)
pureBearer DiffTime
-> Tracer m Trace -> fd -> Maybe (ReadBuffer m) -> Bearer m
f = \DiffTime
sduTimeout Tracer m Trace
rb fd
tr Maybe (ReadBuffer m)
fd -> Bearer m -> m (Bearer m)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DiffTime
-> Tracer m Trace -> fd -> Maybe (ReadBuffer m) -> Bearer m
f DiffTime
sduTimeout Tracer m Trace
rb fd
tr Maybe (ReadBuffer m)
fd)


makeSocketBearer :: MakeBearer IO Socket
makeSocketBearer :: MakeBearer IO Socket
makeSocketBearer = DiffTime -> MakeBearer IO Socket
makeSocketBearer' DiffTime
0

makeSocketBearer' :: DiffTime -> MakeBearer IO Socket
makeSocketBearer' :: DiffTime -> MakeBearer IO Socket
makeSocketBearer' DiffTime
pt = (DiffTime
 -> Tracer IO Trace
 -> Socket
 -> Maybe (ReadBuffer IO)
 -> IO (Bearer IO))
-> MakeBearer IO Socket
forall (m :: * -> *) fd.
(DiffTime
 -> Tracer m Trace -> fd -> Maybe (ReadBuffer m) -> m (Bearer m))
-> MakeBearer m fd
MakeBearer ((DiffTime
  -> Tracer IO Trace
  -> Socket
  -> Maybe (ReadBuffer IO)
  -> IO (Bearer IO))
 -> MakeBearer IO Socket)
-> (DiffTime
    -> Tracer IO Trace
    -> Socket
    -> Maybe (ReadBuffer IO)
    -> IO (Bearer IO))
-> MakeBearer IO Socket
forall a b. (a -> b) -> a -> b
$ (\DiffTime
sduTimeout Tracer IO Trace
tr Socket
fd Maybe (ReadBuffer IO)
rb -> do
    Bearer IO -> IO (Bearer IO)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bearer IO -> IO (Bearer IO)) -> Bearer IO -> IO (Bearer IO)
forall a b. (a -> b) -> a -> b
$ SDUSize
-> Int
-> Maybe (ReadBuffer IO)
-> DiffTime
-> DiffTime
-> Tracer IO Trace
-> Socket
-> Bearer IO
socketAsBearer SDUSize
size Int
batch Maybe (ReadBuffer IO)
rb DiffTime
sduTimeout DiffTime
pt Tracer IO Trace
tr Socket
fd)
  where
    size :: SDUSize
size = Word16 -> SDUSize
SDUSize Word16
12_288
    batch :: Int
batch = Int
131_072

withReadBufferIO :: (Maybe (ReadBuffer IO) -> IO b)
                 -> IO b
withReadBufferIO :: forall b. (Maybe (ReadBuffer IO) -> IO b) -> IO b
withReadBufferIO Maybe (ReadBuffer IO) -> IO b
f = Int -> Int -> (Ptr Word8 -> IO b) -> IO b
forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned Int
size Int
8 ((Ptr Word8 -> IO b) -> IO b) -> (Ptr Word8 -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
    v <- STM IO (StrictTVar IO ByteString) -> IO (StrictTVar IO ByteString)
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO (StrictTVar IO ByteString)
 -> IO (StrictTVar IO ByteString))
-> STM IO (StrictTVar IO ByteString)
-> IO (StrictTVar IO ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> STM IO (StrictTVar IO ByteString)
forall (m :: * -> *) a. MonadSTM m => a -> STM m (StrictTVar m a)
newTVar ByteString
BL.empty
    f $ Just $ ReadBuffer v ptr size
  where
    -- Maximum amount of data read in one call.
    -- Corresponds to the default readbuffer size on Linux.
    -- We want it larger than 64Kbyte, but not too large since
    -- it is a memory overhead per mux bearer in an application.
    size :: Int
size = Int
131_072

makePipeChannelBearer :: MakeBearer IO PipeChannel
makePipeChannelBearer :: MakeBearer IO PipeChannel
makePipeChannelBearer = (DiffTime
 -> Tracer IO Trace
 -> PipeChannel
 -> Maybe (ReadBuffer IO)
 -> IO (Bearer IO))
-> MakeBearer IO PipeChannel
forall (m :: * -> *) fd.
(DiffTime
 -> Tracer m Trace -> fd -> Maybe (ReadBuffer m) -> m (Bearer m))
-> MakeBearer m fd
MakeBearer ((DiffTime
  -> Tracer IO Trace
  -> PipeChannel
  -> Maybe (ReadBuffer IO)
  -> IO (Bearer IO))
 -> MakeBearer IO PipeChannel)
-> (DiffTime
    -> Tracer IO Trace
    -> PipeChannel
    -> Maybe (ReadBuffer IO)
    -> IO (Bearer IO))
-> MakeBearer IO PipeChannel
forall a b. (a -> b) -> a -> b
$ (DiffTime
 -> Tracer IO Trace
 -> PipeChannel
 -> Maybe (ReadBuffer IO)
 -> Bearer IO)
-> DiffTime
-> Tracer IO Trace
-> PipeChannel
-> Maybe (ReadBuffer IO)
-> IO (Bearer IO)
forall (m :: * -> *) fd.
Applicative m =>
(DiffTime
 -> Tracer m Trace -> fd -> Maybe (ReadBuffer m) -> Bearer m)
-> DiffTime
-> Tracer m Trace
-> fd
-> Maybe (ReadBuffer m)
-> m (Bearer m)
pureBearer (\DiffTime
_ Tracer IO Trace
tr PipeChannel
fd Maybe (ReadBuffer IO)
_ -> SDUSize -> Tracer IO Trace -> PipeChannel -> Bearer IO
pipeAsBearer SDUSize
size Tracer IO Trace
tr PipeChannel
fd)
  where
    size :: SDUSize
size = Word16 -> SDUSize
SDUSize Word16
32_768

makeQueueChannelBearer :: ( MonadSTM   m
                          , MonadMonotonicTime m
                          , MonadThrow m
                          )
                       => MakeBearer m (QueueChannel m)
makeQueueChannelBearer :: forall (m :: * -> *).
(MonadSTM m, MonadMonotonicTime m, MonadThrow m) =>
MakeBearer m (QueueChannel m)
makeQueueChannelBearer = (DiffTime
 -> Tracer m Trace
 -> QueueChannel m
 -> Maybe (ReadBuffer m)
 -> m (Bearer m))
-> MakeBearer m (QueueChannel m)
forall (m :: * -> *) fd.
(DiffTime
 -> Tracer m Trace -> fd -> Maybe (ReadBuffer m) -> m (Bearer m))
-> MakeBearer m fd
MakeBearer ((DiffTime
  -> Tracer m Trace
  -> QueueChannel m
  -> Maybe (ReadBuffer m)
  -> m (Bearer m))
 -> MakeBearer m (QueueChannel m))
-> (DiffTime
    -> Tracer m Trace
    -> QueueChannel m
    -> Maybe (ReadBuffer m)
    -> m (Bearer m))
-> MakeBearer m (QueueChannel m)
forall a b. (a -> b) -> a -> b
$ (DiffTime
 -> Tracer m Trace
 -> QueueChannel m
 -> Maybe (ReadBuffer m)
 -> Bearer m)
-> DiffTime
-> Tracer m Trace
-> QueueChannel m
-> Maybe (ReadBuffer m)
-> m (Bearer m)
forall (m :: * -> *) fd.
Applicative m =>
(DiffTime
 -> Tracer m Trace -> fd -> Maybe (ReadBuffer m) -> Bearer m)
-> DiffTime
-> Tracer m Trace
-> fd
-> Maybe (ReadBuffer m)
-> m (Bearer m)
pureBearer (\DiffTime
_ Tracer m Trace
tr QueueChannel m
q Maybe (ReadBuffer m)
_ -> SDUSize -> Tracer m Trace -> QueueChannel m -> Bearer m
forall (m :: * -> *).
(MonadSTM m, MonadMonotonicTime m, MonadThrow m) =>
SDUSize -> Tracer m Trace -> QueueChannel m -> Bearer m
queueChannelAsBearer SDUSize
size Tracer m Trace
tr QueueChannel m
q)
  where
    size :: SDUSize
size = Word16 -> SDUSize
SDUSize Word16
1_280

#if defined(mingw32_HOST_OS)
makeNamedPipeBearer :: MakeBearer IO HANDLE
makeNamedPipeBearer = MakeBearer $ pureBearer (\_ tr fd _ -> namedPipeAsBearer size tr fd)
  where
    size = SDUSize 24_576
#endif