{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}

module Cardano.Client.Subscription
  ( -- * Subscription API
    subscribe
  , SubscriptionParams (..)
  , SubscriptionTracers (..)
  , SubscriptionTrace (..)
    -- * Re-exports
    -- ** Mux
  , MuxMode
  , MuxTrace
  , Mx.WithBearer
    -- ** Connections
  , ConnectionId (..)
  , LocalAddress (..)
    -- ** Protocol API
  , NodeToClientProtocols (..)
  , MiniProtocolCb (..)
  , RunMiniProtocol (..)
  , ControlMessage (..)
  ) where

import Codec.CBOR.Term qualified as CBOR
import Control.Exception
import Control.Monad (join)
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI
import Control.Tracer (Tracer, traceWith)
import Data.ByteString.Lazy qualified as BSL
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import Data.Maybe (fromMaybe)
import Data.Void (Void)

import Network.Mux qualified as Mx

import Ouroboros.Network.ControlMessage (ControlMessage (..))
import Ouroboros.Network.Magic (NetworkMagic)
import Ouroboros.Network.Mux (MiniProtocolCb (..),
           OuroborosApplicationWithMinimalCtx, RunMiniProtocol (..))

import Ouroboros.Network.ConnectionId (ConnectionId (..))
import Ouroboros.Network.NodeToClient (Handshake, LocalAddress (..),
           NetworkConnectTracers (..), NodeToClientProtocols,
           NodeToClientVersion, NodeToClientVersionData (..), TraceSendRecv,
           Versions)
import Ouroboros.Network.NodeToClient qualified as NtC
import Ouroboros.Network.Snocket qualified as Snocket

type MuxMode  = Mx.Mode
type MuxTrace = Mx.Trace

data SubscriptionParams a = SubscriptionParams
  { forall a. SubscriptionParams a -> LocalAddress
spAddress           :: !LocalAddress
  -- ^ unix socket or named pipe address
  , forall a. SubscriptionParams a -> Maybe DiffTime
spReconnectionDelay :: !(Maybe DiffTime)
  -- ^ delay between connection attempts.  The default value is `5s`.
  , forall a.
SubscriptionParams a -> Either SomeException a -> Decision
spCompleteCb        :: Either SomeException a -> Decision
  }

data Decision =
    Abort
    -- ^ abort subscription loop
  | Reconnect
    -- ^ reconnect

data SubscriptionTracers a = SubscriptionTracers {
      forall a.
SubscriptionTracers a
-> Tracer IO (WithBearer (ConnectionId LocalAddress) MuxTrace)
stMuxTracer          :: Tracer IO (Mx.WithBearer (ConnectionId LocalAddress) MuxTrace),
      -- ^ low level mux-network tracer, which logs mux sdu (send and received)
      -- and other low level multiplexing events.
      forall a.
SubscriptionTracers a
-> Tracer
     IO
     (WithBearer
        (ConnectionId LocalAddress)
        (TraceSendRecv (Handshake NodeToClientVersion Term)))
stHandshakeTracer    :: Tracer IO (Mx.WithBearer (ConnectionId LocalAddress)
                                            (TraceSendRecv (Handshake NodeToClientVersion CBOR.Term))),
      -- ^ handshake protocol tracer; it is important for analysing version
      -- negotation mismatches.
      forall a. SubscriptionTracers a -> Tracer IO (SubscriptionTrace a)
stSubscriptionTracer :: Tracer IO (SubscriptionTrace a)
    }

data SubscriptionTrace a =
    SubscriptionResult a
  | SubscriptionError SomeException
  | SubscriptionReconnect
  | SubscriptionTerminate
  deriving Int -> SubscriptionTrace a -> ShowS
[SubscriptionTrace a] -> ShowS
SubscriptionTrace a -> String
(Int -> SubscriptionTrace a -> ShowS)
-> (SubscriptionTrace a -> String)
-> ([SubscriptionTrace a] -> ShowS)
-> Show (SubscriptionTrace a)
forall a. Show a => Int -> SubscriptionTrace a -> ShowS
forall a. Show a => [SubscriptionTrace a] -> ShowS
forall a. Show a => SubscriptionTrace a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> SubscriptionTrace a -> ShowS
showsPrec :: Int -> SubscriptionTrace a -> ShowS
$cshow :: forall a. Show a => SubscriptionTrace a -> String
show :: SubscriptionTrace a -> String
$cshowList :: forall a. Show a => [SubscriptionTrace a] -> ShowS
showList :: [SubscriptionTrace a] -> ShowS
Show

-- | Subscribe using `node-to-client` mini-protocol.
--
-- 'blockVersion' ought to be instantiated with `BlockNodeToClientVersion blk`.
-- The callback receives `blockVersion` associated with each
-- 'NodeToClientVersion' and can be used to create codecs with
-- `Ouroboros.Consensus.Network.NodeToClient.clientCodecs`.
--
subscribe
  :: forall blockVersion a.
     Snocket.LocalSnocket
  -> NetworkMagic
  -> Map NodeToClientVersion blockVersion
  -- ^ Use `supportedNodeToClientVersions` from `ouroboros-consensus`.
  -> SubscriptionTracers a
  -> SubscriptionParams a
  -> (   NodeToClientVersion
      -> blockVersion
      -> NodeToClientProtocols Mx.InitiatorMode LocalAddress BSL.ByteString IO a Void)
  -> IO ()
subscribe :: forall blockVersion a.
LocalSnocket
-> NetworkMagic
-> Map NodeToClientVersion blockVersion
-> SubscriptionTracers a
-> SubscriptionParams a
-> (NodeToClientVersion
    -> blockVersion
    -> NodeToClientProtocols
         'InitiatorMode LocalAddress ByteString IO a Void)
-> IO ()
subscribe LocalSnocket
snocket NetworkMagic
networkMagic Map NodeToClientVersion blockVersion
supportedVersions
                  SubscriptionTracers {
                    stMuxTracer :: forall a.
SubscriptionTracers a
-> Tracer IO (WithBearer (ConnectionId LocalAddress) MuxTrace)
stMuxTracer = Tracer IO (WithBearer (ConnectionId LocalAddress) MuxTrace)
muxTracer,
                    stHandshakeTracer :: forall a.
SubscriptionTracers a
-> Tracer
     IO
     (WithBearer
        (ConnectionId LocalAddress)
        (TraceSendRecv (Handshake NodeToClientVersion Term)))
stHandshakeTracer = Tracer
  IO
  (WithBearer
     (ConnectionId LocalAddress)
     (TraceSendRecv (Handshake NodeToClientVersion Term)))
handshakeTracer,
                    stSubscriptionTracer :: forall a. SubscriptionTracers a -> Tracer IO (SubscriptionTrace a)
stSubscriptionTracer = Tracer IO (SubscriptionTrace a)
tracer
                  }
                  SubscriptionParams {
                    spAddress :: forall a. SubscriptionParams a -> LocalAddress
spAddress = LocalAddress
addr,
                    spReconnectionDelay :: forall a. SubscriptionParams a -> Maybe DiffTime
spReconnectionDelay = Maybe DiffTime
reConnDelay,
                    spCompleteCb :: forall a.
SubscriptionParams a -> Either SomeException a -> Decision
spCompleteCb = Either SomeException a -> Decision
completeCb
                  }
                  NodeToClientVersion
-> blockVersion
-> NodeToClientProtocols
     'InitiatorMode LocalAddress ByteString IO a Void
protocols =
    ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO ()) -> IO ())
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask ->
      (forall a. IO a -> IO a) -> IO (Either SomeException a) -> IO ()
loop IO x -> IO x
forall a. IO a -> IO a
unmask (IO (Either SomeException a) -> IO ())
-> IO (Either SomeException a) -> IO ()
forall a b. (a -> b) -> a -> b
$
        LocalSnocket
-> NetworkConnectTracers LocalAddress NodeToClientVersion
-> Versions
     NodeToClientVersion
     NodeToClientVersionData
     (OuroborosApplicationWithMinimalCtx
        'InitiatorMode LocalAddress ByteString IO a Void)
-> String
-> IO (Either SomeException a)
forall a.
LocalSnocket
-> NetworkConnectTracers LocalAddress NodeToClientVersion
-> Versions
     NodeToClientVersion
     NodeToClientVersionData
     (OuroborosApplicationWithMinimalCtx
        'InitiatorMode LocalAddress ByteString IO a Void)
-> String
-> IO (Either SomeException a)
NtC.connectTo
          LocalSnocket
snocket
          NetworkConnectTracers {
            nctMuxTracer :: Tracer IO (WithBearer (ConnectionId LocalAddress) MuxTrace)
nctMuxTracer       = Tracer IO (WithBearer (ConnectionId LocalAddress) MuxTrace)
muxTracer,
            nctHandshakeTracer :: Tracer
  IO
  (WithBearer
     (ConnectionId LocalAddress)
     (TraceSendRecv (Handshake NodeToClientVersion Term)))
nctHandshakeTracer = Tracer
  IO
  (WithBearer
     (ConnectionId LocalAddress)
     (TraceSendRecv (Handshake NodeToClientVersion Term)))
handshakeTracer
          }
          (NetworkMagic
-> Map NodeToClientVersion blockVersion
-> (NodeToClientVersion
    -> blockVersion
    -> NodeToClientProtocols
         'InitiatorMode LocalAddress ByteString IO a Void)
-> Versions
     NodeToClientVersion
     NodeToClientVersionData
     (OuroborosApplicationWithMinimalCtx
        'InitiatorMode LocalAddress ByteString IO a Void)
forall (m :: * -> *) (appType :: Mode) bytes blockVersion a.
NetworkMagic
-> Map NodeToClientVersion blockVersion
-> (NodeToClientVersion
    -> blockVersion
    -> NodeToClientProtocols appType LocalAddress bytes m a Void)
-> Versions
     NodeToClientVersion
     NodeToClientVersionData
     (OuroborosApplicationWithMinimalCtx
        appType LocalAddress bytes m a Void)
versionedProtocols NetworkMagic
networkMagic Map NodeToClientVersion blockVersion
supportedVersions NodeToClientVersion
-> blockVersion
-> NodeToClientProtocols
     'InitiatorMode LocalAddress ByteString IO a Void
protocols)
          (LocalAddress -> String
getFilePath LocalAddress
addr)
  where
    loop :: (forall x. IO x -> IO x) -> IO (Either SomeException a) -> IO ()
    loop :: (forall a. IO a -> IO a) -> IO (Either SomeException a) -> IO ()
loop forall a. IO a -> IO a
unmask IO (Either SomeException a)
act = do
      r <- Either SomeException (Either SomeException a)
-> Either SomeException a
forall x y. Either x (Either x y) -> Either x y
squashLefts (Either SomeException (Either SomeException a)
 -> Either SomeException a)
-> IO (Either SomeException (Either SomeException a))
-> IO (Either SomeException a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Either SomeException a)
-> IO (Either SomeException (Either SomeException a))
forall e a. Exception e => IO a -> IO (Either e a)
try (IO (Either SomeException a) -> IO (Either SomeException a)
forall a. IO a -> IO a
unmask IO (Either SomeException a)
act)
      case r of
        Right a
a -> Tracer IO (SubscriptionTrace a) -> SubscriptionTrace a -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO (SubscriptionTrace a)
tracer (a -> SubscriptionTrace a
forall a. a -> SubscriptionTrace a
SubscriptionResult a
a)
        Left  SomeException
e -> Tracer IO (SubscriptionTrace a) -> SubscriptionTrace a -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO (SubscriptionTrace a)
tracer (SomeException -> SubscriptionTrace a
forall a. SomeException -> SubscriptionTrace a
SubscriptionError SomeException
e)
      case completeCb r of
        Decision
Abort ->
          Tracer IO (SubscriptionTrace a) -> SubscriptionTrace a -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO (SubscriptionTrace a)
tracer SubscriptionTrace a
forall a. SubscriptionTrace a
SubscriptionTerminate
        Decision
Reconnect -> do
          Tracer IO (SubscriptionTrace a) -> SubscriptionTrace a -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO (SubscriptionTrace a)
tracer SubscriptionTrace a
forall a. SubscriptionTrace a
SubscriptionReconnect
          DiffTime -> IO ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay (DiffTime -> Maybe DiffTime -> DiffTime
forall a. a -> Maybe a -> a
fromMaybe DiffTime
5 Maybe DiffTime
reConnDelay)
          (forall a. IO a -> IO a) -> IO (Either SomeException a) -> IO ()
loop IO x -> IO x
forall a. IO a -> IO a
unmask IO (Either SomeException a)
act

    squashLefts :: forall x y. Either x (Either x y) -> Either x y
    squashLefts :: forall x y. Either x (Either x y) -> Either x y
squashLefts = Either x (Either x y) -> Either x y
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join


versionedProtocols ::
     forall m appType bytes blockVersion a.
     NetworkMagic
  -> Map NodeToClientVersion blockVersion
  -- ^ Use `supportedNodeToClientVersions` from `ouroboros-consensus`.
  -> (   NodeToClientVersion
      -> blockVersion
      -> NodeToClientProtocols appType LocalAddress bytes m a Void)
     -- ^ callback which receives codecs, connection id and STM action which
     -- can be checked if the networking runtime system requests the protocols
     -- to stop.
     --
     -- TODO: the 'RunOrStop' might not be needed for @node-to-client@, hence
     -- it's not exposed in 'subscribe'. We should provide
     -- 'OuroborosClientApplication', which does not include it.
  -> Versions
       NodeToClientVersion
       NodeToClientVersionData
       (OuroborosApplicationWithMinimalCtx appType LocalAddress bytes m a Void)
versionedProtocols :: forall (m :: * -> *) (appType :: Mode) bytes blockVersion a.
NetworkMagic
-> Map NodeToClientVersion blockVersion
-> (NodeToClientVersion
    -> blockVersion
    -> NodeToClientProtocols appType LocalAddress bytes m a Void)
-> Versions
     NodeToClientVersion
     NodeToClientVersionData
     (OuroborosApplicationWithMinimalCtx
        appType LocalAddress bytes m a Void)
versionedProtocols NetworkMagic
networkMagic Map NodeToClientVersion blockVersion
supportedVersions NodeToClientVersion
-> blockVersion
-> NodeToClientProtocols appType LocalAddress bytes m a Void
callback =
    ((NodeToClientVersion, blockVersion)
 -> Versions
      NodeToClientVersion
      NodeToClientVersionData
      (OuroborosApplicationWithMinimalCtx
         appType LocalAddress bytes m a Void))
-> [(NodeToClientVersion, blockVersion)]
-> Versions
     NodeToClientVersion
     NodeToClientVersionData
     (OuroborosApplicationWithMinimalCtx
        appType LocalAddress bytes m a Void)
forall vNum (f :: * -> *) x extra r.
(Ord vNum, Foldable f, HasCallStack) =>
(x -> Versions vNum extra r) -> f x -> Versions vNum extra r
NtC.foldMapVersions (NodeToClientVersion, blockVersion)
-> Versions
     NodeToClientVersion
     NodeToClientVersionData
     (OuroborosApplicationWithMinimalCtx
        appType LocalAddress bytes m a Void)
applyVersion (Map NodeToClientVersion blockVersion
-> [(NodeToClientVersion, blockVersion)]
forall k a. Map k a -> [(k, a)]
Map.toList Map NodeToClientVersion blockVersion
supportedVersions)
  where
    applyVersion
      :: (NodeToClientVersion, blockVersion)
      -> Versions
           NodeToClientVersion
           NodeToClientVersionData
           (OuroborosApplicationWithMinimalCtx appType LocalAddress bytes m a Void)
    applyVersion :: (NodeToClientVersion, blockVersion)
-> Versions
     NodeToClientVersion
     NodeToClientVersionData
     (OuroborosApplicationWithMinimalCtx
        appType LocalAddress bytes m a Void)
applyVersion (NodeToClientVersion
version, blockVersion
blockVersion) =
      NodeToClientVersion
-> NodeToClientVersionData
-> NodeToClientProtocols appType LocalAddress bytes m a Void
-> Versions
     NodeToClientVersion
     NodeToClientVersionData
     (OuroborosApplicationWithMinimalCtx
        appType LocalAddress bytes m a Void)
forall (appType :: Mode) bytes (m :: * -> *) a b.
NodeToClientVersion
-> NodeToClientVersionData
-> NodeToClientProtocols appType LocalAddress bytes m a b
-> Versions
     NodeToClientVersion
     NodeToClientVersionData
     (OuroborosApplicationWithMinimalCtx
        appType LocalAddress bytes m a b)
NtC.versionedNodeToClientProtocols
        NodeToClientVersion
version
        NodeToClientVersionData {
          NetworkMagic
networkMagic :: NetworkMagic
networkMagic :: NetworkMagic
networkMagic,
          query :: Bool
query = Bool
False
        }
        (NodeToClientVersion
-> blockVersion
-> NodeToClientProtocols appType LocalAddress bytes m a Void
callback NodeToClientVersion
version blockVersion
blockVersion)