{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns        #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeFamilies          #-}

module Network.Mux.Ingress
  ( -- $ingress
    demuxer
  ) where

import Data.Array
import Data.ByteString.Lazy qualified as BL
import Data.List (nub)

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTimer.SI hiding (timeout)

import Network.Mux.Timeout
import Network.Mux.Trace
import Network.Mux.Types as Mx


flipMiniProtocolDir :: MiniProtocolDir -> MiniProtocolDir
flipMiniProtocolDir :: MiniProtocolDir -> MiniProtocolDir
flipMiniProtocolDir MiniProtocolDir
InitiatorDir = MiniProtocolDir
ResponderDir
flipMiniProtocolDir MiniProtocolDir
ResponderDir = MiniProtocolDir
InitiatorDir

-- $ingress
-- = Ingress Path
--
-- >                  ●
-- >                  │
-- >                  │ ByteStrings
-- >                  │
-- >         ░░░░░░░░░▼░░░░░░░░░
-- >         ░┌───────────────┐░
-- >         ░│ Bearer.read() │░ Mux Bearer implementation (Socket, Pipes, etc.)
-- >         ░└───────────────┘░
-- >         ░░░░░░░░░│░░░░░░░░░
-- >                 ░│░         SDUs
-- >         ░░░░░░░░░▼░░░░░░░░░
-- >         ░┌───────────────┐░
-- >         ░│     demux     │░ For a given Mux Bearer there is a single demux
-- >         ░└───────┬───────┘░ thread reading from the underlying bearer.
-- >         ░░░░░░░░░│░░░░░░░░░
-- >                 ░│░
-- >        ░░░░░░░░░░▼░░░░░░░░░░
-- >        ░ ╭────┬────┬─────╮ ░ There is a limited queue (in bytes) for each mode
-- >        ░ │    │    │     │ ░ (responder/initiator) per miniprotocol. Overflowing
-- >        ░ ▼    ▼    ▼     ▼ ░ a queue is a protocol violation and a
-- >        ░│  │ │  │ │  │ │  │░ IngressQueueOverRun exception is thrown
-- >        ░│ci│ │  │ │bi│ │br│░ and the bearer torn down.
-- >        ░│ci│ │cr│ │bi│ │br│░
-- >        ░└──┘ └──┘ └──┘ └──┘░ Every ingress queue has a dedicated thread which will
-- >        ░░│░░░░│░░░░│░░░░│░░░ read application encoded data from its queue.
-- >          │    │    │    │
-- >           application data
-- >          │    │    │    │
-- >          ▼    │    │    ▼
-- > ┌───────────┐ │    │  ┌───────────┐
-- > │ muxDuplex │ │    │  │ muxDuplex │
-- > │ Initiator │ │    │  │ Responder │
-- > │ ChainSync │ │    │  │ BlockFetch│
-- > └───────────┘ │    │  └───────────┘
-- >               ▼    ▼
-- >    ┌───────────┐  ┌───────────┐
-- >    │ muxDuplex │  │ muxDuplex │
-- >    │ Responder │  │ Initiator │
-- >    │ ChainSync │  │ BlockFetch│
-- >    └───────────┘  └───────────┘

-- | Each peer's multiplexer has some state that provides both
-- de-multiplexing details (for despatch of incoming messages to mini
-- protocols) and for dispatching incoming SDUs.  This is shared
-- between the muxIngress and the bearerIngress processes.
--
data MiniProtocolDispatch m =
     MiniProtocolDispatch
       !(Array MiniProtocolNum (Maybe MiniProtocolIx))
       !(Array (MiniProtocolIx, MiniProtocolDir)
               (MiniProtocolDispatchInfo m))

data MiniProtocolDispatchInfo m =
     MiniProtocolDispatchInfo
       !(IngressQueue m)
       !Int
   | MiniProtocolDirUnused


-- | demux runs as a single separate thread and reads complete 'SDU's from
-- the underlying Bearer and forwards it to the matching ingress queue.
demuxer :: (MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
            MonadTimer m)
      => [MiniProtocolState mode m]
      -> Bearer m
      -> m void
demuxer :: forall (m :: * -> *) (mode :: Mode) void.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
 MonadTimer m) =>
[MiniProtocolState mode m] -> Bearer m -> m void
demuxer [MiniProtocolState mode m]
ptcls Bearer m
bearer =
  let !dispatchTable :: MiniProtocolDispatch m
dispatchTable = [MiniProtocolState mode m] -> MiniProtocolDispatch m
forall (mode :: Mode) (m :: * -> *).
[MiniProtocolState mode m] -> MiniProtocolDispatch m
setupDispatchTable [MiniProtocolState mode m]
ptcls in
  (TimeoutFn m -> m void) -> m void
forall (m :: * -> *) b.
(MonadAsync m, MonadFork m, MonadMonotonicTime m, MonadTimer m,
 MonadMask m, MonadThrow (STM m)) =>
(TimeoutFn m -> m b) -> m b
withTimeoutSerial ((TimeoutFn m -> m void) -> m void)
-> (TimeoutFn m -> m void) -> m void
forall a b. (a -> b) -> a -> b
$ \TimeoutFn m
timeout ->
  m () -> m void
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m void) -> m () -> m void
forall a b. (a -> b) -> a -> b
$ do
    (sdu, _) <- Bearer m -> TimeoutFn m -> m (SDU, Time)
forall (m :: * -> *). Bearer m -> TimeoutFn m -> m (SDU, Time)
Mx.read Bearer m
bearer DiffTime -> m a -> m (Maybe a)
TimeoutFn m
timeout
    -- say $ printf "demuxing sdu on mid %s mode %s lenght %d " (show $ msId sdu) (show $ msDir sdu)
    --             (BL.length $ msBlob sdu)
    case lookupMiniProtocol dispatchTable (msNum sdu)
                            -- Notice the mode reversal, ResponderDir is
                            -- delivered to InitiatorDir and vice versa:
                            (flipMiniProtocolDir $ msDir sdu) of
      Maybe (MiniProtocolDispatchInfo m)
Nothing   -> Error -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MiniProtocolNum -> Error
UnknownMiniProtocol (SDU -> MiniProtocolNum
msNum SDU
sdu))
      Just MiniProtocolDispatchInfo m
MiniProtocolDirUnused ->
                   Error -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MiniProtocolNum -> Error
InitiatorOnly (SDU -> MiniProtocolNum
msNum SDU
sdu))
      Just (MiniProtocolDispatchInfo IngressQueue m
q Int
qMax) ->
        STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          buf <- IngressQueue m -> STM m ByteString
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar IngressQueue m
q
          if BL.length buf + BL.length (msBlob sdu) <= fromIntegral qMax
              then writeTVar q $ BL.append buf (msBlob sdu)
              else throwSTM $ IngressQueueOverRun (msNum sdu) (msDir sdu)

lookupMiniProtocol :: MiniProtocolDispatch m
                   -> MiniProtocolNum
                   -> MiniProtocolDir
                   -> Maybe (MiniProtocolDispatchInfo m)
lookupMiniProtocol :: forall (m :: * -> *).
MiniProtocolDispatch m
-> MiniProtocolNum
-> MiniProtocolDir
-> Maybe (MiniProtocolDispatchInfo m)
lookupMiniProtocol (MiniProtocolDispatch Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray) MiniProtocolNum
pnum MiniProtocolDir
pdir
  | (MiniProtocolNum, MiniProtocolNum) -> MiniProtocolNum -> Bool
forall a. Ix a => (a, a) -> a -> Bool
inRange (Array MiniProtocolNum (Maybe MiniProtocolIx)
-> (MiniProtocolNum, MiniProtocolNum)
forall i e. Array i e -> (i, i)
bounds Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray) MiniProtocolNum
pnum
  , Just MiniProtocolIx
mpid <- Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array MiniProtocolNum (Maybe MiniProtocolIx)
-> MiniProtocolNum -> Maybe MiniProtocolIx
forall i e. Ix i => Array i e -> i -> e
! MiniProtocolNum
pnum = MiniProtocolDispatchInfo m -> Maybe (MiniProtocolDispatchInfo m)
forall a. a -> Maybe a
Just (Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
-> (MiniProtocolIx, MiniProtocolDir) -> MiniProtocolDispatchInfo m
forall i e. Ix i => Array i e -> i -> e
! (MiniProtocolIx
mpid, MiniProtocolDir
pdir))
  | Bool
otherwise                   = Maybe (MiniProtocolDispatchInfo m)
forall a. Maybe a
Nothing

-- | Construct the table that maps 'MiniProtocolNum' and 'MiniProtocolDir' to
-- 'MiniProtocolDispatchInfo'. Use 'lookupMiniProtocol' to index it.
--
setupDispatchTable :: forall mode m.
                      [MiniProtocolState mode m] -> MiniProtocolDispatch m
setupDispatchTable :: forall (mode :: Mode) (m :: * -> *).
[MiniProtocolState mode m] -> MiniProtocolDispatch m
setupDispatchTable [MiniProtocolState mode m]
ptcls =
    Array MiniProtocolNum (Maybe MiniProtocolIx)
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
-> MiniProtocolDispatch m
forall (m :: * -> *).
Array MiniProtocolNum (Maybe MiniProtocolIx)
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
-> MiniProtocolDispatch m
MiniProtocolDispatch Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray
  where
    -- The 'MiniProtocolNum' space is sparse but we don't want a huge single
    -- table if we use large protocol numbers. So we use a two level mapping.
    --
    -- The first array maps 'MiniProtocolNum' to a dense space of intermediate
    -- integer indexes. These indexes are meaningless outside of the context of
    -- this table. Then we use the index and the 'MiniProtocolDir' for the
    -- second table.
    --
    pnumArray :: Array MiniProtocolNum (Maybe MiniProtocolIx)
    pnumArray :: Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray =
      (MiniProtocolNum, MiniProtocolNum)
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> Array MiniProtocolNum (Maybe MiniProtocolIx)
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (MiniProtocolNum
minpnum, MiniProtocolNum
maxpnum) ([(MiniProtocolNum, Maybe MiniProtocolIx)]
 -> Array MiniProtocolNum (Maybe MiniProtocolIx))
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> Array MiniProtocolNum (Maybe MiniProtocolIx)
forall a b. (a -> b) -> a -> b
$
            -- Fill in Nothing first to cover any unused ones.
            [ (MiniProtocolNum
pnum, Maybe MiniProtocolIx
forall a. Maybe a
Nothing)    | MiniProtocolNum
pnum <- [MiniProtocolNum
minpnum..MiniProtocolNum
maxpnum] ]

            -- And override with the ones actually used.
         [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
forall a. [a] -> [a] -> [a]
++ [ (MiniProtocolNum
pnum, MiniProtocolIx -> Maybe MiniProtocolIx
forall a. a -> Maybe a
Just MiniProtocolIx
pix)   | (MiniProtocolNum
pnum, MiniProtocolIx
pix) <- [MiniProtocolNum]
-> [MiniProtocolIx] -> [(MiniProtocolNum, MiniProtocolIx)]
forall a b. [a] -> [b] -> [(a, b)]
zip [MiniProtocolNum]
pnums [MiniProtocolIx
0..] ]

    ptclArray :: Array (MiniProtocolIx, MiniProtocolDir)
                       (MiniProtocolDispatchInfo m)
    ptclArray :: Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray =
      ((MiniProtocolIx, MiniProtocolDir),
 (MiniProtocolIx, MiniProtocolDir))
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array ((MiniProtocolIx
minpix, MiniProtocolDir
InitiatorDir), (MiniProtocolIx
maxpix, MiniProtocolDir
ResponderDir)) ([((MiniProtocolIx, MiniProtocolDir), MiniProtocolDispatchInfo m)]
 -> Array
      (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m))
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
forall a b. (a -> b) -> a -> b
$
            -- Fill in MiniProtocolDirUnused first to cover any unused ones.
            [ ((MiniProtocolIx
pix, MiniProtocolDir
dir), MiniProtocolDispatchInfo m
forall (m :: * -> *). MiniProtocolDispatchInfo m
MiniProtocolDirUnused)
            | (MiniProtocolIx
pix, MiniProtocolDir
dir) <- ((MiniProtocolIx, MiniProtocolDir),
 (MiniProtocolIx, MiniProtocolDir))
-> [(MiniProtocolIx, MiniProtocolDir)]
forall a. Ix a => (a, a) -> [a]
range ((MiniProtocolIx
minpix, MiniProtocolDir
InitiatorDir),
                                   (MiniProtocolIx
maxpix, MiniProtocolDir
ResponderDir)) ]

             -- And override with the ones actually used.
         [((MiniProtocolIx, MiniProtocolDir), MiniProtocolDispatchInfo m)]
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
forall a. [a] -> [a] -> [a]
++ [ ((MiniProtocolIx
pix, MiniProtocolDir
dir), IngressQueue m -> Int -> MiniProtocolDispatchInfo m
forall (m :: * -> *).
IngressQueue m -> Int -> MiniProtocolDispatchInfo m
MiniProtocolDispatchInfo IngressQueue m
q Int
qMax)
            | MiniProtocolState {
                miniProtocolInfo :: forall (mode :: Mode) (m :: * -> *).
MiniProtocolState mode m -> MiniProtocolInfo mode
miniProtocolInfo =
                  MiniProtocolInfo {
                    MiniProtocolNum
miniProtocolNum :: MiniProtocolNum
miniProtocolNum :: forall (mode :: Mode). MiniProtocolInfo mode -> MiniProtocolNum
miniProtocolNum,
                    MiniProtocolDirection mode
miniProtocolDir :: MiniProtocolDirection mode
miniProtocolDir :: forall (mode :: Mode).
MiniProtocolInfo mode -> MiniProtocolDirection mode
miniProtocolDir,
                    MiniProtocolLimits
miniProtocolLimits :: MiniProtocolLimits
miniProtocolLimits :: forall (mode :: Mode). MiniProtocolInfo mode -> MiniProtocolLimits
miniProtocolLimits
                  },
                miniProtocolIngressQueue :: forall (mode :: Mode) (m :: * -> *).
MiniProtocolState mode m -> IngressQueue m
miniProtocolIngressQueue = IngressQueue m
q
              } <- [MiniProtocolState mode m]
ptcls
            , let pix :: MiniProtocolIx
pix  =
                   case Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array MiniProtocolNum (Maybe MiniProtocolIx)
-> MiniProtocolNum -> Maybe MiniProtocolIx
forall i e. Ix i => Array i e -> i -> e
! MiniProtocolNum
miniProtocolNum of
                     Just MiniProtocolIx
a  -> MiniProtocolIx
a
                     -- This error is impossible to trigger - note that
                     -- `pnumArray` is constructed to ensure that every
                     -- `miniProtocolNum` in `ptcls` indexes to a `Just` value.
                     Maybe MiniProtocolIx
Nothing -> [Char] -> MiniProtocolIx
forall a. HasCallStack => [Char] -> a
error ([Char]
"setupDispatchTable: impossible: missing " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ MiniProtocolNum -> [Char]
forall a. Show a => a -> [Char]
show MiniProtocolNum
miniProtocolNum)
                  dir :: MiniProtocolDir
dir      = MiniProtocolDirection mode -> MiniProtocolDir
forall (mode :: Mode).
MiniProtocolDirection mode -> MiniProtocolDir
protocolDirEnum MiniProtocolDirection mode
miniProtocolDir
                  qMax :: Int
qMax     = MiniProtocolLimits -> Int
maximumIngressQueue MiniProtocolLimits
miniProtocolLimits
            ]

    -- The protocol numbers actually used, in the order of the first use within
    -- the 'ptcls' list. The order does not matter provided we do it
    -- consistently between the two arrays.
    pnums :: [MiniProtocolNum]
pnums   = [MiniProtocolNum] -> [MiniProtocolNum]
forall a. Eq a => [a] -> [a]
nub ([MiniProtocolNum] -> [MiniProtocolNum])
-> [MiniProtocolNum] -> [MiniProtocolNum]
forall a b. (a -> b) -> a -> b
$ (MiniProtocolState mode m -> MiniProtocolNum)
-> [MiniProtocolState mode m] -> [MiniProtocolNum]
forall a b. (a -> b) -> [a] -> [b]
map (MiniProtocolInfo mode -> MiniProtocolNum
forall (mode :: Mode). MiniProtocolInfo mode -> MiniProtocolNum
miniProtocolNum (MiniProtocolInfo mode -> MiniProtocolNum)
-> (MiniProtocolState mode m -> MiniProtocolInfo mode)
-> MiniProtocolState mode m
-> MiniProtocolNum
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MiniProtocolState mode m -> MiniProtocolInfo mode
forall (mode :: Mode) (m :: * -> *).
MiniProtocolState mode m -> MiniProtocolInfo mode
miniProtocolInfo) [MiniProtocolState mode m]
ptcls

    -- The dense range of indexes of used protocol numbers.
    minpix, maxpix :: MiniProtocolIx
    minpix :: MiniProtocolIx
minpix  = MiniProtocolIx
0
    maxpix :: MiniProtocolIx
maxpix  = Int -> MiniProtocolIx
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([MiniProtocolNum] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [MiniProtocolNum]
pnums Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

    -- The sparse range of protocol numbers
    minpnum, maxpnum :: MiniProtocolNum
    minpnum :: MiniProtocolNum
minpnum = [MiniProtocolNum] -> MiniProtocolNum
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [MiniProtocolNum]
pnums
    maxpnum :: MiniProtocolNum
maxpnum = [MiniProtocolNum] -> MiniProtocolNum
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [MiniProtocolNum]
pnums