module DMQ.NodeToClient.LocalMsgNotification
  ( localMsgNotificationServer
  , LocalMsgNotificationProtocolError (..)
  ) where

import Control.Concurrent.Class.MonadSTM
import Control.Monad.Class.MonadThrow
import Control.Tracer
import Data.List.NonEmpty qualified as NonEmpty
import Data.Maybe (fromJust)
import Data.Traversable (mapAccumR)
import Data.Word

import DMQ.Protocol.LocalMsgNotification.Server
import DMQ.Protocol.LocalMsgNotification.Type
import Ouroboros.Network.TxSubmission.Mempool.Reader

data LocalMsgNotificationProtocolError =
    ProtocolErrorUnexpectedBlockingRequest
  | ProtocolErrorUnexpectedNonBlockingRequest
  deriving Int -> LocalMsgNotificationProtocolError -> ShowS
[LocalMsgNotificationProtocolError] -> ShowS
LocalMsgNotificationProtocolError -> String
(Int -> LocalMsgNotificationProtocolError -> ShowS)
-> (LocalMsgNotificationProtocolError -> String)
-> ([LocalMsgNotificationProtocolError] -> ShowS)
-> Show LocalMsgNotificationProtocolError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> LocalMsgNotificationProtocolError -> ShowS
showsPrec :: Int -> LocalMsgNotificationProtocolError -> ShowS
$cshow :: LocalMsgNotificationProtocolError -> String
show :: LocalMsgNotificationProtocolError -> String
$cshowList :: [LocalMsgNotificationProtocolError] -> ShowS
showList :: [LocalMsgNotificationProtocolError] -> ShowS
Show

instance Exception LocalMsgNotificationProtocolError where
  displayException :: LocalMsgNotificationProtocolError -> String
displayException LocalMsgNotificationProtocolError
ProtocolErrorUnexpectedBlockingRequest =
    String
"The client issued a blocking request when a non-blocking request was expected."
  displayException LocalMsgNotificationProtocolError
ProtocolErrorUnexpectedNonBlockingRequest =
    String
"The client issued a non-blocking request when a blocking request was expected."

-- | Local Message Notification server application
--
localMsgNotificationServer
  :: forall m msg msgid idx a. (MonadSTM m {-, MonadThrow m -})
  => Tracer m (TraceMessageNotificationServer msg)
  -> m a
  -> Word16
  -> TxSubmissionMempoolReader msgid msg idx m
  -> LocalMsgNotificationServer m msg a
localMsgNotificationServer :: forall (m :: * -> *) msg msgid idx a.
MonadSTM m =>
Tracer m (TraceMessageNotificationServer msg)
-> m a
-> Word16
-> TxSubmissionMempoolReader msgid msg idx m
-> LocalMsgNotificationServer m msg a
localMsgNotificationServer Tracer m (TraceMessageNotificationServer msg)
tracer m a
mdone Word16
maxMsgs0
                           TxSubmissionMempoolReader {
                             idx
mempoolZeroIdx :: idx
mempoolZeroIdx :: forall txid tx idx (m :: * -> *).
TxSubmissionMempoolReader txid tx idx m -> idx
mempoolZeroIdx
                           , STM m (MempoolSnapshot msgid msg idx)
mempoolGetSnapshot :: STM m (MempoolSnapshot msgid msg idx)
mempoolGetSnapshot :: forall txid tx idx (m :: * -> *).
TxSubmissionMempoolReader txid tx idx m
-> STM m (MempoolSnapshot txid tx idx)
mempoolGetSnapshot
                           } =
  m (ServerIdle m msg a) -> LocalMsgNotificationServer m msg a
forall (m :: * -> *) msg a.
m (ServerIdle m msg a) -> LocalMsgNotificationServer m msg a
LocalMsgNotificationServer (m (ServerIdle m msg a) -> LocalMsgNotificationServer m msg a)
-> (ServerIdle m msg a -> m (ServerIdle m msg a))
-> ServerIdle m msg a
-> LocalMsgNotificationServer m msg a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ServerIdle m msg a -> m (ServerIdle m msg a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ServerIdle m msg a -> LocalMsgNotificationServer m msg a)
-> ServerIdle m msg a -> LocalMsgNotificationServer m msg a
forall a b. (a -> b) -> a -> b
$ idx -> HasMore -> ServerIdle m msg a
serverIdle idx
mempoolZeroIdx HasMore
DoesNotHaveMore
  where
    maxMsgs :: Int
maxMsgs = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
maxMsgs0

    serverIdle :: idx -> HasMore -> ServerIdle m msg a
    serverIdle :: idx -> HasMore -> ServerIdle m msg a
serverIdle !idx
lastIdx HasMore
_hasMore = ServerIdle { SingBlockingStyle blocking -> m (ServerResponse m blocking msg a)
forall (blocking :: StBlockingStyle).
SingBlockingStyle blocking -> m (ServerResponse m blocking msg a)
msgRequestHandler :: forall (blocking :: StBlockingStyle).
SingBlockingStyle blocking -> m (ServerResponse m blocking msg a)
msgRequestHandler :: forall (blocking :: StBlockingStyle).
SingBlockingStyle blocking -> m (ServerResponse m blocking msg a)
msgRequestHandler, m a
msgDoneHandler :: m a
msgDoneHandler :: m a
msgDoneHandler }
      where
        msgRequestHandler :: forall blocking.
                             SingBlockingStyle blocking
                          -> m (ServerResponse m blocking msg a)
        msgRequestHandler :: forall (blocking :: StBlockingStyle).
SingBlockingStyle blocking -> m (ServerResponse m blocking msg a)
msgRequestHandler SingBlockingStyle blocking
blocking = do
          let process :: MempoolSnapshot msgid msg idx
                      -> (idx, HasMore, [msg]) -- last index and extracted messages
              process :: MempoolSnapshot msgid msg idx -> (idx, HasMore, [msg])
process MempoolSnapshot msgid msg idx
ms =
                let f :: idx -> (msgid, idx, SizeInBytes) -> (idx, msg)
f idx
_lastIdx (msgid
_id, idx
idx, SizeInBytes
_size) = (idx
idx, Maybe msg -> msg
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe msg -> msg) -> Maybe msg -> msg
forall a b. (a -> b) -> a -> b
$ MempoolSnapshot msgid msg idx -> idx -> Maybe msg
forall txid tx idx. MempoolSnapshot txid tx idx -> idx -> Maybe tx
mempoolLookupTx MempoolSnapshot msgid msg idx
ms idx
idx)
                    ([(msgid, idx, SizeInBytes)]
prefix, [(msgid, idx, SizeInBytes)]
rest) = Int
-> [(msgid, idx, SizeInBytes)]
-> ([(msgid, idx, SizeInBytes)], [(msgid, idx, SizeInBytes)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
maxMsgs ([(msgid, idx, SizeInBytes)]
 -> ([(msgid, idx, SizeInBytes)], [(msgid, idx, SizeInBytes)]))
-> (idx -> [(msgid, idx, SizeInBytes)])
-> idx
-> ([(msgid, idx, SizeInBytes)], [(msgid, idx, SizeInBytes)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MempoolSnapshot msgid msg idx -> idx -> [(msgid, idx, SizeInBytes)]
forall txid tx idx.
MempoolSnapshot txid tx idx -> idx -> [(txid, idx, SizeInBytes)]
mempoolTxIdsAfter MempoolSnapshot msgid msg idx
ms (idx -> ([(msgid, idx, SizeInBytes)], [(msgid, idx, SizeInBytes)]))
-> idx
-> ([(msgid, idx, SizeInBytes)], [(msgid, idx, SizeInBytes)])
forall a b. (a -> b) -> a -> b
$ idx
lastIdx
                    hasMore' :: HasMore
hasMore' = if [(msgid, idx, SizeInBytes)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(msgid, idx, SizeInBytes)]
rest then HasMore
DoesNotHaveMore else HasMore
HasMore
                    (idx
lastIdx', [msg]
msgs) = (idx -> (msgid, idx, SizeInBytes) -> (idx, msg))
-> idx -> [(msgid, idx, SizeInBytes)] -> (idx, [msg])
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumR idx -> (msgid, idx, SizeInBytes) -> (idx, msg)
f idx
lastIdx [(msgid, idx, SizeInBytes)]
prefix
                in (idx
lastIdx', HasMore
hasMore', [msg]
msgs)
          case SingBlockingStyle blocking
blocking of
            SingBlockingStyle blocking
SingBlocking
              -- | HasMore <- hasMore ->
              --     throwIO ProtocolErrorUnexpectedBlockingRequest
              | Bool
otherwise -> do
                  (lastIdx', hasMore', msgs) <- STM m (idx, HasMore, [msg]) -> m (idx, HasMore, [msg])
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically do
                    snapshot <- STM m (MempoolSnapshot msgid msg idx)
mempoolGetSnapshot
                    let (lastIdx', hasMore', msgs) = process snapshot
                    check . not . null $ msgs
                    return (lastIdx', hasMore', msgs)

                  traceWith tracer (TraceMsgNotificationServerReply hasMore' msgs)
                  return $ ServerReply (BlockingReply (NonEmpty.fromList msgs))
                                       hasMore'
                                       (serverIdle lastIdx' hasMore')
            SingBlockingStyle blocking
SingNonBlocking
              -- | DoesNotHaveMore <- hasMore ->
              --     throwIO ProtocolErrorUnexpectedNonBlockingRequest
              | Bool
otherwise -> do
                  snapshot <- STM m (MempoolSnapshot msgid msg idx)
-> m (MempoolSnapshot msgid msg idx)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m (MempoolSnapshot msgid msg idx)
mempoolGetSnapshot
                  let (lastIdx', hasMore', msgs) = process snapshot
                  traceWith tracer (TraceMsgNotificationServerReply hasMore' msgs)
                  return $ ServerReply (NonBlockingReply msgs) hasMore' (serverIdle lastIdx' hasMore')

        msgDoneHandler :: m a
msgDoneHandler =
          Tracer m (TraceMessageNotificationServer msg)
-> TraceMessageNotificationServer msg -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (TraceMessageNotificationServer msg)
tracer TraceMessageNotificationServer msg
forall msg. TraceMessageNotificationServer msg
TraceMsgNotificationServerHandleDone m () -> m a -> m a
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m a
mdone


data TraceMessageNotificationServer msg =
    TraceMsgNotificationServerReply HasMore [msg]
    -- ^ The transactions to be sent in the response.
  | TraceMsgNotificationServerHandleDone
    -- ^ client terminates
  deriving Int -> TraceMessageNotificationServer msg -> ShowS
[TraceMessageNotificationServer msg] -> ShowS
TraceMessageNotificationServer msg -> String
(Int -> TraceMessageNotificationServer msg -> ShowS)
-> (TraceMessageNotificationServer msg -> String)
-> ([TraceMessageNotificationServer msg] -> ShowS)
-> Show (TraceMessageNotificationServer msg)
forall msg.
Show msg =>
Int -> TraceMessageNotificationServer msg -> ShowS
forall msg.
Show msg =>
[TraceMessageNotificationServer msg] -> ShowS
forall msg.
Show msg =>
TraceMessageNotificationServer msg -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall msg.
Show msg =>
Int -> TraceMessageNotificationServer msg -> ShowS
showsPrec :: Int -> TraceMessageNotificationServer msg -> ShowS
$cshow :: forall msg.
Show msg =>
TraceMessageNotificationServer msg -> String
show :: TraceMessageNotificationServer msg -> String
$cshowList :: forall msg.
Show msg =>
[TraceMessageNotificationServer msg] -> ShowS
showList :: [TraceMessageNotificationServer msg] -> ShowS
Show