{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Ouroboros.Network.Channel
  ( Channel (..)
  , module Mx
  , fixedInputChannel
  , createConnectedBufferedChannelsUnbounded
  , createConnectedBufferedChannels
  , createConnectedBufferedChannelsSTM
  , createPipelineTestChannels
  ) where

import Numeric.Natural

import Control.Concurrent.Class.MonadSTM.Strict

import Network.Mux.Channel as Mx


-- | A 'Channel' with a fixed input, and where all output is discarded.
--
-- The input is guaranteed to be supplied via 'read' with the given chunk
-- boundaries.
--
-- This is only useful for testing. In particular the fixed chunk boundaries
-- can be used to test that framing and other codecs work with any possible
-- chunking.
--
fixedInputChannel :: MonadSTM m => [a] -> m (Channel m a)
fixedInputChannel :: forall (m :: * -> *) a. MonadSTM m => [a] -> m (Channel m a)
fixedInputChannel [a]
xs0 = do
    v <- [a] -> m (StrictTVar m [a])
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO [a]
xs0
    return Channel {send, recv = recv v}
  where
    recv :: StrictTVar m [a] -> m (Maybe a)
recv StrictTVar m [a]
v = STM m (Maybe a) -> m (Maybe a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Maybe a) -> m (Maybe a)) -> STM m (Maybe a) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ do
               xs <- StrictTVar m [a] -> STM m [a]
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m [a]
v
               case xs of
                 []      -> Maybe a -> STM m (Maybe a)
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
                 (a
x:[a]
xs') -> StrictTVar m [a] -> [a] -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m [a]
v [a]
xs' STM m () -> STM m (Maybe a) -> STM m (Maybe a)
forall a b. STM m a -> STM m b -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe a -> STM m (Maybe a)
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Maybe a
forall a. a -> Maybe a
Just a
x)

    send :: p -> m ()
send p
_ = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()



-- | Create a pair of channels that are connected via two unbounded buffers.
--
-- This is primarily useful for testing protocols.
--
createConnectedBufferedChannelsUnbounded :: forall m a. MonadSTM m
                                         => m (Channel m a, Channel m a)
createConnectedBufferedChannelsUnbounded :: forall (m :: * -> *) a. MonadSTM m => m (Channel m a, Channel m a)
createConnectedBufferedChannelsUnbounded = do
    -- Create two TQueues to act as the channel buffers (one for each
    -- direction) and use them to make both ends of a bidirectional channel
    bufferA <- STM m (StrictTQueue m a) -> m (StrictTQueue m a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m (StrictTQueue m a)
forall (m :: * -> *) a. MonadSTM m => STM m (StrictTQueue m a)
newTQueue
    bufferB <- atomically newTQueue

    return (queuesAsChannel bufferB bufferA,
            queuesAsChannel bufferA bufferB)
  where
    queuesAsChannel :: StrictTQueue m a -> StrictTQueue m a -> Channel m a
queuesAsChannel StrictTQueue m a
bufferRead StrictTQueue m a
bufferWrite =
        Channel{a -> m ()
send :: a -> m ()
send :: a -> m ()
send, m (Maybe a)
recv :: m (Maybe a)
recv :: m (Maybe a)
recv}
      where
        send :: a -> m ()
send a
x = STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (StrictTQueue m a -> a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTQueue m a -> a -> STM m ()
writeTQueue StrictTQueue m a
bufferWrite a
x)
        recv :: m (Maybe a)
recv   = STM m (Maybe a) -> m (Maybe a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> STM m a -> STM m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTQueue m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => StrictTQueue m a -> STM m a
readTQueue StrictTQueue m a
bufferRead)


-- | Create a pair of channels that are connected via N-place buffers.
--
-- This variant /blocks/ when 'send' would exceed the maximum buffer size.
-- Use this variant when you want the environment rather than the 'Peer' to
-- limit the pipelining.
--
-- This is primarily useful for testing protocols.
--
createConnectedBufferedChannels :: forall m a. MonadLabelledSTM m
                                => Natural -> m (Channel m a, Channel m a)
createConnectedBufferedChannels :: forall (m :: * -> *) a.
MonadLabelledSTM m =>
Natural -> m (Channel m a, Channel m a)
createConnectedBufferedChannels Natural
sz = do
    (chan1, chan2) <- STM m (Channel (STM m) a, Channel (STM m) a)
-> m (Channel (STM m) a, Channel (STM m) a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Channel (STM m) a, Channel (STM m) a)
 -> m (Channel (STM m) a, Channel (STM m) a))
-> STM m (Channel (STM m) a, Channel (STM m) a)
-> m (Channel (STM m) a, Channel (STM m) a)
forall a b. (a -> b) -> a -> b
$ Natural -> STM m (Channel (STM m) a, Channel (STM m) a)
forall (m :: * -> *) a.
MonadLabelledSTM m =>
Natural -> STM m (Channel (STM m) a, Channel (STM m) a)
createConnectedBufferedChannelsSTM Natural
sz
    pure (wrap chan1, wrap chan2)
  where
    wrap :: Channel (STM m) a -> Channel m a
    wrap :: Channel (STM m) a -> Channel m a
wrap Channel{a -> STM m ()
send :: forall (m :: * -> *) a. Channel m a -> a -> m ()
send :: a -> STM m ()
send, STM m (Maybe a)
recv :: forall (m :: * -> *) a. Channel m a -> m (Maybe a)
recv :: STM m (Maybe a)
recv} = Channel
      { send :: a -> m ()
send    = STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> (a -> STM m ()) -> a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> STM m ()
send
      , recv :: m (Maybe a)
recv    = STM m (Maybe a) -> m (Maybe a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m (Maybe a)
recv
      }

-- | As 'createConnectedBufferedChannels', but in 'STM'.
--
-- TODO: it should return a pair of `Channel m a`.
createConnectedBufferedChannelsSTM :: MonadLabelledSTM m
                                   => Natural -> STM m (Channel (STM m) a, Channel (STM m) a)
createConnectedBufferedChannelsSTM :: forall (m :: * -> *) a.
MonadLabelledSTM m =>
Natural -> STM m (Channel (STM m) a, Channel (STM m) a)
createConnectedBufferedChannelsSTM Natural
sz = do
    -- Create two TBQueues to act as the channel buffers (one for each
    -- direction) and use them to make both ends of a bidirectional channel
    bufferA <- Natural -> STM m (StrictTBQueue m a)
forall (m :: * -> *) a.
MonadSTM m =>
Natural -> STM m (StrictTBQueue m a)
newTBQueue Natural
sz
    labelTBQueue bufferA "chann-a"
    bufferB <- newTBQueue sz
    labelTBQueue bufferB "chann-b"

    return (queuesAsChannel bufferB bufferA,
            queuesAsChannel bufferA bufferB)
  where
    queuesAsChannel :: StrictTBQueue m a -> StrictTBQueue m a -> Channel (STM m) a
queuesAsChannel StrictTBQueue m a
bufferRead StrictTBQueue m a
bufferWrite =
        Channel{a -> STM m ()
send :: a -> STM m ()
send :: a -> STM m ()
send, STM m (Maybe a)
recv :: STM m (Maybe a)
recv :: STM m (Maybe a)
recv}
      where
        send :: a -> STM m ()
send a
x  = StrictTBQueue m a -> a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTBQueue m a -> a -> STM m ()
writeTBQueue StrictTBQueue m a
bufferWrite a
x
        recv :: STM m (Maybe a)
recv    = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> STM m a -> STM m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTBQueue m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => StrictTBQueue m a -> STM m a
readTBQueue StrictTBQueue m a
bufferRead


-- | Create a pair of channels that are connected via N-place buffers.
--
-- This variant /fails/ when  'send' would exceed the maximum buffer size.
-- Use this variant when you want the 'PeerPipelined' to limit the pipelining
-- itself, and you want to check that it does not exceed the expected level of
-- pipelining.
--
-- This is primarily useful for testing protocols.
--
createPipelineTestChannels :: MonadSTM m
                           => Natural -> m (Channel m a, Channel m a)
createPipelineTestChannels :: forall (m :: * -> *) a.
MonadSTM m =>
Natural -> m (Channel m a, Channel m a)
createPipelineTestChannels Natural
sz = do
    -- Create two TBQueues to act as the channel buffers (one for each
    -- direction) and use them to make both ends of a bidirectional channel
    bufferA <- STM m (StrictTBQueue m a) -> m (StrictTBQueue m a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (StrictTBQueue m a) -> m (StrictTBQueue m a))
-> STM m (StrictTBQueue m a) -> m (StrictTBQueue m a)
forall a b. (a -> b) -> a -> b
$ Natural -> STM m (StrictTBQueue m a)
forall (m :: * -> *) a.
MonadSTM m =>
Natural -> STM m (StrictTBQueue m a)
newTBQueue Natural
sz
    bufferB <- atomically $ newTBQueue sz

    return (queuesAsChannel bufferB bufferA,
            queuesAsChannel bufferA bufferB)
  where
    queuesAsChannel :: StrictTBQueue m a -> StrictTBQueue m a -> Channel m a
queuesAsChannel StrictTBQueue m a
bufferRead StrictTBQueue m a
bufferWrite =
        Channel{a -> m ()
send :: a -> m ()
send :: a -> m ()
send, m (Maybe a)
recv :: m (Maybe a)
recv :: m (Maybe a)
recv}
      where
        send :: a -> m ()
send a
x = STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                   full <- StrictTBQueue m a -> STM m Bool
forall (m :: * -> *) a.
MonadSTM m =>
StrictTBQueue m a -> STM m Bool
isFullTBQueue StrictTBQueue m a
bufferWrite
                   if full then error failureMsg
                           else writeTBQueue bufferWrite x
        recv :: m (Maybe a)
recv   = STM m (Maybe a) -> m (Maybe a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> STM m a -> STM m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTBQueue m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => StrictTBQueue m a -> STM m a
readTBQueue StrictTBQueue m a
bufferRead)

    failureMsg :: String
failureMsg = String
"createPipelineTestChannels: "
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"maximum pipeline depth exceeded: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Natural -> String
forall a. Show a => a -> String
show Natural
sz