{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | A simple server.  The server doesn't control resource usage (e.g. limiting
-- of inbound connections) and thus should only be used in a safe environment.
--
-- The module should be imported qualified.
module Ouroboros.Network.Server.Simple where

import Control.Applicative (Alternative)
import Control.Concurrent.JobPool qualified as JobPool
import Control.Monad (forever)
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadFork (MonadFork)
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTimer.SI (MonadTimer)
import Control.Tracer (nullTracer)
import Data.ByteString.Lazy qualified as BL
import Data.Functor (void)
import Data.Typeable (Typeable)
import Data.Void (Void)

import Network.Mux qualified as Mx

import Ouroboros.Network.ConnectionId
import Ouroboros.Network.Mux
import Ouroboros.Network.Protocol.Handshake
import Ouroboros.Network.Snocket as Snocket
import Ouroboros.Network.Socket


with :: forall fd addr vNumber vData m a b.
        ( Alternative (STM m),
          MonadAsync m,
          MonadFork  m,
          MonadLabelledSTM m,
          MonadMask  m,
          MonadTimer m,
          MonadThrow (STM m),
          Ord vNumber,
          Typeable vNumber,
          Show vNumber
        )
     => Snocket m fd addr
     -> Mx.MakeBearer m fd
     -> (fd -> addr -> m ())
     -> addr
     -> HandshakeArguments (ConnectionId addr) vNumber vData m
     -> Versions vNumber vData (SomeResponderApplication addr BL.ByteString m b)
     -> (addr -> Async m Void -> m a)
     -> m a
with :: forall fd addr vNumber vData (m :: * -> *) a b.
(Alternative (STM m), MonadAsync m, MonadFork m,
 MonadLabelledSTM m, MonadMask m, MonadTimer m, MonadThrow (STM m),
 Ord vNumber, Typeable vNumber, Show vNumber) =>
Snocket m fd addr
-> MakeBearer m fd
-> (fd -> addr -> m ())
-> addr
-> HandshakeArguments (ConnectionId addr) vNumber vData m
-> Versions
     vNumber vData (SomeResponderApplication addr ByteString m b)
-> (addr -> Async m Void -> m a)
-> m a
with Snocket m fd addr
sn MakeBearer m fd
makeBearer fd -> addr -> m ()
configureSock addr
addr HandshakeArguments (ConnectionId addr) vNumber vData m
handshakeArgs Versions
  vNumber vData (SomeResponderApplication addr ByteString m b)
versions addr -> Async m Void -> m a
k =
   (JobPool () m () -> m a) -> m a
forall group (m :: * -> *) a b.
(MonadAsync m, MonadThrow m, MonadLabelledSTM m) =>
(JobPool group m a -> m b) -> m b
JobPool.withJobPool ((JobPool () m () -> m a) -> m a)
-> (JobPool () m () -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \JobPool () m ()
jobPool ->
   m (fd, addr) -> ((fd, addr) -> m ()) -> ((fd, addr) -> m a) -> m a
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
     (do sd <- Snocket m fd addr -> AddressFamily addr -> m fd
forall (m :: * -> *) fd addr.
Snocket m fd addr -> AddressFamily addr -> m fd
Snocket.open Snocket m fd addr
sn (Snocket m fd addr -> addr -> AddressFamily addr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket m fd addr
sn addr
addr)
         configureSock sd addr
         Snocket.bind sn sd addr
         Snocket.listen sn sd
         addr' <- getLocalAddr sn sd
         pure (sd, addr'))
     (Snocket m fd addr -> fd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m fd addr
sn (fd -> m ()) -> ((fd, addr) -> fd) -> (fd, addr) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (fd, addr) -> fd
forall a b. (a, b) -> a
fst)
     (\(fd
sock, addr
addr') ->
       -- accept loop
       m Void -> (Async m Void -> m a) -> m a
forall a b. m a -> (Async m a -> m b) -> m b
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync (m () -> m Void
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m Void) -> m () -> m Void
forall a b. (a -> b) -> a -> b
$ JobPool () m () -> fd -> m ()
acceptOne JobPool () m ()
jobPool fd
sock) (addr -> Async m Void -> m a
k addr
addr')
     )
  where
    acceptOne :: JobPool.JobPool () m () -> fd -> m ()
    acceptOne :: JobPool () m () -> fd -> m ()
acceptOne JobPool () m ()
jobPool fd
sock = Snocket m fd addr -> fd -> m (Accept m fd addr)
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> m (Accept m fd addr)
accept Snocket m fd addr
sn fd
sock m (Accept m fd addr)
-> (Accept m fd addr -> m (Accepted fd addr, Accept m fd addr))
-> m (Accepted fd addr, Accept m fd addr)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Accept m fd addr -> m (Accepted fd addr, Accept m fd addr)
forall (m :: * -> *) fd addr.
Accept m fd addr -> m (Accepted fd addr, Accept m fd addr)
runAccept m (Accepted fd addr, Accept m fd addr)
-> ((Accepted fd addr, Accept m fd addr) -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      (Accepted fd
sock' addr
remoteAddr, Accept m fd addr
_) -> do
        let connThread :: m ()
connThread = do
              -- connection responder thread
              let connId :: ConnectionId addr
connId = addr -> addr -> ConnectionId addr
forall addr. addr -> addr -> ConnectionId addr
ConnectionId addr
addr addr
remoteAddr
              bearer <- MakeBearer m fd -> DiffTime -> Tracer m Trace -> fd -> m (Bearer m)
forall (m :: * -> *) fd.
MakeBearer m fd -> DiffTime -> Tracer m Trace -> fd -> m (Bearer m)
Mx.getBearer MakeBearer m fd
makeBearer
                        (-DiffTime
1) Tracer m Trace
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer fd
sock'
              configureSock sock' addr
              r <- runHandshakeServer bearer connId handshakeArgs versions
              case r of
                Left (HandshakeProtocolLimit ProtocolLimitFailure
e) -> ProtocolLimitFailure -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO ProtocolLimitFailure
e
                Left (HandshakeProtocolError HandshakeProtocolError vNumber
e) -> HandshakeProtocolError vNumber -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO HandshakeProtocolError vNumber
e
                Right HandshakeQueryResult {}   -> [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"handshake query is not supported"
                Right (HandshakeNegotiationResult (SomeResponderApplication OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
app) vNumber
vNumber vData
vData) -> do
                  mux <- [MiniProtocolInfo muxMode] -> m (Mux muxMode m)
forall (mode :: Mode) (m :: * -> *).
MonadLabelledSTM m =>
[MiniProtocolInfo mode] -> m (Mux mode m)
Mx.new (ForkPolicyCb
-> OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
-> [MiniProtocolInfo muxMode]
forall (mode :: Mode) initiatorCtx responderCtx bytes (m :: * -> *)
       a b.
ForkPolicyCb
-> OuroborosApplication mode initiatorCtx responderCtx bytes m a b
-> [MiniProtocolInfo mode]
toMiniProtocolInfos (ForkPolicy addr -> addr -> ForkPolicyCb
forall peerAddr. ForkPolicy peerAddr -> peerAddr -> ForkPolicyCb
runForkPolicy ForkPolicy addr
forall peerAddr. ForkPolicy peerAddr
noBindForkPolicy (ConnectionId addr -> addr
forall addr. ConnectionId addr -> addr
remoteAddress ConnectionId addr
connId)) OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
app)
                  withAsync (Mx.run nullTracer mux bearer) $ \Async m ()
aid -> do
                    m (Either SomeException (Either a b)) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (Either SomeException (Either a b)) -> m ())
-> m (Either SomeException (Either a b)) -> m ()
forall a b. (a -> b) -> a -> b
$ ConnectionId addr
-> vNumber
-> vData
-> OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
-> Mux muxMode m
-> Async m ()
-> m (Either SomeException (Either a b))
forall (muxMode :: Mode) addr vNumber vData (m :: * -> *) a b.
(Alternative (STM m), MonadAsync m, MonadSTM m, MonadThrow m,
 MonadThrow (STM m)) =>
ConnectionId addr
-> vNumber
-> vData
-> OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
-> Mux muxMode m
-> Async m ()
-> m (Either SomeException (Either a b))
simpleMuxCallback ConnectionId addr
connId vNumber
vNumber vData
vData OuroborosApplicationWithMinimalCtx muxMode addr ByteString m a b
app Mux muxMode m
mux Async m ()
aid

            errorHandler :: SomeException -> m a
errorHandler = \SomeException
e -> SomeException -> m a
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO SomeException
e

        JobPool () m () -> Job () m () -> m ()
forall group (m :: * -> *) a.
(MonadAsync m, MonadMask m, Ord group) =>
JobPool group m a -> Job group m a -> m ()
JobPool.forkJob JobPool () m ()
jobPool
                      (Job () m () -> m ()) -> Job () m () -> m ()
forall a b. (a -> b) -> a -> b
$ m () -> (SomeException -> m ()) -> () -> [Char] -> Job () m ()
forall group (m :: * -> *) a.
m a -> (SomeException -> m a) -> group -> [Char] -> Job group m a
JobPool.Job m ()
connThread
                                    SomeException -> m ()
forall {a}. SomeException -> m a
errorHandler
                                    ()
                                    [Char]
"conn-thread"
      (AcceptFailure SomeException
e, Accept m fd addr
_) -> SomeException -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO SomeException
e