{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns        #-}
{-# LANGUAGE PatternSynonyms       #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeFamilies          #-}

module Network.Mux.Ingress
  ( -- $ingress
    demuxer
  ) where

import Data.Array
import Data.ByteString.Builder.Internal (lazyByteStringInsert,
           lazyByteStringThreshold)
import Data.ByteString.Lazy qualified as BL
import Data.List (nub)
import Data.Strict.Tuple (pattern (:!:))

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTimer.SI hiding (timeout)
import Control.Tracer (Tracer)

import Network.Mux.Timeout
import Network.Mux.Trace
import Network.Mux.Types as Mx


flipMiniProtocolDir :: MiniProtocolDir -> MiniProtocolDir
flipMiniProtocolDir :: MiniProtocolDir -> MiniProtocolDir
flipMiniProtocolDir MiniProtocolDir
InitiatorDir = MiniProtocolDir
ResponderDir
flipMiniProtocolDir MiniProtocolDir
ResponderDir = MiniProtocolDir
InitiatorDir

-- $ingress
-- = Ingress Path
--
-- >                  ●
-- >                  │
-- >                  │ ByteStrings
-- >                  │
-- >         ░░░░░░░░░▼░░░░░░░░░
-- >         ░┌───────────────┐░
-- >         ░│ Bearer.read() │░ Mux Bearer implementation (Socket, Pipes, etc.)
-- >         ░└───────────────┘░
-- >         ░░░░░░░░░│░░░░░░░░░
-- >                 ░│░         SDUs
-- >         ░░░░░░░░░▼░░░░░░░░░
-- >         ░┌───────────────┐░
-- >         ░│     demux     │░ For a given Mux Bearer there is a single demux
-- >         ░└───────┬───────┘░ thread reading from the underlying bearer.
-- >         ░░░░░░░░░│░░░░░░░░░
-- >                 ░│░
-- >        ░░░░░░░░░░▼░░░░░░░░░░
-- >        ░ ╭────┬────┬─────╮ ░ There is a limited queue (in bytes) for each mode
-- >        ░ │    │    │     │ ░ (responder/initiator) per miniprotocol. Overflowing
-- >        ░ ▼    ▼    ▼     ▼ ░ a queue is a protocol violation and a
-- >        ░│  │ │  │ │  │ │  │░ IngressQueueOverRun exception is thrown
-- >        ░│ci│ │  │ │bi│ │br│░ and the bearer torn down.
-- >        ░│ci│ │cr│ │bi│ │br│░
-- >        ░└──┘ └──┘ └──┘ └──┘░ Every ingress queue has a dedicated thread which will
-- >        ░░│░░░░│░░░░│░░░░│░░░ read application encoded data from its queue.
-- >          │    │    │    │
-- >           application data
-- >          │    │    │    │
-- >          ▼    │    │    ▼
-- > ┌───────────┐ │    │  ┌───────────┐
-- > │ muxDuplex │ │    │  │ muxDuplex │
-- > │ Initiator │ │    │  │ Responder │
-- > │ ChainSync │ │    │  │ BlockFetch│
-- > └───────────┘ │    │  └───────────┘
-- >               ▼    ▼
-- >    ┌───────────┐  ┌───────────┐
-- >    │ muxDuplex │  │ muxDuplex │
-- >    │ Responder │  │ Initiator │
-- >    │ ChainSync │  │ BlockFetch│
-- >    └───────────┘  └───────────┘

-- | Each peer's multiplexer has some state that provides both
-- de-multiplexing details (for despatch of incoming messages to mini
-- protocols) and for dispatching incoming SDUs.  This is shared
-- between the muxIngress and the bearerIngress processes.
--
data MiniProtocolDispatch m =
     MiniProtocolDispatch
       !(Array MiniProtocolNum (Maybe MiniProtocolIx))
       !(Array (MiniProtocolIx, MiniProtocolDir)
               (MiniProtocolDispatchInfo m))

data MiniProtocolDispatchInfo m =
     MiniProtocolDispatchInfo
       !(IngressQueue m)
       !Int
   | MiniProtocolDirUnused


-- | demux runs as a single separate thread and reads complete 'SDU's from
-- the underlying Bearer and forwards it to the matching ingress queue.
demuxer :: (MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
            MonadTimer m)
      => [MiniProtocolState mode m]
      -> Tracer m BearerTrace
      -> Bearer m
      -> m void
demuxer :: forall (m :: * -> *) (mode :: Mode) void.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
 MonadTimer m) =>
[MiniProtocolState mode m]
-> Tracer m BearerTrace -> Bearer m -> m void
demuxer [MiniProtocolState mode m]
ptcls Tracer m BearerTrace
tracer Bearer m
bearer =
  let !dispatchTable :: MiniProtocolDispatch m
dispatchTable = [MiniProtocolState mode m] -> MiniProtocolDispatch m
forall (mode :: Mode) (m :: * -> *).
[MiniProtocolState mode m] -> MiniProtocolDispatch m
setupDispatchTable [MiniProtocolState mode m]
ptcls in
  (TimeoutFn m -> m void) -> m void
forall (m :: * -> *) b.
(MonadAsync m, MonadFork m, MonadMonotonicTime m, MonadTimer m,
 MonadMask m, MonadThrow (STM m)) =>
(TimeoutFn m -> m b) -> m b
withTimeoutSerial ((TimeoutFn m -> m void) -> m void)
-> (TimeoutFn m -> m void) -> m void
forall a b. (a -> b) -> a -> b
$ \TimeoutFn m
timeout ->
  m () -> m void
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m void) -> m () -> m void
forall a b. (a -> b) -> a -> b
$ do
    (sdu, _) <- Bearer m -> Tracer m BearerTrace -> TimeoutFn m -> m (SDU, Time)
forall (m :: * -> *).
Bearer m -> Tracer m BearerTrace -> TimeoutFn m -> m (SDU, Time)
Mx.read Bearer m
bearer Tracer m BearerTrace
tracer DiffTime -> m a -> m (Maybe a)
TimeoutFn m
timeout
    -- say $ printf "demuxing sdu on mid %s mode %s lenght %d " (show $ msId sdu) (show $ msDir sdu)
    --             (BL.length $ msBlob sdu)
    case lookupMiniProtocol dispatchTable (msNum sdu)
                            -- Notice the mode reversal, ResponderDir is
                            -- delivered to InitiatorDir and vice versa:
                            (flipMiniProtocolDir $ msDir sdu) of
      Maybe (MiniProtocolDispatchInfo m)
Nothing   -> Error -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MiniProtocolNum -> Error
UnknownMiniProtocol (SDU -> MiniProtocolNum
msNum SDU
sdu))
      Just MiniProtocolDispatchInfo m
MiniProtocolDirUnused ->
                   Error -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (MiniProtocolNum -> Error
InitiatorOnly (SDU -> MiniProtocolNum
msNum SDU
sdu))
      Just (MiniProtocolDispatchInfo IngressQueue m
q Int
qMax) ->
        STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          len :!: buf <- IngressQueue m -> STM m (Pair Int64 Builder)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar IngressQueue m
q
          let !len' = Int64
len Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ ByteString -> Int64
BL.length (SDU -> ByteString
msBlob SDU
sdu)
          if len' <= fromIntegral qMax
              then do
                let !buf' = if Int64
len Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
0
                               then -- Don't copy the payload if the queue was empty
                                 ByteString -> Builder
lazyByteStringInsert (ByteString -> Builder) -> ByteString -> Builder
forall a b. (a -> b) -> a -> b
$ SDU -> ByteString
msBlob SDU
sdu
                               else -- Copy payloads smaller than 128 bytes
                                 Builder
buf Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Int -> ByteString -> Builder
lazyByteStringThreshold Int
128 (SDU -> ByteString
msBlob SDU
sdu)
                writeTVar q $ len' :!: buf'
              else throwSTM $ IngressQueueOverRun (msNum sdu) (msDir sdu)

lookupMiniProtocol :: MiniProtocolDispatch m
                   -> MiniProtocolNum
                   -> MiniProtocolDir
                   -> Maybe (MiniProtocolDispatchInfo m)
lookupMiniProtocol :: forall (m :: * -> *).
MiniProtocolDispatch m
-> MiniProtocolNum
-> MiniProtocolDir
-> Maybe (MiniProtocolDispatchInfo m)
lookupMiniProtocol (MiniProtocolDispatch Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray) MiniProtocolNum
pnum MiniProtocolDir
pdir
  | (MiniProtocolNum, MiniProtocolNum) -> MiniProtocolNum -> Bool
forall a. Ix a => (a, a) -> a -> Bool
inRange (Array MiniProtocolNum (Maybe MiniProtocolIx)
-> (MiniProtocolNum, MiniProtocolNum)
forall i e. Array i e -> (i, i)
bounds Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray) MiniProtocolNum
pnum
  , Just MiniProtocolIx
mpid <- Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array MiniProtocolNum (Maybe MiniProtocolIx)
-> MiniProtocolNum -> Maybe MiniProtocolIx
forall i e. Ix i => Array i e -> i -> e
! MiniProtocolNum
pnum = MiniProtocolDispatchInfo m -> Maybe (MiniProtocolDispatchInfo m)
forall a. a -> Maybe a
Just (Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
-> (MiniProtocolIx, MiniProtocolDir) -> MiniProtocolDispatchInfo m
forall i e. Ix i => Array i e -> i -> e
! (MiniProtocolIx
mpid, MiniProtocolDir
pdir))
  | Bool
otherwise                   = Maybe (MiniProtocolDispatchInfo m)
forall a. Maybe a
Nothing

-- | Construct the table that maps 'MiniProtocolNum' and 'MiniProtocolDir' to
-- 'MiniProtocolDispatchInfo'. Use 'lookupMiniProtocol' to index it.
--
setupDispatchTable :: forall mode m.
                      [MiniProtocolState mode m] -> MiniProtocolDispatch m
setupDispatchTable :: forall (mode :: Mode) (m :: * -> *).
[MiniProtocolState mode m] -> MiniProtocolDispatch m
setupDispatchTable [MiniProtocolState mode m]
ptcls =
    Array MiniProtocolNum (Maybe MiniProtocolIx)
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
-> MiniProtocolDispatch m
forall (m :: * -> *).
Array MiniProtocolNum (Maybe MiniProtocolIx)
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
-> MiniProtocolDispatch m
MiniProtocolDispatch Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray
  where
    -- The 'MiniProtocolNum' space is sparse but we don't want a huge single
    -- table if we use large protocol numbers. So we use a two level mapping.
    --
    -- The first array maps 'MiniProtocolNum' to a dense space of intermediate
    -- integer indexes. These indexes are meaningless outside of the context of
    -- this table. Then we use the index and the 'MiniProtocolDir' for the
    -- second table.
    --
    pnumArray :: Array MiniProtocolNum (Maybe MiniProtocolIx)
    pnumArray :: Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray =
      (MiniProtocolNum, MiniProtocolNum)
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> Array MiniProtocolNum (Maybe MiniProtocolIx)
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (MiniProtocolNum
minpnum, MiniProtocolNum
maxpnum) ([(MiniProtocolNum, Maybe MiniProtocolIx)]
 -> Array MiniProtocolNum (Maybe MiniProtocolIx))
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> Array MiniProtocolNum (Maybe MiniProtocolIx)
forall a b. (a -> b) -> a -> b
$
            -- Fill in Nothing first to cover any unused ones.
            [ (MiniProtocolNum
pnum, Maybe MiniProtocolIx
forall a. Maybe a
Nothing)    | MiniProtocolNum
pnum <- [MiniProtocolNum
minpnum..MiniProtocolNum
maxpnum] ]

            -- And override with the ones actually used.
         [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
-> [(MiniProtocolNum, Maybe MiniProtocolIx)]
forall a. [a] -> [a] -> [a]
++ [ (MiniProtocolNum
pnum, MiniProtocolIx -> Maybe MiniProtocolIx
forall a. a -> Maybe a
Just MiniProtocolIx
pix)   | (MiniProtocolNum
pnum, MiniProtocolIx
pix) <- [MiniProtocolNum]
-> [MiniProtocolIx] -> [(MiniProtocolNum, MiniProtocolIx)]
forall a b. [a] -> [b] -> [(a, b)]
zip [MiniProtocolNum]
pnums [MiniProtocolIx
0..] ]

    ptclArray :: Array (MiniProtocolIx, MiniProtocolDir)
                       (MiniProtocolDispatchInfo m)
    ptclArray :: Array
  (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
ptclArray =
      ((MiniProtocolIx, MiniProtocolDir),
 (MiniProtocolIx, MiniProtocolDir))
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array ((MiniProtocolIx
minpix, MiniProtocolDir
InitiatorDir), (MiniProtocolIx
maxpix, MiniProtocolDir
ResponderDir)) ([((MiniProtocolIx, MiniProtocolDir), MiniProtocolDispatchInfo m)]
 -> Array
      (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m))
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
-> Array
     (MiniProtocolIx, MiniProtocolDir) (MiniProtocolDispatchInfo m)
forall a b. (a -> b) -> a -> b
$
            -- Fill in MiniProtocolDirUnused first to cover any unused ones.
            [ ((MiniProtocolIx
pix, MiniProtocolDir
dir), MiniProtocolDispatchInfo m
forall (m :: * -> *). MiniProtocolDispatchInfo m
MiniProtocolDirUnused)
            | (MiniProtocolIx
pix, MiniProtocolDir
dir) <- ((MiniProtocolIx, MiniProtocolDir),
 (MiniProtocolIx, MiniProtocolDir))
-> [(MiniProtocolIx, MiniProtocolDir)]
forall a. Ix a => (a, a) -> [a]
range ((MiniProtocolIx
minpix, MiniProtocolDir
InitiatorDir),
                                   (MiniProtocolIx
maxpix, MiniProtocolDir
ResponderDir)) ]

             -- And override with the ones actually used.
         [((MiniProtocolIx, MiniProtocolDir), MiniProtocolDispatchInfo m)]
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
-> [((MiniProtocolIx, MiniProtocolDir),
     MiniProtocolDispatchInfo m)]
forall a. [a] -> [a] -> [a]
++ [ ((MiniProtocolIx
pix, MiniProtocolDir
dir), IngressQueue m -> Int -> MiniProtocolDispatchInfo m
forall (m :: * -> *).
IngressQueue m -> Int -> MiniProtocolDispatchInfo m
MiniProtocolDispatchInfo IngressQueue m
q Int
qMax)
            | MiniProtocolState {
                miniProtocolInfo :: forall (mode :: Mode) (m :: * -> *).
MiniProtocolState mode m -> MiniProtocolInfo mode
miniProtocolInfo =
                  MiniProtocolInfo {
                    MiniProtocolNum
miniProtocolNum :: MiniProtocolNum
miniProtocolNum :: forall (mode :: Mode). MiniProtocolInfo mode -> MiniProtocolNum
miniProtocolNum,
                    MiniProtocolDirection mode
miniProtocolDir :: MiniProtocolDirection mode
miniProtocolDir :: forall (mode :: Mode).
MiniProtocolInfo mode -> MiniProtocolDirection mode
miniProtocolDir,
                    MiniProtocolLimits
miniProtocolLimits :: MiniProtocolLimits
miniProtocolLimits :: forall (mode :: Mode). MiniProtocolInfo mode -> MiniProtocolLimits
miniProtocolLimits
                  },
                miniProtocolIngressQueue :: forall (mode :: Mode) (m :: * -> *).
MiniProtocolState mode m -> IngressQueue m
miniProtocolIngressQueue = IngressQueue m
q
              } <- [MiniProtocolState mode m]
ptcls
            , let pix :: MiniProtocolIx
pix  =
                   case Array MiniProtocolNum (Maybe MiniProtocolIx)
pnumArray Array MiniProtocolNum (Maybe MiniProtocolIx)
-> MiniProtocolNum -> Maybe MiniProtocolIx
forall i e. Ix i => Array i e -> i -> e
! MiniProtocolNum
miniProtocolNum of
                     Just MiniProtocolIx
a  -> MiniProtocolIx
a
                     -- This error is impossible to trigger - note that
                     -- `pnumArray` is constructed to ensure that every
                     -- `miniProtocolNum` in `ptcls` indexes to a `Just` value.
                     Maybe MiniProtocolIx
Nothing -> [Char] -> MiniProtocolIx
forall a. HasCallStack => [Char] -> a
error ([Char]
"setupDispatchTable: impossible: missing " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ MiniProtocolNum -> [Char]
forall a. Show a => a -> [Char]
show MiniProtocolNum
miniProtocolNum)
                  dir :: MiniProtocolDir
dir      = MiniProtocolDirection mode -> MiniProtocolDir
forall (mode :: Mode).
MiniProtocolDirection mode -> MiniProtocolDir
protocolDirEnum MiniProtocolDirection mode
miniProtocolDir
                  qMax :: Int
qMax     = MiniProtocolLimits -> Int
maximumIngressQueue MiniProtocolLimits
miniProtocolLimits
            ]

    -- The protocol numbers actually used, in the order of the first use within
    -- the 'ptcls' list. The order does not matter provided we do it
    -- consistently between the two arrays.
    pnums :: [MiniProtocolNum]
pnums   = [MiniProtocolNum] -> [MiniProtocolNum]
forall a. Eq a => [a] -> [a]
nub ([MiniProtocolNum] -> [MiniProtocolNum])
-> [MiniProtocolNum] -> [MiniProtocolNum]
forall a b. (a -> b) -> a -> b
$ (MiniProtocolState mode m -> MiniProtocolNum)
-> [MiniProtocolState mode m] -> [MiniProtocolNum]
forall a b. (a -> b) -> [a] -> [b]
map (MiniProtocolInfo mode -> MiniProtocolNum
forall (mode :: Mode). MiniProtocolInfo mode -> MiniProtocolNum
miniProtocolNum (MiniProtocolInfo mode -> MiniProtocolNum)
-> (MiniProtocolState mode m -> MiniProtocolInfo mode)
-> MiniProtocolState mode m
-> MiniProtocolNum
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MiniProtocolState mode m -> MiniProtocolInfo mode
forall (mode :: Mode) (m :: * -> *).
MiniProtocolState mode m -> MiniProtocolInfo mode
miniProtocolInfo) [MiniProtocolState mode m]
ptcls

    -- The dense range of indexes of used protocol numbers.
    minpix, maxpix :: MiniProtocolIx
    minpix :: MiniProtocolIx
minpix  = MiniProtocolIx
0
    maxpix :: MiniProtocolIx
maxpix  = Int -> MiniProtocolIx
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([MiniProtocolNum] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [MiniProtocolNum]
pnums Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

    -- The sparse range of protocol numbers
    minpnum, maxpnum :: MiniProtocolNum
    minpnum :: MiniProtocolNum
minpnum = [MiniProtocolNum] -> MiniProtocolNum
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [MiniProtocolNum]
pnums
    maxpnum :: MiniProtocolNum
maxpnum = [MiniProtocolNum] -> MiniProtocolNum
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [MiniProtocolNum]
pnums