{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
module Ouroboros.Network.Diffusion.Utils
( withSockets
, withLocalSocket
) where
import Control.Monad.Class.MonadThrow
import Control.Tracer (Tracer, traceWith)
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NonEmpty
import Data.Typeable (Typeable)
import Ouroboros.Network.Snocket (FileDescriptor, Snocket)
import Ouroboros.Network.Snocket qualified as Snocket
import Ouroboros.Network.Diffusion.Common
withSockets :: forall m ntnFd ntnAddr ntcAddr a.
( MonadCatch m
, Typeable ntnAddr
, Show ntnAddr
)
=> Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> Snocket m ntnFd ntnAddr
-> (ntnFd -> ntnAddr -> m ())
-> (ntnFd -> ntnAddr -> m ())
-> [Either ntnFd ntnAddr]
-> (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a)
-> m a
withSockets :: forall (m :: * -> *) ntnFd ntnAddr ntcAddr a.
(MonadCatch m, Typeable ntnAddr, Show ntnAddr) =>
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> Snocket m ntnFd ntnAddr
-> (ntnFd -> ntnAddr -> m ())
-> (ntnFd -> ntnAddr -> m ())
-> [Either ntnFd ntnAddr]
-> (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a)
-> m a
withSockets Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer Snocket m ntnFd ntnAddr
sn
ntnFd -> ntnAddr -> m ()
configureSocket
ntnFd -> ntnAddr -> m ()
configureSystemdSocket
[Either ntnFd ntnAddr]
addresses NonEmpty ntnFd -> NonEmpty ntnAddr -> m a
k = [(ntnFd, ntnAddr)] -> [Either ntnFd ntnAddr] -> m a
go [] [Either ntnFd ntnAddr]
addresses
where
go :: [(ntnFd, ntnAddr)] -> [Either ntnFd ntnAddr] -> m a
go ![(ntnFd, ntnAddr)]
acc (Either ntnFd ntnAddr
a : [Either ntnFd ntnAddr]
as) = Either ntnFd ntnAddr -> ((ntnFd, ntnAddr) -> m a) -> m a
withSocket Either ntnFd ntnAddr
a (\(ntnFd, ntnAddr)
sa -> [(ntnFd, ntnAddr)] -> [Either ntnFd ntnAddr] -> m a
go ((ntnFd, ntnAddr)
sa (ntnFd, ntnAddr) -> [(ntnFd, ntnAddr)] -> [(ntnFd, ntnAddr)]
forall a. a -> [a] -> [a]
: [(ntnFd, ntnAddr)]
acc) [Either ntnFd ntnAddr]
as)
go [] [] = Failure -> m a
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Failure
NoSocket
go ![(ntnFd, ntnAddr)]
acc [] =
let acc' :: NonEmpty (ntnFd, ntnAddr)
acc' = [(ntnFd, ntnAddr)] -> NonEmpty (ntnFd, ntnAddr)
forall a. HasCallStack => [a] -> NonEmpty a
NonEmpty.fromList ([(ntnFd, ntnAddr)] -> [(ntnFd, ntnAddr)]
forall a. [a] -> [a]
reverse [(ntnFd, ntnAddr)]
acc)
in (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a
k (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a)
-> NonEmpty ntnFd -> NonEmpty ntnAddr -> m a
forall a b. (a -> b) -> a -> b
$! ((ntnFd, ntnAddr) -> ntnFd
forall a b. (a, b) -> a
fst ((ntnFd, ntnAddr) -> ntnFd)
-> NonEmpty (ntnFd, ntnAddr) -> NonEmpty ntnFd
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (ntnFd, ntnAddr)
acc')) (NonEmpty ntnAddr -> m a) -> NonEmpty ntnAddr -> m a
forall a b. (a -> b) -> a -> b
$! ((ntnFd, ntnAddr) -> ntnAddr
forall a b. (a, b) -> b
snd ((ntnFd, ntnAddr) -> ntnAddr)
-> NonEmpty (ntnFd, ntnAddr) -> NonEmpty ntnAddr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (ntnFd, ntnAddr)
acc')
withSocket :: Either ntnFd ntnAddr
-> ((ntnFd, ntnAddr) -> m a)
-> m a
withSocket :: Either ntnFd ntnAddr -> ((ntnFd, ntnAddr) -> m a) -> m a
withSocket (Left ntnFd
sock) (ntnFd, ntnAddr) -> m a
f =
do !addr <- Snocket m ntnFd ntnAddr -> ntnFd -> m ntnAddr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket m ntnFd ntnAddr
sn ntnFd
sock
configureSystemdSocket sock addr
f (sock, addr)
m a -> m () -> m a
forall a b. m a -> m b -> m a
forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
`onException` Snocket m ntnFd ntnAddr -> ntnFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntnFd ntnAddr
sn ntnFd
sock
withSocket (Right ntnAddr
addr) (ntnFd, ntnAddr) -> m a
f =
m ntnFd -> (ntnFd -> m ()) -> (ntnFd -> m a) -> m a
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
(do Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
CreatingServerSocket ntnAddr
addr)
Snocket m ntnFd ntnAddr -> AddressFamily ntnAddr -> m ntnFd
forall (m :: * -> *) fd addr.
Snocket m fd addr -> AddressFamily addr -> m fd
Snocket.open Snocket m ntnFd ntnAddr
sn (Snocket m ntnFd ntnAddr -> ntnAddr -> AddressFamily ntnAddr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket m ntnFd ntnAddr
sn ntnAddr
addr))
(Snocket m ntnFd ntnAddr -> ntnFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntnFd ntnAddr
sn)
((ntnFd -> m a) -> m a) -> (ntnFd -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \ntnFd
sock -> do
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
ConfiguringServerSocket ntnAddr
addr
ntnFd -> ntnAddr -> m ()
configureSocket ntnFd
sock ntnAddr
addr
Snocket m ntnFd ntnAddr -> ntnFd -> ntnAddr -> m ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.bind Snocket m ntnFd ntnAddr
sn ntnFd
sock ntnAddr
addr
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
ListeningServerSocket ntnAddr
addr
Snocket m ntnFd ntnAddr -> ntnFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.listen Snocket m ntnFd ntnAddr
sn ntnFd
sock
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
ServerSocketUp ntnAddr
addr
(ntnFd, ntnAddr) -> m a
f (ntnFd
sock, ntnAddr
addr)
withLocalSocket :: forall ntnAddr ntcFd ntcAddr m a.
( MonadThrow m
, Typeable ntnAddr
, Show ntnAddr
)
=> Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> (ntcFd -> m FileDescriptor)
-> Snocket m ntcFd ntcAddr
-> Either ntcFd ntcAddr
-> (ntcFd -> m a)
-> m a
withLocalSocket :: forall ntnAddr ntcFd ntcAddr (m :: * -> *) a.
(MonadThrow m, Typeable ntnAddr, Show ntnAddr) =>
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> (ntcFd -> m FileDescriptor)
-> Snocket m ntcFd ntcAddr
-> Either ntcFd ntcAddr
-> (ntcFd -> m a)
-> m a
withLocalSocket Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer ntcFd -> m FileDescriptor
getFileDescriptor Snocket m ntcFd ntcAddr
sn Either ntcFd ntcAddr
localAddress ntcFd -> m a
k =
m (Either ntcFd (ntcFd, ntcAddr))
-> (Either ntcFd (ntcFd, ntcAddr) -> m ())
-> (Either ntcFd (ntcFd, ntcAddr) -> m a)
-> m a
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
(
case Either ntcFd ntcAddr
localAddress of
#if defined(mingw32_HOST_OS)
Left _ -> traceWith tracer (UnsupportedReadySocketCase
:: DiffusionTracer ntnAddr ntcAddr)
>> throwIO UnsupportedReadySocket
#else
Left ntcFd
sd -> do
addr <- Snocket m ntcFd ntcAddr -> ntcFd -> m ntcAddr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket m ntcFd ntcAddr
sn ntcFd
sd
traceWith tracer (UsingSystemdSocket addr)
return (Left sd)
#endif
Right ntcAddr
addr -> do
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntcAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntcAddr -> DiffusionTracer ntnAddr ntcAddr
CreateSystemdSocketForSnocketPath ntcAddr
addr
sd <- Snocket m ntcFd ntcAddr -> AddressFamily ntcAddr -> m ntcFd
forall (m :: * -> *) fd addr.
Snocket m fd addr -> AddressFamily addr -> m fd
Snocket.open Snocket m ntcFd ntcAddr
sn (Snocket m ntcFd ntcAddr -> ntcAddr -> AddressFamily ntcAddr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket m ntcFd ntcAddr
sn ntcAddr
addr)
traceWith tracer $ CreatedLocalSocket addr
return (Right (sd, addr))
)
(\case
Right (ntcFd
sd, ntcAddr
_) -> Snocket m ntcFd ntcAddr -> ntcFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntcFd ntcAddr
sn ntcFd
sd
Left ntcFd
sd -> Snocket m ntcFd ntcAddr -> ntcFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntcFd ntcAddr
sn ntcFd
sd
)
((Either ntcFd (ntcFd, ntcAddr) -> m a) -> m a)
-> (Either ntcFd (ntcFd, ntcAddr) -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \case
Right (ntcFd
sd, ntcAddr
addr) -> do
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> (FileDescriptor -> DiffusionTracer ntnAddr ntcAddr)
-> FileDescriptor
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr.
ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
ConfiguringLocalSocket ntcAddr
addr
(FileDescriptor -> m ()) -> m FileDescriptor -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ntcFd -> m FileDescriptor
getFileDescriptor ntcFd
sd
Snocket m ntcFd ntcAddr -> ntcFd -> ntcAddr -> m ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.bind Snocket m ntcFd ntcAddr
sn ntcFd
sd ntcAddr
addr
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> (FileDescriptor -> DiffusionTracer ntnAddr ntcAddr)
-> FileDescriptor
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr.
ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
ListeningLocalSocket ntcAddr
addr
(FileDescriptor -> m ()) -> m FileDescriptor -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ntcFd -> m FileDescriptor
getFileDescriptor ntcFd
sd
Snocket m ntcFd ntcAddr -> ntcFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.listen Snocket m ntcFd ntcAddr
sn ntcFd
sd
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> (FileDescriptor -> DiffusionTracer ntnAddr ntcAddr)
-> FileDescriptor
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr.
ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
LocalSocketUp ntcAddr
addr
(FileDescriptor -> m ()) -> m FileDescriptor -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ntcFd -> m FileDescriptor
getFileDescriptor ntcFd
sd
ntcFd -> m a
k ntcFd
sd
Left ntcFd
sd -> ntcFd -> m a
k ntcFd
sd