{-# LANGUAGE CPP                 #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE DerivingVia         #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}

-- it is useful to have 'HasInitiator' constraint on 'connectToNode' & friends.
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

-- For Hashable SockAddr
{-# OPTIONS_GHC -Wno-orphans #-}


-- |
-- Module exports interface for running a node over a socket over TCP \/ IP.
--
module Ouroboros.Network.Socket
  ( -- * High level socket interface
    ConnectionTable
  , ConnectionTableRef (..)
  , ValencyCounter
  , NetworkMutableState (..)
  , SomeResponderApplication (..)
  , newNetworkMutableState
  , newNetworkMutableStateSTM
  , cleanNetworkMutableState
  , AcceptedConnectionsLimit (..)
  , ConnectionId (..)
  , withServerNode
  , withServerNode'
  , connectToNode
  , connectToNodeSocket
  , connectToNode'
    -- * Socket configuration
  , configureSocket
  , configureSystemdSocket
  , SystemdSocketTracer (..)
    -- * Traces
  , NetworkConnectTracers (..)
  , nullNetworkConnectTracers
  , debuggingNetworkConnectTracers
  , NetworkServerTracers (..)
  , nullNetworkServerTracers
  , debuggingNetworkServerTracers
  , AcceptConnectionsPolicyTrace (..)
    -- * Helper function for creating servers
  , fromSnocket
  , beginConnection
    -- * Re-export of HandshakeCallbacks
  , HandshakeCallbacks (..)
    -- * Re-export of PeerStates
  , PeerStates
    -- * Re-export connection table functions
  , newConnectionTable
  , refConnection
  , addConnection
  , removeConnection
  , newValencyCounter
  , addValencyCounter
  , remValencyCounter
  , waitValencyCounter
  , readValencyCounter
    -- * Auxiliary functions
  , sockAddrFamily
  ) where

import Control.Concurrent.Async
import Control.Concurrent.Class.MonadSTM.Strict
import Control.Exception (SomeException (..))
-- TODO: remove this, it will not be needed when `orElse` PR will be merged.
import Codec.CBOR.Read qualified as CBOR
import Codec.CBOR.Term qualified as CBOR
import Control.Monad (unless, when)
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI
import Control.Monad.STM qualified as STM
import Data.ByteString.Lazy qualified as BL
import Data.Hashable
import Data.Typeable (Typeable)
import Data.Void
import Data.Word (Word16)
import GHC.IO.Exception
#if !defined(mingw32_HOST_OS)
import Foreign.C.Error
#endif

import Network.Socket (SockAddr, Socket, StructLinger (..))
import Network.Socket qualified as Socket

import Control.Tracer

import Network.Mux.Bearer qualified as Mx
import Network.Mux.Compat qualified as Mx
import Network.Mux.DeltaQ.TraceTransformer
import Network.TypedProtocol.Codec hiding (decode, encode)

import Ouroboros.Network.Context
import Ouroboros.Network.Driver.Limits
import Ouroboros.Network.ErrorPolicy
import Ouroboros.Network.Handshake (HandshakeCallbacks (..))
import Ouroboros.Network.IOManager (IOManager)
import Ouroboros.Network.Mux
import Ouroboros.Network.Protocol.Handshake
import Ouroboros.Network.Protocol.Handshake.Codec
import Ouroboros.Network.Protocol.Handshake.Type
import Ouroboros.Network.Server.ConnectionTable
import Ouroboros.Network.Server.Socket (AcceptConnectionsPolicyTrace (..),
           AcceptedConnectionsLimit (..))
import Ouroboros.Network.Server.Socket qualified as Server
import Ouroboros.Network.Snocket (Snocket)
import Ouroboros.Network.Snocket qualified as Snocket
import Ouroboros.Network.Subscription.PeerState


-- | Tracer used by 'connectToNode' (and derivatives, like
-- 'Ouroboros.Network.NodeToNode.connectTo' or
-- 'Ouroboros.Network.NodeToClient.connectTo).
--
data NetworkConnectTracers addr vNumber = NetworkConnectTracers {
      forall addr vNumber.
NetworkConnectTracers addr vNumber
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer         :: Tracer IO (Mx.WithMuxBearer (ConnectionId addr)  Mx.MuxTrace),
      -- ^ low level mux-network tracer, which logs mux sdu (send and received)
      -- and other low level multiplexing events.
      forall addr vNumber.
NetworkConnectTracers addr vNumber
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer   :: Tracer IO (Mx.WithMuxBearer (ConnectionId addr)
                                          (TraceSendRecv (Handshake vNumber CBOR.Term)))
      -- ^ handshake protocol tracer; it is important for analysing version
      -- negotiation mismatches.
    }

nullNetworkConnectTracers :: NetworkConnectTracers addr vNumber
nullNetworkConnectTracers :: forall addr vNumber. NetworkConnectTracers addr vNumber
nullNetworkConnectTracers = NetworkConnectTracers {
      nctMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer       = Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer,
      nctHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer = Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer
    }


debuggingNetworkConnectTracers :: (Show addr, Show vNumber)
                               => NetworkConnectTracers addr vNumber
debuggingNetworkConnectTracers :: forall addr vNumber.
(Show addr, Show vNumber) =>
NetworkConnectTracers addr vNumber
debuggingNetworkConnectTracers = NetworkConnectTracers {
      nctMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer       = Tracer IO String
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer,
      nctHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer = Tracer IO String
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer
    }

sockAddrFamily
    :: Socket.SockAddr
    -> Socket.Family
sockAddrFamily :: SockAddr -> Family
sockAddrFamily (Socket.SockAddrInet  PortNumber
_ HostAddress
_    ) = Family
Socket.AF_INET
sockAddrFamily (Socket.SockAddrInet6 PortNumber
_ HostAddress
_ HostAddress6
_ HostAddress
_) = Family
Socket.AF_INET6
sockAddrFamily (Socket.SockAddrUnix String
_       ) = Family
Socket.AF_UNIX


-- | Configure a socket.  Either 'Socket.AF_INET' or 'Socket.AF_INET6' socket
-- is expected.
--
configureSocket :: Socket -> Maybe SockAddr -> IO ()
configureSocket :: Socket -> Maybe SockAddr -> IO ()
configureSocket Socket
sock Maybe SockAddr
addr = do
    let fml :: Maybe Family
fml = SockAddr -> Family
sockAddrFamily (SockAddr -> Family) -> Maybe SockAddr -> Maybe Family
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe SockAddr
addr
    Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
sock SocketOption
Socket.ReuseAddr Int
1
#if !defined(mingw32_HOST_OS)
    -- not supported on Windows 10
    Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
sock SocketOption
Socket.ReusePort Int
1
#endif
    Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
sock SocketOption
Socket.NoDelay Int
1
    -- it is safe to set 'SO_LINGER' option (which implicates that every
    -- close will reset the connection), since our protocols are robust.
    -- In particular if invalid data will arrive (which includes the rare
    -- case of a late packet from a previous connection), we will abandon
    -- (and close) the connection.
    Socket -> SocketOption -> StructLinger -> IO ()
forall a. Storable a => Socket -> SocketOption -> a -> IO ()
Socket.setSockOpt Socket
sock SocketOption
Socket.Linger
                          (StructLinger { sl_onoff :: CInt
sl_onoff  = CInt
1,
                                          sl_linger :: CInt
sl_linger = CInt
0 })
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe Family
fml Maybe Family -> Maybe Family -> Bool
forall a. Eq a => a -> a -> Bool
== Family -> Maybe Family
forall a. a -> Maybe a
Just Family
Socket.AF_INET6)
      -- An AF_INET6 socket can be used to talk to both IPv4 and IPv6 end points, and
      -- it is enabled by default on some systems. Disabled here since we run a separate
      -- IPv4 server instance if configured to use IPv4.
      (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
sock SocketOption
Socket.IPv6Only Int
1


-- | Configure sockets passed through systemd socket activation.
-- Currently 'ReuseAddr' and 'Linger' options are not configurable with
-- 'systemd.socket', these options are set by this function.  For other socket
-- options we only trace if they are not set.
--
configureSystemdSocket :: Tracer IO SystemdSocketTracer -> Socket -> SockAddr -> IO ()
configureSystemdSocket :: Tracer IO SystemdSocketTracer -> Socket -> SockAddr -> IO ()
configureSystemdSocket Tracer IO SystemdSocketTracer
tracer Socket
sock SockAddr
addr = do
   let fml :: Family
fml = SockAddr -> Family
sockAddrFamily SockAddr
addr
   case Family
fml of
     Family
Socket.AF_INET ->
          Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
sock SocketOption
Socket.ReuseAddr Int
1
     Family
Socket.AF_INET6 ->
          Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
sock SocketOption
Socket.ReuseAddr Int
1
     Family
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if !defined(mingw32_HOST_OS)
   -- not supported on Windows 10
   reusePortOpt <- Socket -> SocketOption -> IO Int
Socket.getSocketOption Socket
sock SocketOption
Socket.ReusePort
   unless (reusePortOpt /= 0) $
     traceWith tracer (SocketOptionNotSet Socket.ReusePort)
#endif
   noDelayOpt <- Socket.getSocketOption sock Socket.NoDelay
   unless (noDelayOpt /= 0) $
     traceWith tracer (SocketOptionNotSet Socket.NoDelay)

   Socket.setSockOpt sock Socket.Linger
                         (StructLinger { sl_onoff  = 1,
                                         sl_linger = 0 })
   when (fml == Socket.AF_INET6) $ do
     ipv6OnlyOpt <- Socket.getSocketOption sock Socket.IPv6Only
     unless (ipv6OnlyOpt /= 0) $
       traceWith tracer (SocketOptionNotSet Socket.IPv6Only)

data SystemdSocketTracer = SocketOptionNotSet Socket.SocketOption
  deriving Int -> SystemdSocketTracer -> ShowS
[SystemdSocketTracer] -> ShowS
SystemdSocketTracer -> String
(Int -> SystemdSocketTracer -> ShowS)
-> (SystemdSocketTracer -> String)
-> ([SystemdSocketTracer] -> ShowS)
-> Show SystemdSocketTracer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SystemdSocketTracer -> ShowS
showsPrec :: Int -> SystemdSocketTracer -> ShowS
$cshow :: SystemdSocketTracer -> String
show :: SystemdSocketTracer -> String
$cshowList :: [SystemdSocketTracer] -> ShowS
showList :: [SystemdSocketTracer] -> ShowS
Show


instance Hashable Socket.SockAddr where
  hashWithSalt :: Int -> SockAddr -> Int
hashWithSalt Int
s (Socket.SockAddrInet   PortNumber
p   HostAddress
a   ) = Int -> (Word16, HostAddress) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (PortNumber -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
p :: Word16, HostAddress
a)
  hashWithSalt Int
s (Socket.SockAddrInet6  PortNumber
p HostAddress
_ HostAddress6
a HostAddress
_ ) = Int -> (Word16, HostAddress6) -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (PortNumber -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
p :: Word16, HostAddress6
a)
  hashWithSalt Int
s (Socket.SockAddrUnix   String
p       ) = Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s String
p

-- | We place an upper limit of `30s` on the time we wait on receiving an SDU.
-- There is no upper bound on the time we wait when waiting for a new SDU.
-- This makes it possible for miniprotocols to use timeouts that are larger
-- than 30s or wait forever.  `30s` for receiving an SDU corresponds to
-- a minimum speed limit of 17kbps.
--
-- ( 8      -- mux header length
-- + 0xffff -- maximum SDU payload
-- )
-- * 8
-- = 524_344 -- maximum bits in an SDU
--
--  524_344 / 30 / 1024 = 17kbps
--
sduTimeout :: DiffTime
sduTimeout :: DiffTime
sduTimeout = DiffTime
30

-- | For handshake, we put a limit of `10s` for sending or receiving a single
-- `MuxSDU`.
--
sduHandshakeTimeout :: DiffTime
sduHandshakeTimeout :: DiffTime
sduHandshakeTimeout = DiffTime
10


-- |
-- Connect to a remote node.  It is using bracket to enclose the underlying
-- socket acquisition.  This implies that when the continuation exits the
-- underlying bearer will get closed.
--
-- The connection will start with handshake protocol sending @Versions@ to the
-- remote peer.  It must fit into @'maxTransmissionUnit'@ (~5k bytes).
--
-- Exceptions thrown by @'MuxApplication'@ are rethrown by @'connectTo'@.
connectToNode
  :: forall appType vNumber vData fd addr a b.
     ( Ord vNumber
     , Typeable vNumber
     , Show vNumber
     , Mx.HasInitiator appType ~ True
     )
  => Snocket IO fd addr
  -> Mx.MakeBearer IO fd
  -> (fd -> IO ()) -- ^ configure a socket
  -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
  -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
  -> VersionDataCodec CBOR.Term vNumber vData
  -> NetworkConnectTracers addr vNumber
  -> HandshakeCallbacks vData
  -> Versions vNumber vData (OuroborosApplicationWithMinimalCtx appType addr BL.ByteString IO a b)
  -- ^ application to run over the connection
  -> Maybe addr
  -- ^ local address; the created socket will bind to it
  -> addr
  -- ^ remote address
  -> IO ()
connectToNode :: forall (appType :: MuxMode) vNumber vData fd addr a b.
(Ord vNumber, Typeable vNumber, Show vNumber,
 HasInitiator appType ~ 'True) =>
Snocket IO fd addr
-> MakeBearer IO fd
-> (fd -> IO ())
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers addr vNumber
-> HandshakeCallbacks vData
-> Versions
     vNumber
     vData
     (OuroborosApplicationWithMinimalCtx appType addr ByteString IO a b)
-> Maybe addr
-> addr
-> IO ()
connectToNode Snocket IO fd addr
sn MakeBearer IO fd
makeBearer fd -> IO ()
configureSock Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits VersionDataCodec Term vNumber vData
versionDataCodec NetworkConnectTracers addr vNumber
tracers HandshakeCallbacks vData
handshakeCallbacks Versions
  vNumber
  vData
  (OuroborosApplicationWithMinimalCtx appType addr ByteString IO a b)
versions Maybe addr
localAddr addr
remoteAddr =
    IO fd -> (fd -> IO ()) -> (fd -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
      (Snocket IO fd addr -> addr -> IO fd
forall (m :: * -> *) fd addr. Snocket m fd addr -> addr -> m fd
Snocket.openToConnect Snocket IO fd addr
sn addr
remoteAddr)
      (Snocket IO fd addr -> fd -> IO ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket IO fd addr
sn)
      (\fd
sd -> do
          fd -> IO ()
configureSock fd
sd
          case Maybe addr
localAddr of
            Just addr
addr -> Snocket IO fd addr -> fd -> addr -> IO ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.bind Snocket IO fd addr
sn fd
sd addr
addr
            Maybe addr
Nothing   -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Snocket IO fd addr -> fd -> addr -> IO ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.connect Snocket IO fd addr
sn fd
sd addr
remoteAddr
          Snocket IO fd addr
-> MakeBearer IO fd
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers addr vNumber
-> HandshakeCallbacks vData
-> Versions
     vNumber
     vData
     (OuroborosApplicationWithMinimalCtx appType addr ByteString IO a b)
-> fd
-> IO ()
forall (appType :: MuxMode) vNumber vData fd addr a b.
(Ord vNumber, Typeable vNumber, Show vNumber,
 HasInitiator appType ~ 'True) =>
Snocket IO fd addr
-> MakeBearer IO fd
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers addr vNumber
-> HandshakeCallbacks vData
-> Versions
     vNumber
     vData
     (OuroborosApplicationWithMinimalCtx appType addr ByteString IO a b)
-> fd
-> IO ()
connectToNode' Snocket IO fd addr
sn MakeBearer IO fd
makeBearer Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits VersionDataCodec Term vNumber vData
versionDataCodec NetworkConnectTracers addr vNumber
tracers HandshakeCallbacks vData
handshakeCallbacks Versions
  vNumber
  vData
  (OuroborosApplicationWithMinimalCtx appType addr ByteString IO a b)
versions fd
sd
      )

-- |
-- Connect to a remote node using an existing socket. It is up to to caller to
-- ensure that the socket is closed in case of an exception.
--
-- The connection will start with handshake protocol sending @Versions@ to the
-- remote peer.  It must fit into @'maxTransmissionUnit'@ (~5k bytes).
--
-- Exceptions thrown by @'MuxApplication'@ are rethrown by @'connectTo'@.
connectToNode'
  :: forall appType vNumber vData fd addr a b.
     ( Ord vNumber
     , Typeable vNumber
     , Show vNumber
     , Mx.HasInitiator appType ~ True
     )
  => Snocket IO fd addr
  -> Mx.MakeBearer IO fd
  -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
  -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
  -> VersionDataCodec CBOR.Term vNumber vData
  -> NetworkConnectTracers addr vNumber
  -> HandshakeCallbacks vData
  -> Versions vNumber vData (OuroborosApplicationWithMinimalCtx appType addr BL.ByteString IO a b)
  -- ^ application to run over the connection
  -> fd
  -- ^ a configured socket to use to connect to a remote service provider
  -> IO ()
connectToNode' :: forall (appType :: MuxMode) vNumber vData fd addr a b.
(Ord vNumber, Typeable vNumber, Show vNumber,
 HasInitiator appType ~ 'True) =>
Snocket IO fd addr
-> MakeBearer IO fd
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers addr vNumber
-> HandshakeCallbacks vData
-> Versions
     vNumber
     vData
     (OuroborosApplicationWithMinimalCtx appType addr ByteString IO a b)
-> fd
-> IO ()
connectToNode' Snocket IO fd addr
sn MakeBearer IO fd
makeBearer Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits VersionDataCodec Term vNumber vData
versionDataCodec NetworkConnectTracers {Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer :: forall addr vNumber.
NetworkConnectTracers addr vNumber
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nctMuxTracer, Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer :: forall addr vNumber.
NetworkConnectTracers addr vNumber
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nctHandshakeTracer } HandshakeCallbacks vData
handshakeCallbacks Versions
  vNumber
  vData
  (OuroborosApplicationWithMinimalCtx appType addr ByteString IO a b)
versions fd
sd = do
    connectionId <- (\addr
localAddress addr
remoteAddress -> ConnectionId { addr
localAddress :: addr
localAddress :: addr
localAddress, addr
remoteAddress :: addr
remoteAddress :: addr
remoteAddress })
                (addr -> addr -> ConnectionId addr)
-> IO addr -> IO (addr -> ConnectionId addr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Snocket IO fd addr -> fd -> IO addr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket IO fd addr
sn fd
sd IO (addr -> ConnectionId addr) -> IO addr -> IO (ConnectionId addr)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Snocket IO fd addr -> fd -> IO addr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getRemoteAddr Snocket IO fd addr
sn fd
sd
    muxTracer <- initDeltaQTracer' $ Mx.WithMuxBearer connectionId `contramap` nctMuxTracer
    ts_start <- getMonotonicTime

    handshakeBearer <- Mx.getBearer makeBearer sduHandshakeTimeout muxTracer sd
    app_e <-
      runHandshakeClient
        handshakeBearer
        connectionId
        -- TODO: push 'HandshakeArguments' up the call stack.
        HandshakeArguments {
          haHandshakeTracer  = nctHandshakeTracer,
          haHandshakeCodec   = handshakeCodec,
          haVersionDataCodec = versionDataCodec,
          haAcceptVersion    = acceptCb handshakeCallbacks,
          haQueryVersion     = queryCb handshakeCallbacks,
          haTimeLimits       = handshakeTimeLimits
        }
        versions
    ts_end <- getMonotonicTime
    case app_e of
         Left (HandshakeProtocolLimit ProtocolLimitFailure
err) -> do
             Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ ProtocolLimitFailure -> DiffTime -> MuxTrace
forall e. Exception e => e -> DiffTime -> MuxTrace
Mx.MuxTraceHandshakeClientError ProtocolLimitFailure
err (Time -> Time -> DiffTime
diffTime Time
ts_end Time
ts_start)
             ProtocolLimitFailure -> IO ()
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO ProtocolLimitFailure
err

         Left (HandshakeProtocolError HandshakeProtocolError vNumber
err) -> do
             Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ HandshakeProtocolError vNumber -> DiffTime -> MuxTrace
forall e. Exception e => e -> DiffTime -> MuxTrace
Mx.MuxTraceHandshakeClientError HandshakeProtocolError vNumber
err (Time -> Time -> DiffTime
diffTime Time
ts_end Time
ts_start)
             HandshakeProtocolError vNumber -> IO ()
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO HandshakeProtocolError vNumber
err

         Right (HandshakeNegotiationResult OuroborosApplicationWithMinimalCtx appType addr ByteString IO a b
app vNumber
_versionNumber vData
_agreedOptions) -> do
             Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ DiffTime -> MuxTrace
Mx.MuxTraceHandshakeClientEnd (Time -> Time -> DiffTime
diffTime Time
ts_end Time
ts_start)
             bearer <- MakeBearer IO fd
-> DiffTime -> Tracer IO MuxTrace -> fd -> IO (MuxBearer IO)
forall (m :: * -> *) fd.
MakeBearer m fd
-> DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m)
Mx.getBearer MakeBearer IO fd
makeBearer DiffTime
sduTimeout Tracer IO MuxTrace
muxTracer fd
sd
             Mx.muxStart
               muxTracer
               (toApplication MinimalInitiatorContext { micConnectionId = connectionId }
                              ResponderContext { rcConnectionId = connectionId }
                              app)
               bearer

         Right (HandshakeQueryResult Map vNumber (Either Text vData)
_vMap) -> do
             Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ DiffTime -> MuxTrace
Mx.MuxTraceHandshakeClientEnd (Time -> Time -> DiffTime
diffTime Time
ts_end Time
ts_start)
             HandshakeProtocolError vNumber -> IO ()
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (forall vNumber. HandshakeProtocolError vNumber
QueryNotSupported @vNumber)


-- Wraps a Socket inside a Snocket and calls connectToNode'
connectToNodeSocket
  :: forall appType vNumber vData a b.
     ( Ord vNumber
     , Typeable vNumber
     , Show vNumber
     , Mx.HasInitiator appType ~ True
     )
  => IOManager
  -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
  -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
  -> VersionDataCodec CBOR.Term vNumber vData
  -> NetworkConnectTracers Socket.SockAddr vNumber
  -> HandshakeCallbacks vData
  -> Versions vNumber vData (OuroborosApplicationWithMinimalCtx appType Socket.SockAddr BL.ByteString IO a b)
  -- ^ application to run over the connection
  -> Socket.Socket
  -> IO ()
connectToNodeSocket :: forall (appType :: MuxMode) vNumber vData a b.
(Ord vNumber, Typeable vNumber, Show vNumber,
 HasInitiator appType ~ 'True) =>
IOManager
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers SockAddr vNumber
-> HandshakeCallbacks vData
-> Versions
     vNumber
     vData
     (OuroborosApplicationWithMinimalCtx
        appType SockAddr ByteString IO a b)
-> Socket
-> IO ()
connectToNodeSocket IOManager
iocp Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits VersionDataCodec Term vNumber vData
versionDataCodec NetworkConnectTracers SockAddr vNumber
tracers HandshakeCallbacks vData
handshakeCallbacks Versions
  vNumber
  vData
  (OuroborosApplicationWithMinimalCtx
     appType SockAddr ByteString IO a b)
versions Socket
sd =
    Snocket IO Socket SockAddr
-> MakeBearer IO Socket
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers SockAddr vNumber
-> HandshakeCallbacks vData
-> Versions
     vNumber
     vData
     (OuroborosApplicationWithMinimalCtx
        appType SockAddr ByteString IO a b)
-> Socket
-> IO ()
forall (appType :: MuxMode) vNumber vData fd addr a b.
(Ord vNumber, Typeable vNumber, Show vNumber,
 HasInitiator appType ~ 'True) =>
Snocket IO fd addr
-> MakeBearer IO fd
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> NetworkConnectTracers addr vNumber
-> HandshakeCallbacks vData
-> Versions
     vNumber
     vData
     (OuroborosApplicationWithMinimalCtx appType addr ByteString IO a b)
-> fd
-> IO ()
connectToNode'
      (IOManager -> Snocket IO Socket SockAddr
Snocket.socketSnocket IOManager
iocp)
      MakeBearer IO Socket
Mx.makeSocketBearer
      Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec
      ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
      VersionDataCodec Term vNumber vData
versionDataCodec
      NetworkConnectTracers SockAddr vNumber
tracers
      HandshakeCallbacks vData
handshakeCallbacks
      Versions
  vNumber
  vData
  (OuroborosApplicationWithMinimalCtx
     appType SockAddr ByteString IO a b)
versions
      Socket
sd

-- |
-- Wrapper for OuroborosResponderApplication and OuroborosInitiatorAndResponderApplication.
--
data SomeResponderApplication addr bytes m b where
     SomeResponderApplication
       :: forall appType addr bytes m a b.
          Mx.HasResponder appType ~ True
       => (OuroborosApplicationWithMinimalCtx appType addr bytes m a b)
       -> SomeResponderApplication addr bytes m b

-- |
-- Accept or reject an incoming connection.  Each record contains the new state
-- after accepting / rejecting a connection.  When accepting a connection one
-- has to give a mux application which necessarily has the server side, and
-- optionally has the client side.
--
-- TODO:
-- If the other side will not allow us to run the client side on the incoming
-- connection, the whole connection will terminate.  We might want to be more
-- admissible in this scenario: leave the server thread running and let only
-- the client thread to die.
data AcceptConnection st vNumber vData peerid m bytes where

    AcceptConnection
      :: forall st vNumber vData peerid bytes m b.
         !st
      -> !(ConnectionId peerid)
      -> Versions vNumber vData (SomeResponderApplication peerid bytes m b)
      -> AcceptConnection st vNumber vData peerid m bytes

    RejectConnection
      :: !st
      -> !(ConnectionId peerid)
      -> AcceptConnection st vNumber vData peerid m bytes


-- |
-- Accept or reject incoming connection based on the current state and address
-- of the incoming connection.
--
beginConnection
    :: forall vNumber vData addr st fd.
       ( Ord vNumber
       , Typeable vNumber
       , Show vNumber
       )
    => Mx.MakeBearer IO fd
    -> Tracer IO (Mx.WithMuxBearer (ConnectionId addr) Mx.MuxTrace)
    -> Tracer IO (Mx.WithMuxBearer (ConnectionId addr) (TraceSendRecv (Handshake vNumber CBOR.Term)))
    -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
    -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
    -> VersionDataCodec CBOR.Term vNumber vData
    -> HandshakeCallbacks vData
    -> (Time -> addr -> st -> STM.STM (AcceptConnection st vNumber vData addr IO BL.ByteString))
    -- ^ either accept or reject a connection.
    -> Server.BeginConnection addr fd st ()
beginConnection :: forall vNumber vData addr st fd.
(Ord vNumber, Typeable vNumber, Show vNumber) =>
MakeBearer IO fd
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> HandshakeCallbacks vData
-> (Time
    -> addr
    -> st
    -> STM (AcceptConnection st vNumber vData addr IO ByteString))
-> BeginConnection addr fd st ()
beginConnection MakeBearer IO fd
makeBearer Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
muxTracer Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
handshakeTracer Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits VersionDataCodec Term vNumber vData
versionDataCodec HandshakeCallbacks vData
handshakeCallbacks Time
-> addr
-> st
-> STM (AcceptConnection st vNumber vData addr IO ByteString)
fn Time
t addr
addr st
st = do
    accept <- Time
-> addr
-> st
-> STM (AcceptConnection st vNumber vData addr IO ByteString)
fn Time
t addr
addr st
st
    case accept of
      AcceptConnection st
st' ConnectionId addr
connectionId Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions -> HandleConnection fd st () -> STM (HandleConnection fd st ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HandleConnection fd st () -> STM (HandleConnection fd st ()))
-> HandleConnection fd st () -> STM (HandleConnection fd st ())
forall a b. (a -> b) -> a -> b
$ st -> (fd -> IO ()) -> HandleConnection fd st ()
forall st channel r.
st -> (channel -> IO r) -> HandleConnection channel st r
Server.Accept st
st' ((fd -> IO ()) -> HandleConnection fd st ())
-> (fd -> IO ()) -> HandleConnection fd st ()
forall a b. (a -> b) -> a -> b
$ \fd
sd -> do
        muxTracer' <- Tracer IO MuxTrace -> IO (Tracer IO MuxTrace)
forall (m :: * -> *).
MonadSTM m =>
Tracer m MuxTrace -> m (Tracer m MuxTrace)
initDeltaQTracer' (Tracer IO MuxTrace -> IO (Tracer IO MuxTrace))
-> Tracer IO MuxTrace -> IO (Tracer IO MuxTrace)
forall a b. (a -> b) -> a -> b
$ ConnectionId addr
-> MuxTrace -> WithMuxBearer (ConnectionId addr) MuxTrace
forall peerid a. peerid -> a -> WithMuxBearer peerid a
Mx.WithMuxBearer ConnectionId addr
connectionId (MuxTrace -> WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
-> Tracer IO MuxTrace
forall a' a. (a' -> a) -> Tracer IO a -> Tracer IO a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
`contramap` Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
muxTracer

        traceWith muxTracer' $ Mx.MuxTraceHandshakeStart

        handshakeBearer <- Mx.getBearer makeBearer sduHandshakeTimeout muxTracer' sd
        app_e <-
          runHandshakeServer
            handshakeBearer
            connectionId
            HandshakeArguments {
              haHandshakeTracer  = handshakeTracer,
              haHandshakeCodec   = handshakeCodec,
              haVersionDataCodec = versionDataCodec,
              haAcceptVersion    = acceptCb handshakeCallbacks,
              haQueryVersion     = queryCb handshakeCallbacks,
              haTimeLimits       = handshakeTimeLimits
            }
           versions

        case app_e of
             Left (HandshakeProtocolLimit ProtocolLimitFailure
err) -> do
                 Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer' (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ ProtocolLimitFailure -> MuxTrace
forall e. Exception e => e -> MuxTrace
Mx.MuxTraceHandshakeServerError ProtocolLimitFailure
err
                 ProtocolLimitFailure -> IO ()
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO ProtocolLimitFailure
err

             Left (HandshakeProtocolError HandshakeProtocolError vNumber
err) -> do
                 Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer' (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ HandshakeProtocolError vNumber -> MuxTrace
forall e. Exception e => e -> MuxTrace
Mx.MuxTraceHandshakeServerError HandshakeProtocolError vNumber
err
                 HandshakeProtocolError vNumber -> IO ()
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO HandshakeProtocolError vNumber
err

             Right (HandshakeNegotiationResult (SomeResponderApplication OuroborosApplicationWithMinimalCtx appType addr ByteString IO a b
app) vNumber
_versionNumber vData
_agreedOptions) -> do
                 Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer' (MuxTrace -> IO ()) -> MuxTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceHandshakeServerEnd
                 bearer <- MakeBearer IO fd
-> DiffTime -> Tracer IO MuxTrace -> fd -> IO (MuxBearer IO)
forall (m :: * -> *) fd.
MakeBearer m fd
-> DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m)
Mx.getBearer MakeBearer IO fd
makeBearer DiffTime
sduTimeout Tracer IO MuxTrace
muxTracer' fd
sd
                 Mx.muxStart
                   muxTracer'
                   (toApplication MinimalInitiatorContext { micConnectionId = connectionId }
                                  ResponderContext { rcConnectionId = connectionId }
                                  app)
                   bearer

             Right (HandshakeQueryResult Map vNumber (Either Text vData)
_vMap) -> do
                 Tracer IO MuxTrace -> MuxTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
muxTracer' MuxTrace
Mx.MuxTraceHandshakeServerEnd
                 -- Wait 20s for client to receive response, who should close the connection.
                 DiffTime -> IO ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
handshake_QUERY_SHUTDOWN_DELAY

      RejectConnection st
st' ConnectionId addr
_peerid -> HandleConnection fd st () -> STM (HandleConnection fd st ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HandleConnection fd st () -> STM (HandleConnection fd st ()))
-> HandleConnection fd st () -> STM (HandleConnection fd st ())
forall a b. (a -> b) -> a -> b
$ st -> HandleConnection fd st ()
forall st channel r. st -> HandleConnection channel st r
Server.Reject st
st'

mkListeningSocket
    :: Snocket IO fd addr
    -> (fd -> addr -> IO ())
    -> addr
    -> Snocket.AddressFamily addr
    -> IO fd
mkListeningSocket :: forall fd addr.
Snocket IO fd addr
-> (fd -> addr -> IO ()) -> addr -> AddressFamily addr -> IO fd
mkListeningSocket Snocket IO fd addr
sn fd -> addr -> IO ()
configureSock addr
addr AddressFamily addr
family_ = do
    sd <- Snocket IO fd addr -> AddressFamily addr -> IO fd
forall (m :: * -> *) fd addr.
Snocket m fd addr -> AddressFamily addr -> m fd
Snocket.open Snocket IO fd addr
sn AddressFamily addr
family_
    configureSock sd addr
    Snocket.bind sn sd addr
    Snocket.listen sn sd
    pure sd

-- |
-- Make a server-compatible socket from a network socket.
--
fromSnocket
    :: forall fd addr. Ord addr
    => ConnectionTable IO addr
    -> Snocket IO fd addr
    -> fd -- ^ socket or handle
    -> IO (Server.Socket addr fd)
fromSnocket :: forall fd addr.
Ord addr =>
ConnectionTable IO addr
-> Snocket IO fd addr -> fd -> IO (Socket addr fd)
fromSnocket ConnectionTable IO addr
tblVar Snocket IO fd addr
sn fd
sd = Accept IO fd addr -> Socket addr fd
go (Accept IO fd addr -> Socket addr fd)
-> IO (Accept IO fd addr) -> IO (Socket addr fd)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Snocket IO fd addr -> fd -> IO (Accept IO fd addr)
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> m (Accept m fd addr)
Snocket.accept Snocket IO fd addr
sn fd
sd
  where
    go :: Snocket.Accept IO fd addr -> Server.Socket addr fd
    go :: Accept IO fd addr -> Socket addr fd
go (Snocket.Accept IO (Accepted fd addr, Accept IO fd addr)
accept) = IO (addr, fd, IO (), Socket addr fd) -> Socket addr fd
forall addr channel.
IO (addr, channel, IO (), Socket addr channel)
-> Socket addr channel
Server.Socket (IO (addr, fd, IO (), Socket addr fd) -> Socket addr fd)
-> IO (addr, fd, IO (), Socket addr fd) -> Socket addr fd
forall a b. (a -> b) -> a -> b
$ do
      (result, next) <- IO (Accepted fd addr, Accept IO fd addr)
accept
      case result of
        Snocket.Accepted fd
sd' addr
remoteAddr -> do
          -- TOOD: we don't need to that on each accept
          localAddr <- Snocket IO fd addr -> fd -> IO addr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket IO fd addr
sn fd
sd'
          atomically $ addConnection tblVar remoteAddr localAddr ConnectionInbound Nothing
          pure (remoteAddr, sd', close remoteAddr localAddr sd', go next)
        Snocket.AcceptFailure SomeException
err ->
          -- the is no way to construct 'Server.Socket'; This will be removed in a later commit!
          SomeException -> IO (addr, fd, IO (), Socket addr fd)
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO SomeException
err

    close :: addr -> addr -> fd -> IO ()
close addr
remoteAddr addr
localAddr fd
sd' = do
        ConnectionTable IO addr
-> addr -> addr -> ConnectionDirection -> IO ()
forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
ConnectionTable m addr
-> addr -> addr -> ConnectionDirection -> m ()
removeConnection ConnectionTable IO addr
tblVar addr
remoteAddr addr
localAddr ConnectionDirection
ConnectionInbound
        Snocket IO fd addr -> fd -> IO ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket IO fd addr
sn fd
sd'


-- | Tracers required by a server which handles inbound connections.
--
data NetworkServerTracers addr vNumber = NetworkServerTracers {
      forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer         :: Tracer IO (Mx.WithMuxBearer (ConnectionId addr) Mx.MuxTrace),
      -- ^ low level mux-network tracer, which logs mux sdu (send and received)
      -- and other low level multiplexing events.
      forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer   :: Tracer IO (Mx.WithMuxBearer (ConnectionId addr)
                                          (TraceSendRecv (Handshake vNumber CBOR.Term))),
      -- ^ handshake protocol tracer; it is important for analysing version
      -- negotation mismatches.
      forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer :: Tracer IO (WithAddr addr ErrorPolicyTrace),
      -- ^ error policy tracer; must not be 'nullTracer', otherwise all the
      -- exceptions which are not matched by any error policy will be caught
      -- and not logged or rethrown.
      forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer :: Tracer IO AcceptConnectionsPolicyTrace
      -- ^ tracing rate limiting of accepting connections.
    }

nullNetworkServerTracers :: NetworkServerTracers addr vNumber
nullNetworkServerTracers :: forall addr vNumber. NetworkServerTracers addr vNumber
nullNetworkServerTracers = NetworkServerTracers {
      nstMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer          = Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer,
      nstHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer    = Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer,
      nstErrorPolicyTracer :: Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer  = Tracer IO (WithAddr addr ErrorPolicyTrace)
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer,
      nstAcceptPolicyTracer :: Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer = Tracer IO AcceptConnectionsPolicyTrace
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer
    }

debuggingNetworkServerTracers :: (Show addr, Show vNumber)
                              =>  NetworkServerTracers addr vNumber
debuggingNetworkServerTracers :: forall addr vNumber.
(Show addr, Show vNumber) =>
NetworkServerTracers addr vNumber
debuggingNetworkServerTracers = NetworkServerTracers {
      nstMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer          = Tracer IO String
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer,
      nstHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer    = Tracer IO String
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer,
      nstErrorPolicyTracer :: Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer  = Tracer IO String -> Tracer IO (WithAddr addr ErrorPolicyTrace)
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer,
      nstAcceptPolicyTracer :: Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer = Tracer IO String -> Tracer IO AcceptConnectionsPolicyTrace
forall a (m :: * -> *). Show a => Tracer m String -> Tracer m a
showTracing Tracer IO String
forall (m :: * -> *). MonadIO m => Tracer m String
stdoutTracer
    }


-- | Mutable state maintained by the network component.
--
data NetworkMutableState addr = NetworkMutableState {
    forall addr. NetworkMutableState addr -> ConnectionTable IO addr
nmsConnectionTable :: ConnectionTable IO addr,
    -- ^ 'ConnectionTable' which maintains information about current upstream and
    -- downstream connections.
    forall addr.
NetworkMutableState addr -> StrictTVar IO (PeerStates IO addr)
nmsPeerStates      :: StrictTVar IO (PeerStates IO addr)
    -- ^ 'PeerStates' which maintains state of each downstream / upstream peer
    -- that errored, misbehaved or was not interesting to us.
  }

newNetworkMutableStateSTM :: STM.STM (NetworkMutableState addr)
newNetworkMutableStateSTM :: forall addr. STM (NetworkMutableState addr)
newNetworkMutableStateSTM =
    ConnectionTable IO addr
-> StrictTVar IO (PeerStates IO addr) -> NetworkMutableState addr
forall addr.
ConnectionTable IO addr
-> StrictTVar IO (PeerStates IO addr) -> NetworkMutableState addr
NetworkMutableState (ConnectionTable IO addr
 -> StrictTVar IO (PeerStates IO addr) -> NetworkMutableState addr)
-> STM (ConnectionTable IO addr)
-> STM
     (StrictTVar IO (PeerStates IO addr) -> NetworkMutableState addr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM (ConnectionTable IO addr)
STM IO (ConnectionTable IO addr)
forall (m :: * -> *) addr.
MonadSTM m =>
STM m (ConnectionTable m addr)
newConnectionTableSTM
                        STM
  (StrictTVar IO (PeerStates IO addr) -> NetworkMutableState addr)
-> STM (StrictTVar IO (PeerStates IO addr))
-> STM (NetworkMutableState addr)
forall a b. STM (a -> b) -> STM a -> STM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> STM (StrictTVar IO (PeerStates IO addr))
STM IO (StrictTVar IO (PeerStates IO addr))
forall (m :: * -> *) addr.
MonadSTM m =>
STM m (StrictTVar m (PeerStates m addr))
newPeerStatesVarSTM

newNetworkMutableState :: IO (NetworkMutableState addr)
newNetworkMutableState :: forall addr. IO (NetworkMutableState addr)
newNetworkMutableState = STM IO (NetworkMutableState addr) -> IO (NetworkMutableState addr)
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM (NetworkMutableState addr)
STM IO (NetworkMutableState addr)
forall addr. STM (NetworkMutableState addr)
newNetworkMutableStateSTM

-- | Clean 'PeerStates' within 'NetworkMutableState' every 200s
--
cleanNetworkMutableState :: NetworkMutableState addr
                         -> IO ()
cleanNetworkMutableState :: forall addr. NetworkMutableState addr -> IO ()
cleanNetworkMutableState NetworkMutableState {StrictTVar IO (PeerStates IO addr)
nmsPeerStates :: forall addr.
NetworkMutableState addr -> StrictTVar IO (PeerStates IO addr)
nmsPeerStates :: StrictTVar IO (PeerStates IO addr)
nmsPeerStates} =
    DiffTime -> StrictTVar IO (PeerStates IO addr) -> IO ()
forall (m :: * -> *) addr.
(MonadDelay m, MonadTimer m) =>
DiffTime -> StrictTVar m (PeerStates m addr) -> m ()
cleanPeerStates DiffTime
200 StrictTVar IO (PeerStates IO addr)
nmsPeerStates

-- |
-- Thin wrapper around @'Server.run'@.
--
runServerThread
    :: forall vNumber vData fd addr b.
       ( Ord vNumber
       , Typeable vNumber
       , Show vNumber
       , Ord addr
       )
    => NetworkServerTracers addr vNumber
    -> NetworkMutableState addr
    -> Snocket IO fd addr
    -> Mx.MakeBearer IO fd
    -> fd
    -> AcceptedConnectionsLimit
    -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
    -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
    -> VersionDataCodec CBOR.Term vNumber vData
    -> HandshakeCallbacks vData
    -> Versions vNumber vData (SomeResponderApplication addr BL.ByteString IO b)
    -> ErrorPolicies
    -> IO Void
runServerThread :: forall vNumber vData fd addr b.
(Ord vNumber, Typeable vNumber, Show vNumber, Ord addr) =>
NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> Snocket IO fd addr
-> MakeBearer IO fd
-> fd
-> AcceptedConnectionsLimit
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> HandshakeCallbacks vData
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> IO Void
runServerThread NetworkServerTracers { Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer :: forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer :: Tracer IO (WithMuxBearer (ConnectionId addr) MuxTrace)
nstMuxTracer
                                     , Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer :: forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer
     IO
     (WithMuxBearer
        (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer :: Tracer
  IO
  (WithMuxBearer
     (ConnectionId addr) (TraceSendRecv (Handshake vNumber Term)))
nstHandshakeTracer
                                     , Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer :: forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer :: Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer
                                     , Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer :: forall addr vNumber.
NetworkServerTracers addr vNumber
-> Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer :: Tracer IO AcceptConnectionsPolicyTrace
nstAcceptPolicyTracer
                                     }
                NetworkMutableState { ConnectionTable IO addr
nmsConnectionTable :: forall addr. NetworkMutableState addr -> ConnectionTable IO addr
nmsConnectionTable :: ConnectionTable IO addr
nmsConnectionTable
                                    , StrictTVar IO (PeerStates IO addr)
nmsPeerStates :: forall addr.
NetworkMutableState addr -> StrictTVar IO (PeerStates IO addr)
nmsPeerStates :: StrictTVar IO (PeerStates IO addr)
nmsPeerStates }
                Snocket IO fd addr
sn
                MakeBearer IO fd
makeBearer
                fd
sd
                AcceptedConnectionsLimit
acceptedConnectionsLimit
                Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec
                ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
                VersionDataCodec Term vNumber vData
versionDataCodec
                HandshakeCallbacks vData
handshakeCallbacks
                Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
                ErrorPolicies
errorPolicies = do
    sockAddr <- Snocket IO fd addr -> fd -> IO addr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket IO fd addr
sn fd
sd
    serverSocket <- fromSnocket nmsConnectionTable sn sd
    Server.run
        nstErrorPolicyTracer
        nstAcceptPolicyTracer
        serverSocket
        acceptedConnectionsLimit
        (acceptException sockAddr)
        (beginConnection makeBearer nstMuxTracer nstHandshakeTracer handshakeCodec handshakeTimeLimits versionDataCodec handshakeCallbacks (acceptConnectionTx sockAddr))
        -- register producer when application starts, it will be unregistered
        -- using 'CompleteConnection'
        (\addr
remoteAddr Async ()
thread PeerStates IO addr
st -> PeerStates IO addr -> STM (PeerStates IO addr)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PeerStates IO addr -> STM (PeerStates IO addr))
-> PeerStates IO addr -> STM (PeerStates IO addr)
forall a b. (a -> b) -> a -> b
$ addr -> Async IO () -> PeerStates IO addr -> PeerStates IO addr
forall (m :: * -> *) addr.
(Ord addr, Ord (Async m ())) =>
addr -> Async m () -> PeerStates m addr -> PeerStates m addr
registerProducer addr
remoteAddr Async ()
Async IO ()
thread
        PeerStates IO addr
st)
        completeTx mainTx (toLazyTVar nmsPeerStates)
  where
    mainTx :: Server.Main (PeerStates IO addr) Void
    mainTx :: Main (PeerStates IO addr) Void
mainTx (ThrowException e
e) = e -> STM Void
forall e a. Exception e => e -> STM a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO e
e
    mainTx PeerStates{}       = STM Void
STM IO Void
forall a. STM IO a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry

    -- When a connection completes, we do nothing. State is ().
    -- Crucially: we don't re-throw exceptions, because doing so would
    -- bring down the server.
    completeTx :: Server.CompleteConnection
                    addr
                    (PeerStates IO addr)
                    (WithAddr addr ErrorPolicyTrace)
                    ()
    completeTx :: CompleteConnection addr (PeerStates IO addr) Any ()
completeTx Result addr ()
result PeerStates IO addr
st = case Result addr ()
result of

      Server.Result Async ()
thread addr
remoteAddr Time
t (Left (SomeException e
e)) ->
        (PeerStates IO addr -> PeerStates IO addr)
-> CompleteApplicationResult IO addr (PeerStates IO addr)
-> CompleteApplicationResult IO addr (PeerStates IO addr)
forall a b.
(a -> b)
-> CompleteApplicationResult IO addr a
-> CompleteApplicationResult IO addr b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (addr -> Async IO () -> PeerStates IO addr -> PeerStates IO addr
forall (m :: * -> *) addr.
(Ord addr, Ord (Async m ())) =>
addr -> Async m () -> PeerStates m addr -> PeerStates m addr
unregisterProducer addr
remoteAddr Async ()
Async IO ()
thread)
          (CompleteApplicationResult IO addr (PeerStates IO addr)
 -> CompleteApplicationResult IO addr (PeerStates IO addr))
-> STM (CompleteApplicationResult IO addr (PeerStates IO addr))
-> STM (CompleteApplicationResult IO addr (PeerStates IO addr))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ErrorPolicies
-> CompleteApplication IO (PeerStates IO addr) addr Any
forall (m :: * -> *) addr a.
(MonadAsync m, Ord addr, Ord (Async m ())) =>
ErrorPolicies -> CompleteApplication m (PeerStates m addr) addr a
completeApplicationTx ErrorPolicies
errorPolicies (Time -> addr -> e -> Result addr Any
forall e addr r. Exception e => Time -> addr -> e -> Result addr r
ApplicationError Time
t addr
remoteAddr e
e) PeerStates IO addr
st

      Server.Result Async ()
thread addr
remoteAddr Time
t (Right ()
r) ->
        (PeerStates IO addr -> PeerStates IO addr)
-> CompleteApplicationResult IO addr (PeerStates IO addr)
-> CompleteApplicationResult IO addr (PeerStates IO addr)
forall a b.
(a -> b)
-> CompleteApplicationResult IO addr a
-> CompleteApplicationResult IO addr b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (addr -> Async IO () -> PeerStates IO addr -> PeerStates IO addr
forall (m :: * -> *) addr.
(Ord addr, Ord (Async m ())) =>
addr -> Async m () -> PeerStates m addr -> PeerStates m addr
unregisterProducer addr
remoteAddr Async ()
Async IO ()
thread)
          (CompleteApplicationResult IO addr (PeerStates IO addr)
 -> CompleteApplicationResult IO addr (PeerStates IO addr))
-> STM (CompleteApplicationResult IO addr (PeerStates IO addr))
-> STM (CompleteApplicationResult IO addr (PeerStates IO addr))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ErrorPolicies
-> CompleteApplication IO (PeerStates IO addr) addr ()
forall (m :: * -> *) addr a.
(MonadAsync m, Ord addr, Ord (Async m ())) =>
ErrorPolicies -> CompleteApplication m (PeerStates m addr) addr a
completeApplicationTx ErrorPolicies
errorPolicies (Time -> addr -> () -> Result addr ()
forall addr r. Time -> addr -> r -> Result addr r
ApplicationResult Time
t addr
remoteAddr ()
r) PeerStates IO addr
st

    iseCONNABORTED :: IOError -> Bool
#if defined(mingw32_HOST_OS)
    -- On Windows the network packet classifies all errors
    -- as OtherError. This means that we're forced to match
    -- on the error string. The text string comes from
    -- the network package's winSockErr.c, and if it ever
    -- changes we must update our text string too.
    iseCONNABORTED (IOError _ _ _ "Software caused connection abort (WSAECONNABORTED)" _ _) = True
    iseCONNABORTED _ = False
#else
    iseCONNABORTED :: IOException -> Bool
iseCONNABORTED (IOError Maybe Handle
_ IOErrorType
_ String
_ String
_ (Just CInt
cerrno) Maybe String
_) = Errno
eCONNABORTED Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== CInt -> Errno
Errno CInt
cerrno
#if defined(darwin_HOST_OS)
    -- There is a bug in accept for IPv6 sockets. Instead of returning -1
    -- and setting errno to ECONNABORTED an invalid (>= 0) file descriptor
    -- is returned, with the client address left unchanged. The uninitialized
    -- client address causes the network package to throw the user error below.
    iseCONNABORTED (IOError _ UserError _ "Network.Socket.Types.peekSockAddr: address family '0' not supported." _ _) = True
#endif
    iseCONNABORTED IOException
_ = Bool
False
#endif


    acceptException :: addr -> IOException -> IO ()
    acceptException :: addr -> IOException -> IO ()
acceptException addr
a IOException
e = do
      Tracer IO ErrorPolicyTrace -> ErrorPolicyTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith (addr -> ErrorPolicyTrace -> WithAddr addr ErrorPolicyTrace
forall addr a. addr -> a -> WithAddr addr a
WithAddr addr
a (ErrorPolicyTrace -> WithAddr addr ErrorPolicyTrace)
-> Tracer IO (WithAddr addr ErrorPolicyTrace)
-> Tracer IO ErrorPolicyTrace
forall a' a. (a' -> a) -> Tracer IO a -> Tracer IO a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
`contramap` Tracer IO (WithAddr addr ErrorPolicyTrace)
nstErrorPolicyTracer) (ErrorPolicyTrace -> IO ()) -> ErrorPolicyTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ IOException -> ErrorPolicyTrace
ErrorPolicyAcceptException IOException
e

      -- Try the determine if the connection was aborted by the remote end
      -- before we could process the accept, or if it was a resource exaustion
      -- problem.
      -- NB. This piece of code is fragile and depends on specific
      -- strings/mappings in the network and base libraries.
      if IOException -> Bool
iseCONNABORTED IOException
e then () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                          else IOException -> IO ()
forall e a. Exception e => e -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO IOException
e

    acceptConnectionTx :: addr
-> Time
-> addr
-> PeerStates IO addr
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
acceptConnectionTx addr
sockAddr Time
t addr
connAddr PeerStates IO addr
st = do
      d <- BeforeConnect IO (PeerStates IO addr) addr
forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
BeforeConnect m (PeerStates m addr) addr
beforeConnectTx Time
t addr
connAddr PeerStates IO addr
st
      case d of
        AllowConnection PeerStates IO addr
st'    -> AcceptConnection
  (PeerStates IO addr) vNumber vData addr IO ByteString
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AcceptConnection
   (PeerStates IO addr) vNumber vData addr IO ByteString
 -> STM
      (AcceptConnection
         (PeerStates IO addr) vNumber vData addr IO ByteString))
-> AcceptConnection
     (PeerStates IO addr) vNumber vData addr IO ByteString
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
forall a b. (a -> b) -> a -> b
$ PeerStates IO addr
-> ConnectionId addr
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> AcceptConnection
     (PeerStates IO addr) vNumber vData addr IO ByteString
forall st vNumber vData peerid bytes (m :: * -> *) b.
st
-> ConnectionId peerid
-> Versions
     vNumber vData (SomeResponderApplication peerid bytes m b)
-> AcceptConnection st vNumber vData peerid m bytes
AcceptConnection PeerStates IO addr
st' (addr -> addr -> ConnectionId addr
forall addr. addr -> addr -> ConnectionId addr
ConnectionId addr
sockAddr addr
connAddr) Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
        OnlyAccept PeerStates IO addr
st'         -> AcceptConnection
  (PeerStates IO addr) vNumber vData addr IO ByteString
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AcceptConnection
   (PeerStates IO addr) vNumber vData addr IO ByteString
 -> STM
      (AcceptConnection
         (PeerStates IO addr) vNumber vData addr IO ByteString))
-> AcceptConnection
     (PeerStates IO addr) vNumber vData addr IO ByteString
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
forall a b. (a -> b) -> a -> b
$ PeerStates IO addr
-> ConnectionId addr
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> AcceptConnection
     (PeerStates IO addr) vNumber vData addr IO ByteString
forall st vNumber vData peerid bytes (m :: * -> *) b.
st
-> ConnectionId peerid
-> Versions
     vNumber vData (SomeResponderApplication peerid bytes m b)
-> AcceptConnection st vNumber vData peerid m bytes
AcceptConnection PeerStates IO addr
st' (addr -> addr -> ConnectionId addr
forall addr. addr -> addr -> ConnectionId addr
ConnectionId addr
sockAddr addr
connAddr) Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
        DisallowConnection PeerStates IO addr
st' -> AcceptConnection
  (PeerStates IO addr) vNumber vData addr IO ByteString
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AcceptConnection
   (PeerStates IO addr) vNumber vData addr IO ByteString
 -> STM
      (AcceptConnection
         (PeerStates IO addr) vNumber vData addr IO ByteString))
-> AcceptConnection
     (PeerStates IO addr) vNumber vData addr IO ByteString
-> STM
     (AcceptConnection
        (PeerStates IO addr) vNumber vData addr IO ByteString)
forall a b. (a -> b) -> a -> b
$ PeerStates IO addr
-> ConnectionId addr
-> AcceptConnection
     (PeerStates IO addr) vNumber vData addr IO ByteString
forall st peerid vNumber vData (m :: * -> *) bytes.
st
-> ConnectionId peerid
-> AcceptConnection st vNumber vData peerid m bytes
RejectConnection PeerStates IO addr
st' (addr -> addr -> ConnectionId addr
forall addr. addr -> addr -> ConnectionId addr
ConnectionId addr
sockAddr addr
connAddr)

-- | Run a server application. It will listen on the given address for incoming
-- connection, otherwise like withServerNode'.
withServerNode
    :: forall vNumber vData t fd addr b.
       ( Ord vNumber
       , Typeable vNumber
       , Show vNumber
       , Ord addr
       )
    => Snocket IO fd addr
    -> Mx.MakeBearer IO fd
    -> (fd -> addr -> IO ()) -- ^ callback to configure a socket
    -> NetworkServerTracers addr vNumber
    -> NetworkMutableState addr
    -> AcceptedConnectionsLimit
    -> addr
    -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
    -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
    -> VersionDataCodec CBOR.Term vNumber vData
    -> HandshakeCallbacks vData
    -> Versions vNumber vData (SomeResponderApplication addr BL.ByteString IO b)
    -- ^ The mux application that will be run on each incoming connection from
    -- a given address.  Note that if @'MuxClientAndServerApplication'@ is
    -- returned, the connection will run a full duplex set of mini-protocols.
    -> ErrorPolicies
    -> (addr -> Async Void -> IO t)
    -- ^ callback which takes the @Async@ of the thread that is running the server.
    -- Note: the server thread will terminate when the callback returns or
    -- throws an exception.
    -> IO t
withServerNode :: forall vNumber vData t fd addr b.
(Ord vNumber, Typeable vNumber, Show vNumber, Ord addr) =>
Snocket IO fd addr
-> MakeBearer IO fd
-> (fd -> addr -> IO ())
-> NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> AcceptedConnectionsLimit
-> addr
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> HandshakeCallbacks vData
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> (addr -> Async Void -> IO t)
-> IO t
withServerNode Snocket IO fd addr
sn MakeBearer IO fd
makeBearer
               fd -> addr -> IO ()
configureSock
               NetworkServerTracers addr vNumber
tracers
               NetworkMutableState addr
networkState
               AcceptedConnectionsLimit
acceptedConnectionsLimit
               addr
addr
               Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec
               ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
               VersionDataCodec Term vNumber vData
versionDataCodec
               HandshakeCallbacks vData
handshakeCallbacks
               Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
               ErrorPolicies
errorPolicies
               addr -> Async Void -> IO t
k =
    IO fd -> (fd -> IO ()) -> (fd -> IO t) -> IO t
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (Snocket IO fd addr
-> (fd -> addr -> IO ()) -> addr -> AddressFamily addr -> IO fd
forall fd addr.
Snocket IO fd addr
-> (fd -> addr -> IO ()) -> addr -> AddressFamily addr -> IO fd
mkListeningSocket Snocket IO fd addr
sn fd -> addr -> IO ()
configureSock addr
addr (Snocket IO fd addr -> addr -> AddressFamily addr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket IO fd addr
sn addr
addr)) (Snocket IO fd addr -> fd -> IO ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket IO fd addr
sn) ((fd -> IO t) -> IO t) -> (fd -> IO t) -> IO t
forall a b. (a -> b) -> a -> b
$ \fd
sd -> do
      Snocket IO fd addr
-> MakeBearer IO fd
-> NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> AcceptedConnectionsLimit
-> fd
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> HandshakeCallbacks vData
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> (addr -> Async Void -> IO t)
-> IO t
forall vNumber vData t fd addr b.
(Ord vNumber, Typeable vNumber, Show vNumber, Ord addr) =>
Snocket IO fd addr
-> MakeBearer IO fd
-> NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> AcceptedConnectionsLimit
-> fd
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> HandshakeCallbacks vData
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> (addr -> Async Void -> IO t)
-> IO t
withServerNode'
        Snocket IO fd addr
sn
        MakeBearer IO fd
makeBearer
        NetworkServerTracers addr vNumber
tracers
        NetworkMutableState addr
networkState
        AcceptedConnectionsLimit
acceptedConnectionsLimit
        fd
sd
        Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec
        ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
        VersionDataCodec Term vNumber vData
versionDataCodec
        HandshakeCallbacks vData
handshakeCallbacks
        Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
        ErrorPolicies
errorPolicies
        addr -> Async Void -> IO t
k

-- |
-- Run a server application on the provided socket. The socket must be ready to accept connections.
-- The server thread runs using @withAsync@ function, which means
-- that it will terminate when the callback terminates or throws an exception.
--
-- TODO: we should track connections in the state and refuse connections from
-- peers we are already connected to.  This is also the right place to ban
-- connection from peers which misbehaved.
--
-- The server will run handshake protocol on each incoming connection.  We
-- assume that each version negotiation message should fit into
-- @'maxTransmissionUnit'@ (~5k bytes).
--
-- Note: it will open a socket in the current thread and pass it to the spawned
-- thread which runs the server.  This makes it useful for testing, where we
-- need to guarantee that a socket is open before we try to connect to it.
withServerNode'
    :: forall vNumber vData t fd addr b.
       ( Ord vNumber
       , Typeable vNumber
       , Show vNumber
       , Ord addr
       )
    => Snocket IO fd addr
    -> Mx.MakeBearer IO fd
    -> NetworkServerTracers addr vNumber
    -> NetworkMutableState addr
    -> AcceptedConnectionsLimit
    -> fd
    -- ^ a configured socket to be used be the server.  The server will call
    -- `bind` and `listen` methods but it will not set any socket or tcp options
    -- on it.
    -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
    -> ProtocolTimeLimits (Handshake vNumber CBOR.Term)
    -> VersionDataCodec CBOR.Term vNumber vData
    -> HandshakeCallbacks vData
    -> Versions vNumber vData (SomeResponderApplication addr BL.ByteString IO b)
    -- ^ The mux application that will be run on each incoming connection from
    -- a given address.  Note that if @'MuxClientAndServerApplication'@ is
    -- returned, the connection will run a full duplex set of mini-protocols.
    -> ErrorPolicies
    -> (addr -> Async Void -> IO t)
    -- ^ callback which takes the @Async@ of the thread that is running the server.
    -- Note: the server thread will terminate when the callback returns or
    -- throws an exception.
    -> IO t
withServerNode' :: forall vNumber vData t fd addr b.
(Ord vNumber, Typeable vNumber, Show vNumber, Ord addr) =>
Snocket IO fd addr
-> MakeBearer IO fd
-> NetworkServerTracers addr vNumber
-> NetworkMutableState addr
-> AcceptedConnectionsLimit
-> fd
-> Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> VersionDataCodec Term vNumber vData
-> HandshakeCallbacks vData
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString IO b)
-> ErrorPolicies
-> (addr -> Async Void -> IO t)
-> IO t
withServerNode' Snocket IO fd addr
sn MakeBearer IO fd
makeBearer
                NetworkServerTracers addr vNumber
tracers
                NetworkMutableState addr
networkState
                AcceptedConnectionsLimit
acceptedConnectionsLimit
                fd
sd
                Codec (Handshake vNumber Term) DeserialiseFailure IO ByteString
handshakeCodec
                ProtocolTimeLimits (Handshake vNumber Term)
handshakeTimeLimits
                VersionDataCodec Term vNumber vData
versionDataCodec
                HandshakeCallbacks vData
handshakeCallbacks
                Versions
  vNumber vData (SomeResponderApplication addr ByteString IO b)
versions
                ErrorPolicies
errorPolicies
                addr -> Async Void -> IO t
k = do
      addr' <- Snocket IO fd addr -> fd -> IO addr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket IO fd addr
sn fd
sd
      withAsync
        (runServerThread
          tracers
          networkState
          sn
          makeBearer
          sd
          acceptedConnectionsLimit
          handshakeCodec
          handshakeTimeLimits
          versionDataCodec
          handshakeCallbacks
          versions
          errorPolicies)
        (k addr')