{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | An extension of 'Network.TypedProtocol.Channel', with additional 'Channel'
-- implementations.
--
module Network.Mux.Channel
  ( -- * Channel
    Channel (..)
    -- ** Channel API
  , isoKleisliChannel
  , hoistChannel
  , channelEffect
  , delayChannel
  , loggingChannel
    -- ** create a `Channel`
  , mvarsAsChannel
    -- **  connected `Channel`s
  , createConnectedChannels
    -- * `ByteChannel`
  , ByteChannel
    -- ** create a `ByteChannel`
  , handlesAsChannel
  , withFifosAsChannel
  , socketAsChannel
    -- ** connected `ByteChannel`s
  , createBufferConnectedChannels
  , createPipeConnectedChannels
#if !defined(mingw32_HOST_OS)
  , createSocketConnectedChannels
#endif
  ) where

import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as LBS
import Data.ByteString.Lazy.Internal qualified as LBS (smallChunkSize)
import Network.Socket qualified as Socket
import Network.Socket.ByteString qualified as Socket
import System.IO qualified as IO (Handle, IOMode (..), hFlush, hIsEOF, withFile)
import System.Process qualified as IO (createPipe)

import Control.Concurrent.Class.MonadSTM
import Control.Concurrent.Class.MonadSTM.Strict qualified as StrictSTM
import Control.Monad ((>=>))
import Control.Monad.Class.MonadSay
import Control.Monad.Class.MonadTimer.SI


-- | A channel which can send and receive values.
--
-- It is more general than what `network-mux` requires, see `ByteChannel`
-- instead.  However this is useful for testing purposes when one is either
-- using `mux` or connecting two ends directly.
--
data Channel m a = Channel {

    -- | Write bytes to the channel.
    --
    -- It maybe raise exceptions.
    --
    forall (m :: * -> *) a. Channel m a -> a -> m ()
send :: a -> m (),

    -- | Read some input from the channel, or @Nothing@ to indicate EOF.
    --
    -- Note that having received EOF it is still possible to send.
    -- The EOF condition is however monotonic.
    --
    -- It may raise exceptions (as appropriate for the monad and kind of
    -- channel).
    --
    forall (m :: * -> *) a. Channel m a -> m (Maybe a)
recv :: m (Maybe a)
  }

-- | Given an isomorphism between @a@ and @b@ (in Kleisli category), transform
-- a @'Channel' m a@ into @'Channel' m b@.
--
isoKleisliChannel
  :: forall a b m. Monad m
  => (a -> m b)
  -> (b -> m a)
  -> Channel m a
  -> Channel m b
isoKleisliChannel :: forall a b (m :: * -> *).
Monad m =>
(a -> m b) -> (b -> m a) -> Channel m a -> Channel m b
isoKleisliChannel a -> m b
f b -> m a
finv Channel{a -> m ()
send :: forall (m :: * -> *) a. Channel m a -> a -> m ()
send :: a -> m ()
send, m (Maybe a)
recv :: forall (m :: * -> *) a. Channel m a -> m (Maybe a)
recv :: m (Maybe a)
recv} = Channel {
    send :: b -> m ()
send = b -> m a
finv (b -> m a) -> (a -> m ()) -> b -> m ()
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> a -> m ()
send,
    recv :: m (Maybe b)
recv = m (Maybe a)
recv m (Maybe a) -> (Maybe a -> m (Maybe b)) -> m (Maybe b)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (a -> m b) -> Maybe a -> m (Maybe b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Maybe a -> f (Maybe b)
traverse a -> m b
f
  }


hoistChannel
  :: (forall x . m x -> n x)
  -> Channel m a
  -> Channel n a
hoistChannel :: forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> Channel m a -> Channel n a
hoistChannel forall x. m x -> n x
nat Channel m a
channel = Channel
  { send :: a -> n ()
send = m () -> n ()
forall x. m x -> n x
nat (m () -> n ()) -> (a -> m ()) -> a -> n ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m a -> a -> m ()
forall (m :: * -> *) a. Channel m a -> a -> m ()
send Channel m a
channel
  , recv :: n (Maybe a)
recv = m (Maybe a) -> n (Maybe a)
forall x. m x -> n x
nat (Channel m a -> m (Maybe a)
forall (m :: * -> *) a. Channel m a -> m (Maybe a)
recv Channel m a
channel)
  }

channelEffect :: forall m a.
                 Monad m
              => (a -> m ())       -- ^ Action before 'send'
              -> (Maybe a -> m ()) -- ^ Action after 'recv'
              -> Channel m a
              -> Channel m a
channelEffect :: forall (m :: * -> *) a.
Monad m =>
(a -> m ()) -> (Maybe a -> m ()) -> Channel m a -> Channel m a
channelEffect a -> m ()
beforeSend Maybe a -> m ()
afterRecv Channel{a -> m ()
send :: forall (m :: * -> *) a. Channel m a -> a -> m ()
send :: a -> m ()
send, m (Maybe a)
recv :: forall (m :: * -> *) a. Channel m a -> m (Maybe a)
recv :: m (Maybe a)
recv} =
    Channel{
      send :: a -> m ()
send = \a
x -> do
        a -> m ()
beforeSend a
x
        a -> m ()
send a
x

    , recv :: m (Maybe a)
recv = do
        mx <- m (Maybe a)
recv
        afterRecv mx
        return mx
    }

-- | Delay a channel on the receiver end.
--
-- This is intended for testing, as a crude approximation of network delays.
-- More accurate models along these lines are of course possible.
--
delayChannel :: MonadDelay m
             => DiffTime
             -> Channel m a
             -> Channel m a
delayChannel :: forall (m :: * -> *) a.
MonadDelay m =>
DiffTime -> Channel m a -> Channel m a
delayChannel DiffTime
delay = (a -> m ()) -> (Maybe a -> m ()) -> Channel m a -> Channel m a
forall (m :: * -> *) a.
Monad m =>
(a -> m ()) -> (Maybe a -> m ()) -> Channel m a -> Channel m a
channelEffect (\a
_ -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
                                   (\Maybe a
_ -> DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
delay)

-- | Channel which logs sent and received messages.
--
loggingChannel :: ( MonadSay m
                  , Show id
                  , Show a
                  )
               => id
               -> Channel m a
               -> Channel m a
loggingChannel :: forall (m :: * -> *) id a.
(MonadSay m, Show id, Show a) =>
id -> Channel m a -> Channel m a
loggingChannel id
ident Channel{a -> m ()
send :: forall (m :: * -> *) a. Channel m a -> a -> m ()
send :: a -> m ()
send,m (Maybe a)
recv :: forall (m :: * -> *) a. Channel m a -> m (Maybe a)
recv :: m (Maybe a)
recv} =
    Channel {
      send :: a -> m ()
send = a -> m ()
loggingSend,
      recv :: m (Maybe a)
recv = m (Maybe a)
loggingRecv
    }
  where
    loggingSend :: a -> m ()
loggingSend a
a = do
      String -> m ()
forall (m :: * -> *). MonadSay m => String -> m ()
say (id -> String
forall a. Show a => a -> String
show id
ident String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
":send:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
a)
      a -> m ()
send a
a
  
    loggingRecv :: m (Maybe a)
loggingRecv = do
      msg <- m (Maybe a)
recv
      case msg of
        Maybe a
Nothing -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just a
a  -> String -> m ()
forall (m :: * -> *). MonadSay m => String -> m ()
say (id -> String
forall a. Show a => a -> String
show id
ident String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
":recv:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
a)
      return msg


-- | Make a 'Channel' from a pair of 'TMVar's, one for reading and one for
-- writing.
--
mvarsAsChannel :: MonadSTM m
               => StrictSTM.StrictTMVar m a
               -> StrictSTM.StrictTMVar m a
               -> Channel m a
mvarsAsChannel :: forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> StrictTMVar m a -> Channel m a
mvarsAsChannel StrictTMVar m a
bufferRead StrictTMVar m a
bufferWrite =
    Channel{a -> m ()
send :: a -> m ()
send :: a -> m ()
send, m (Maybe a)
recv :: m (Maybe a)
recv :: m (Maybe a)
recv}
  where
    send :: a -> m ()
send a
x = STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (StrictTMVar m a -> a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> a -> STM m ()
StrictSTM.putTMVar StrictTMVar m a
bufferWrite a
x)
    recv :: m (Maybe a)
recv   = STM m (Maybe a) -> m (Maybe a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> STM m a -> STM m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTMVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => StrictTMVar m a -> STM m a
StrictSTM.takeTMVar StrictTMVar m a
bufferRead)


-- | Create a pair of channels that are connected via one-place buffers.
--
-- This is primarily useful for testing protocols.
--
createConnectedChannels :: MonadSTM m => m (Channel m a, Channel m a)
createConnectedChannels :: forall (m :: * -> *) a. MonadSTM m => m (Channel m a, Channel m a)
createConnectedChannels = do
    -- Create two TMVars to act as the channel buffer (one for each direction)
    -- and use them to make both ends of a bidirectional channel
    bufferA <- m (StrictTMVar m a)
forall (m :: * -> *) a. MonadSTM m => m (StrictTMVar m a)
StrictSTM.newEmptyTMVarIO
    bufferB <- StrictSTM.newEmptyTMVarIO

    return (mvarsAsChannel bufferB bufferA,
            mvarsAsChannel bufferA bufferB)

--
-- ByteChannel
--

-- | Channel using `LBS.ByteString`.
--
type ByteChannel m = Channel m LBS.ByteString


-- | Make a 'Channel' from a pair of IO 'Handle's, one for reading and one
-- for writing.
--
-- The Handles should be open in the appropriate read or write mode, and in
-- binary mode. Writes are flushed after each write, so it is safe to use
-- a buffering mode.
--
-- For bidirectional handles it is safe to pass the same handle for both.
--
handlesAsChannel :: IO.Handle -- ^ Read handle
                 -> IO.Handle -- ^ Write handle
                 -> Channel IO LBS.ByteString
handlesAsChannel :: Handle -> Handle -> Channel IO ByteString
handlesAsChannel Handle
hndRead Handle
hndWrite =
    Channel{ByteString -> IO ()
send :: ByteString -> IO ()
send :: ByteString -> IO ()
send, IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv}
  where
    send :: LBS.ByteString -> IO ()
    send :: ByteString -> IO ()
send ByteString
chunk = do
      Handle -> ByteString -> IO ()
LBS.hPut Handle
hndWrite ByteString
chunk
      Handle -> IO ()
IO.hFlush Handle
hndWrite

    recv :: IO (Maybe LBS.ByteString)
    recv :: IO (Maybe ByteString)
recv = do
      eof <- Handle -> IO Bool
IO.hIsEOF Handle
hndRead
      if eof
        then return Nothing
        else Just . LBS.fromStrict <$> BS.hGetSome hndRead LBS.smallChunkSize

-- | Create a pair of 'Channel's that are connected internally.
--
-- This is intended for inter-thread communication, such as between a
-- multiplexing thread and a thread running a peer.
--
-- It uses lazy 'ByteString's but it ensures that data written to the channel
-- is /fully evaluated/ first. This ensures that any work to serialise the data
-- takes place on the /writer side and not the reader side/.
--
createBufferConnectedChannels :: forall m. MonadSTM m
                              => m (ByteChannel m,
                                    ByteChannel m)
createBufferConnectedChannels :: forall (m :: * -> *).
MonadSTM m =>
m (ByteChannel m, ByteChannel m)
createBufferConnectedChannels = do
    bufferA <- m (TMVar m StrictByteString)
forall a. m (TMVar m a)
forall (m :: * -> *) a. MonadSTM m => m (TMVar m a)
newEmptyTMVarIO
    bufferB <- newEmptyTMVarIO

    return (buffersAsChannel bufferB bufferA,
            buffersAsChannel bufferA bufferB)
  where
    buffersAsChannel :: TMVar m StrictByteString
-> TMVar m StrictByteString -> ByteChannel m
buffersAsChannel TMVar m StrictByteString
bufferRead TMVar m StrictByteString
bufferWrite =
        Channel{ByteString -> m ()
send :: ByteString -> m ()
send :: ByteString -> m ()
send, m (Maybe ByteString)
recv :: m (Maybe ByteString)
recv :: m (Maybe ByteString)
recv}
      where
        send :: LBS.ByteString -> m ()
        send :: ByteString -> m ()
send ByteString
x = [m ()] -> m ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TMVar m StrictByteString -> StrictByteString -> STM m ()
forall a. TMVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TMVar m a -> a -> STM m ()
putTMVar TMVar m StrictByteString
bufferWrite StrictByteString
c)
                           | !StrictByteString
c <- ByteString -> [StrictByteString]
LBS.toChunks ByteString
x ]
                           -- Evaluate the chunk c /before/ doing the STM
                           -- transaction to write it to the buffer.

        recv :: m (Maybe LBS.ByteString)
        recv :: m (Maybe ByteString)
recv   = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> (StrictByteString -> ByteString)
-> StrictByteString
-> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrictByteString -> ByteString
LBS.fromStrict (StrictByteString -> Maybe ByteString)
-> m StrictByteString -> m (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM m StrictByteString -> m StrictByteString
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TMVar m StrictByteString -> STM m StrictByteString
forall a. TMVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TMVar m a -> STM m a
takeTMVar TMVar m StrictByteString
bufferRead)


-- | Create a local pipe, with both ends in this process, and expose that as
-- a pair of 'Channel's, one for each end.
--
-- This is primarily for testing purposes since it does not allow actual IPC.
--
createPipeConnectedChannels :: IO (ByteChannel IO,
                                   ByteChannel IO)
createPipeConnectedChannels :: IO (Channel IO ByteString, Channel IO ByteString)
createPipeConnectedChannels = do
    -- Create two pipes (each one is unidirectional) to make both ends of
    -- a bidirectional channel
    (hndReadA, hndWriteB) <- IO (Handle, Handle)
IO.createPipe
    (hndReadB, hndWriteA) <- IO.createPipe

    return (handlesAsChannel hndReadA hndWriteA,
            handlesAsChannel hndReadB hndWriteB)

-- | Open a pair of Unix FIFOs, and expose that as a 'Channel'.
--
-- The peer process needs to open the same files but the other way around,
-- for writing and reading.
--
-- This is primarily for the purpose of demonstrations that use communication
-- between multiple local processes. It is Unix specific.
--
withFifosAsChannel :: FilePath -- ^ FIFO for reading
                   -> FilePath -- ^ FIFO for writing
                   -> (ByteChannel IO -> IO a) -> IO a
withFifosAsChannel :: forall a.
String -> String -> (Channel IO ByteString -> IO a) -> IO a
withFifosAsChannel String
fifoPathRead String
fifoPathWrite Channel IO ByteString -> IO a
action =
    String -> IOMode -> (Handle -> IO a) -> IO a
forall r. String -> IOMode -> (Handle -> IO r) -> IO r
IO.withFile String
fifoPathRead  IOMode
IO.ReadMode  ((Handle -> IO a) -> IO a) -> (Handle -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Handle
hndRead  ->
    String -> IOMode -> (Handle -> IO a) -> IO a
forall r. String -> IOMode -> (Handle -> IO r) -> IO r
IO.withFile String
fifoPathWrite IOMode
IO.WriteMode ((Handle -> IO a) -> IO a) -> (Handle -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Handle
hndWrite ->
      let channel :: Channel IO ByteString
channel = Handle -> Handle -> Channel IO ByteString
handlesAsChannel Handle
hndRead Handle
hndWrite
       in Channel IO ByteString -> IO a
action Channel IO ByteString
channel


-- | Make a 'Channel' from a 'Socket'. The socket must be a stream socket
--- type and status connected.
---
socketAsChannel :: Socket.Socket -> ByteChannel IO
socketAsChannel :: Socket -> Channel IO ByteString
socketAsChannel Socket
socket =
    Channel{ByteString -> IO ()
send :: ByteString -> IO ()
send :: ByteString -> IO ()
send, IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv}
  where
    send :: LBS.ByteString -> IO ()
    send :: ByteString -> IO ()
send ByteString
chunks =
     -- Use vectored writes.
     Socket -> [StrictByteString] -> IO ()
Socket.sendMany Socket
socket (ByteString -> [StrictByteString]
LBS.toChunks ByteString
chunks)
     -- TODO: limit write sizes, or break them into multiple sends.

    recv :: IO (Maybe LBS.ByteString)
    recv :: IO (Maybe ByteString)
recv = do
      -- We rely on the behaviour of stream sockets that a zero length chunk
      -- indicates EOF.
      chunk <- Socket -> Int -> IO StrictByteString
Socket.recv Socket
socket Int
LBS.smallChunkSize
      if BS.null chunk
        then return Nothing
        else return (Just (LBS.fromStrict chunk))

#if !defined(mingw32_HOST_OS)
--- | Create a local socket, with both ends in this process, and expose that as
--- a pair of 'ByteChannel's, one for each end.
---
--- This is primarily for testing purposes since it does not allow actual IPC.
---
createSocketConnectedChannels :: Socket.Family -- ^ Usually AF_UNIX or AF_INET
                              -> IO (ByteChannel IO,
                                     ByteChannel IO)
createSocketConnectedChannels :: Family -> IO (Channel IO ByteString, Channel IO ByteString)
createSocketConnectedChannels Family
family = do
   -- Create a socket pair to make both ends of a bidirectional channel
   (socketA, socketB) <- Family -> SocketType -> ProtocolNumber -> IO (Socket, Socket)
Socket.socketPair Family
family SocketType
Socket.Stream
                                           ProtocolNumber
Socket.defaultProtocol

   return (socketAsChannel socketA,
           socketAsChannel socketB)
#endif