{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Ouroboros.Network.Server.Simple
  ( with
  , ServerTracer (..)
  ) where
import Control.Applicative (Alternative)
import Control.Concurrent.JobPool qualified as JobPool
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTimer.SI
import Control.Tracer (Tracer, traceWith)
import Data.ByteString.Lazy qualified as BL
import Data.Functor (void)
import Data.Typeable (Typeable)
import Data.Void (Void)
import Network.Mux qualified as Mx
import Ouroboros.Network.ConnectionId
import Ouroboros.Network.Mux
import Ouroboros.Network.Protocol.Handshake
import Ouroboros.Network.Server (isECONNABORTED, server_CONNABORTED_DELAY)
import Ouroboros.Network.Snocket (Snocket)
import Ouroboros.Network.Snocket qualified as Snocket
import Ouroboros.Network.Socket
data ServerTracer addr
  = AcceptException SomeException
  | ConnectionHandlerException (ConnectionId addr) SomeException
  deriving Int -> ServerTracer addr -> ShowS
[ServerTracer addr] -> ShowS
ServerTracer addr -> String
(Int -> ServerTracer addr -> ShowS)
-> (ServerTracer addr -> String)
-> ([ServerTracer addr] -> ShowS)
-> Show (ServerTracer addr)
forall addr. Show addr => Int -> ServerTracer addr -> ShowS
forall addr. Show addr => [ServerTracer addr] -> ShowS
forall addr. Show addr => ServerTracer addr -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall addr. Show addr => Int -> ServerTracer addr -> ShowS
showsPrec :: Int -> ServerTracer addr -> ShowS
$cshow :: forall addr. Show addr => ServerTracer addr -> String
show :: ServerTracer addr -> String
$cshowList :: forall addr. Show addr => [ServerTracer addr] -> ShowS
showList :: [ServerTracer addr] -> ShowS
Show
with :: forall fd addr vNumber vData m a b.
        ( Alternative (STM m),
          MonadAsync m,
          MonadDelay m,
          MonadFork  m,
          MonadLabelledSTM m,
          MonadMask  m,
          MonadTimer m,
          MonadThrow (STM m),
          Ord vNumber,
          Typeable vNumber,
          Show vNumber
        )
     => Snocket m fd addr
     
     -> Tracer m (ServerTracer addr)
     
     -> Mx.TracersWithBearer (ConnectionId addr) m
     
     -> Mx.MakeBearer m fd
     -> (fd -> addr -> m ())
     
     -> addr
     
     -> HandshakeArguments (ConnectionId addr) vNumber vData m
     
     -> Versions vNumber vData (SomeResponderApplication addr BL.ByteString m b)
     
     -> (addr -> Async m Void -> m a)
     
     
     -> m a
with :: forall fd addr vNumber vData (m :: * -> *) a b.
(Alternative (STM m), MonadAsync m, MonadDelay m, MonadFork m,
 MonadLabelledSTM m, MonadMask m, MonadTimer m, MonadThrow (STM m),
 Ord vNumber, Typeable vNumber, Show vNumber) =>
Snocket m fd addr
-> Tracer m (ServerTracer addr)
-> TracersWithBearer (ConnectionId addr) m
-> MakeBearer m fd
-> (fd -> addr -> m ())
-> addr
-> HandshakeArguments (ConnectionId addr) vNumber vData m
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString m b)
-> (addr -> Async m Void -> m a)
-> m a
with Snocket m fd addr
sn Tracer m (ServerTracer addr)
tracer TracersWithBearer (ConnectionId addr) m
muxTracers MakeBearer m fd
makeBearer fd -> addr -> m ()
configureSock addr
addr HandshakeArguments (ConnectionId addr) vNumber vData m
handshakeArgs Versions
  vNumber vData (SomeResponderApplication addr ByteString m b)
versions addr -> Async m Void -> m a
k =
   (JobPool () m () -> m a) -> m a
forall group (m :: * -> *) a b.
(MonadAsync m, MonadThrow m, MonadLabelledSTM m) =>
(JobPool group m a -> m b) -> m b
JobPool.withJobPool ((JobPool () m () -> m a) -> m a)
-> (JobPool () m () -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \JobPool () m ()
jobPool ->
   m (fd, addr) -> ((fd, addr) -> m ()) -> ((fd, addr) -> m a) -> m a
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
     (do sd <- Snocket m fd addr -> AddressFamily addr -> m fd
forall (m :: * -> *) fd addr.
Snocket m fd addr -> AddressFamily addr -> m fd
Snocket.open Snocket m fd addr
sn (Snocket m fd addr -> addr -> AddressFamily addr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket m fd addr
sn addr
addr)
         configureSock sd addr
         Snocket.bind sn sd addr
         Snocket.listen sn sd
         addr' <- Snocket.getLocalAddr sn sd
         pure (sd, addr'))
     (Snocket m fd addr -> fd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m fd addr
sn (fd -> m ()) -> ((fd, addr) -> fd) -> (fd, addr) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (fd, addr) -> fd
forall a b. (a, b) -> a
fst)
     (\(fd
sock, addr
localAddress) ->
       
       m Void -> (Async m Void -> m a) -> m a
forall a b. m a -> (Async m a -> m b) -> m b
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync (Snocket m fd addr -> fd -> m (Accept m fd addr)
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> m (Accept m fd addr)
Snocket.accept Snocket m fd addr
sn fd
sock m (Accept m fd addr) -> (Accept m fd addr -> m Void) -> m Void
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= JobPool () m () -> addr -> Accept m fd addr -> m Void
acceptLoop JobPool () m ()
jobPool addr
localAddress)
                 (addr -> Async m Void -> m a
k addr
localAddress)
     )
  where
    acceptLoop :: JobPool.JobPool () m ()
               -> addr
               -> Snocket.Accept m fd addr
               -> m Void
    acceptLoop :: JobPool () m () -> addr -> Accept m fd addr -> m Void
acceptLoop
        JobPool () m ()
jobPool
        addr
localAddress
        Snocket.Accept { m (Accepted fd addr, Accept m fd addr)
runAccept :: m (Accepted fd addr, Accept m fd addr)
runAccept :: forall (m :: * -> *) fd addr.
Accept m fd addr -> m (Accepted fd addr, Accept m fd addr)
Snocket.runAccept }
        = do
        (accepted, acceptNext) <- m (Accepted fd addr, Accept m fd addr)
runAccept
        acceptOne accepted
        acceptLoop jobPool
                   localAddress
                   acceptNext
      where
        
        
        acceptOne :: Snocket.Accepted fd addr -> m ()
        acceptOne :: Accepted fd addr -> m ()
acceptOne (Snocket.AcceptFailure SomeException
e)
          | Just IOError
ioErr <- SomeException -> Maybe IOError
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e
          , IOError -> Bool
isECONNABORTED IOError
ioErr
          = DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
server_CONNABORTED_DELAY
        acceptOne (Snocket.AcceptFailure SomeException
e)
          = do Tracer m (ServerTracer addr) -> ServerTracer addr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (ServerTracer addr)
tracer (SomeException -> ServerTracer addr
forall addr. SomeException -> ServerTracer addr
AcceptException SomeException
e)
               SomeException -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO SomeException
e
        acceptOne (Snocket.Accepted fd
sock' addr
remoteAddress) = do
            let connId :: ConnectionId addr
connId = ConnectionId { addr
localAddress :: addr
localAddress :: addr
localAddress, addr
remoteAddress :: addr
remoteAddress :: addr
remoteAddress }
                connThread :: m ()
connThread = do
                  
                  bearer <- MakeBearer m fd -> MakeBearerCb m fd
forall (m :: * -> *) fd. MakeBearer m fd -> MakeBearerCb m fd
Mx.getBearer MakeBearer m fd
makeBearer (-DiffTime
1) fd
sock' Maybe (ReadBuffer m)
forall a. Maybe a
Nothing
                  configureSock sock' localAddress
                  r <- runHandshakeServer bearer connId handshakeArgs versions
                  case r of
                    Left (HandshakeProtocolLimit ProtocolLimitFailure
e) -> ProtocolLimitFailure -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO ProtocolLimitFailure
e
                    Left (HandshakeProtocolError HandshakeProtocolError vNumber
e) -> HandshakeProtocolError vNumber -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO HandshakeProtocolError vNumber
e
                    Right HandshakeQueryResult {}   -> String -> m ()
forall a. HasCallStack => String -> a
error String
"handshake query is not supported"
                    Right (HandshakeNegotiationResult (SomeResponderApplication OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
app) vNumber
vNumber vData
vData) -> do
                      mux <- Tracers m -> [MiniProtocolInfo muxMode] -> m (Mux muxMode m)
forall (mode :: Mode) (m :: * -> *).
MonadLabelledSTM m =>
Tracers m -> [MiniProtocolInfo mode] -> m (Mux mode m)
Mx.new (ConnectionId addr
connId ConnectionId addr
-> TracersWithBearer (ConnectionId addr) m -> Tracers m
forall peerId (m :: * -> *).
peerId -> TracersWithBearer peerId m -> Tracers m
`Mx.tracersWithBearer` TracersWithBearer (ConnectionId addr) m
muxTracers)
                                    (ForkPolicyCb
-> OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
-> [MiniProtocolInfo muxMode]
forall (mode :: Mode) initiatorCtx responderCtx bytes (m :: * -> *)
       a b.
ForkPolicyCb
-> OuroborosApplication mode initiatorCtx responderCtx bytes m a b
-> [MiniProtocolInfo mode]
toMiniProtocolInfos
                                      (ForkPolicy addr -> addr -> ForkPolicyCb
forall peerAddr. ForkPolicy peerAddr -> peerAddr -> ForkPolicyCb
runForkPolicy ForkPolicy addr
forall peerAddr. ForkPolicy peerAddr
noBindForkPolicy addr
remoteAddress)
                                      OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
app)
                      withAsync (Mx.run mux bearer) $ \Async m ()
aid -> do
                        m (Either SomeException (Either a b)) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (Either SomeException (Either a b)) -> m ())
-> m (Either SomeException (Either a b)) -> m ()
forall a b. (a -> b) -> a -> b
$ ConnectionId addr
-> vNumber
-> vData
-> OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
-> Mux muxMode m
-> Async m ()
-> m (Either SomeException (Either a b))
forall (muxMode :: Mode) addr vNumber vData (m :: * -> *) a b.
(Alternative (STM m), MonadAsync m, MonadSTM m, MonadThrow m,
 MonadThrow (STM m)) =>
ConnectionId addr
-> vNumber
-> vData
-> OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
-> Mux muxMode m
-> Async m ()
-> m (Either SomeException (Either a b))
simpleMuxCallback ConnectionId addr
connId vNumber
vNumber vData
vData OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
app Mux muxMode m
mux Async m ()
aid
                errorHandler :: SomeException -> m ()
errorHandler = \SomeException
e -> Tracer m (ServerTracer addr) -> ServerTracer addr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (ServerTracer addr)
tracer (ConnectionId addr -> SomeException -> ServerTracer addr
forall addr.
ConnectionId addr -> SomeException -> ServerTracer addr
ConnectionHandlerException ConnectionId addr
connId SomeException
e)
                                  m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SomeException -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO SomeException
e
            JobPool () m () -> Job () m () -> m ()
forall group (m :: * -> *) a.
(MonadAsync m, MonadMask m, Ord group) =>
JobPool group m a -> Job group m a -> m ()
JobPool.forkJob JobPool () m ()
jobPool
                          (Job () m () -> m ()) -> Job () m () -> m ()
forall a b. (a -> b) -> a -> b
$ m () -> (SomeException -> m ()) -> () -> String -> Job () m ()
forall group (m :: * -> *) a.
m a -> (SomeException -> m a) -> group -> String -> Job group m a
JobPool.Job m ()
connThread
                                        SomeException -> m ()
errorHandler
                                        ()
                                        String
"conn-thread"