{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- `withLocalSocket` has some constraints that are only required on Windows.
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

module Ouroboros.Network.Diffusion.Utils
  ( withSockets
  , withLocalSocket
  ) where


import Control.Monad.Class.MonadThrow
import Control.Tracer (Tracer, traceWith)
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NonEmpty
import Data.Typeable (Typeable)

import Ouroboros.Network.Snocket (FileDescriptor, Snocket)
import Ouroboros.Network.Snocket qualified as Snocket

import Ouroboros.Network.Diffusion.Common

--
-- Socket utility functions
--

withSockets :: forall m ntnFd ntnAddr ntcAddr a.
               ( MonadCatch m
               , Typeable ntnAddr
               , Show     ntnAddr
               )
            => Tracer m (DiffusionTracer ntnAddr ntcAddr)
            -> Snocket m ntnFd ntnAddr
            -> (ntnFd -> ntnAddr -> m ()) -- ^ configure a socket
            -> (ntnFd -> ntnAddr -> m ()) -- ^ configure a systemd socket
            -> [Either ntnFd ntnAddr]
            -> (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a)
            -> m a
withSockets :: forall (m :: * -> *) ntnFd ntnAddr ntcAddr a.
(MonadCatch m, Typeable ntnAddr, Show ntnAddr) =>
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> Snocket m ntnFd ntnAddr
-> (ntnFd -> ntnAddr -> m ())
-> (ntnFd -> ntnAddr -> m ())
-> [Either ntnFd ntnAddr]
-> (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a)
-> m a
withSockets Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer Snocket m ntnFd ntnAddr
sn
            ntnFd -> ntnAddr -> m ()
configureSocket
            ntnFd -> ntnAddr -> m ()
configureSystemdSocket
            [Either ntnFd ntnAddr]
addresses NonEmpty ntnFd -> NonEmpty ntnAddr -> m a
k = [(ntnFd, ntnAddr)] -> [Either ntnFd ntnAddr] -> m a
go [] [Either ntnFd ntnAddr]
addresses
  where
    go :: [(ntnFd, ntnAddr)] -> [Either ntnFd ntnAddr] -> m a
go ![(ntnFd, ntnAddr)]
acc (Either ntnFd ntnAddr
a : [Either ntnFd ntnAddr]
as) = Either ntnFd ntnAddr -> ((ntnFd, ntnAddr) -> m a) -> m a
withSocket Either ntnFd ntnAddr
a (\(ntnFd, ntnAddr)
sa -> [(ntnFd, ntnAddr)] -> [Either ntnFd ntnAddr] -> m a
go ((ntnFd, ntnAddr)
sa (ntnFd, ntnAddr) -> [(ntnFd, ntnAddr)] -> [(ntnFd, ntnAddr)]
forall a. a -> [a] -> [a]
: [(ntnFd, ntnAddr)]
acc) [Either ntnFd ntnAddr]
as)
    go []   []       = Failure -> m a
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Failure
NoSocket
    go ![(ntnFd, ntnAddr)]
acc []       =
      let acc' :: NonEmpty (ntnFd, ntnAddr)
acc' = [(ntnFd, ntnAddr)] -> NonEmpty (ntnFd, ntnAddr)
forall a. HasCallStack => [a] -> NonEmpty a
NonEmpty.fromList ([(ntnFd, ntnAddr)] -> [(ntnFd, ntnAddr)]
forall a. [a] -> [a]
reverse [(ntnFd, ntnAddr)]
acc)
      in (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a
k (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a)
-> NonEmpty ntnFd -> NonEmpty ntnAddr -> m a
forall a b. (a -> b) -> a -> b
$! ((ntnFd, ntnAddr) -> ntnFd
forall a b. (a, b) -> a
fst ((ntnFd, ntnAddr) -> ntnFd)
-> NonEmpty (ntnFd, ntnAddr) -> NonEmpty ntnFd
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (ntnFd, ntnAddr)
acc')) (NonEmpty ntnAddr -> m a) -> NonEmpty ntnAddr -> m a
forall a b. (a -> b) -> a -> b
$! ((ntnFd, ntnAddr) -> ntnAddr
forall a b. (a, b) -> b
snd ((ntnFd, ntnAddr) -> ntnAddr)
-> NonEmpty (ntnFd, ntnAddr) -> NonEmpty ntnAddr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (ntnFd, ntnAddr)
acc')

    withSocket :: Either ntnFd ntnAddr
               -> ((ntnFd, ntnAddr) -> m a)
               -> m a
    withSocket :: Either ntnFd ntnAddr -> ((ntnFd, ntnAddr) -> m a) -> m a
withSocket (Left ntnFd
sock) (ntnFd, ntnAddr) -> m a
f =
      do !addr <- Snocket m ntnFd ntnAddr -> ntnFd -> m ntnAddr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket m ntnFd ntnAddr
sn ntnFd
sock
         configureSystemdSocket sock addr
         f (sock, addr)
      m a -> m () -> m a
forall a b. m a -> m b -> m a
forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
`onException` Snocket m ntnFd ntnAddr -> ntnFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntnFd ntnAddr
sn ntnFd
sock
    withSocket (Right ntnAddr
addr) (ntnFd, ntnAddr) -> m a
f =
      m ntnFd -> (ntnFd -> m ()) -> (ntnFd -> m a) -> m a
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
        (do Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
CreatingServerSocket ntnAddr
addr)
            Snocket m ntnFd ntnAddr -> AddressFamily ntnAddr -> m ntnFd
forall (m :: * -> *) fd addr.
Snocket m fd addr -> AddressFamily addr -> m fd
Snocket.open Snocket m ntnFd ntnAddr
sn (Snocket m ntnFd ntnAddr -> ntnAddr -> AddressFamily ntnAddr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket m ntnFd ntnAddr
sn ntnAddr
addr))
        (Snocket m ntnFd ntnAddr -> ntnFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntnFd ntnAddr
sn)
        ((ntnFd -> m a) -> m a) -> (ntnFd -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \ntnFd
sock -> do
          Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
ConfiguringServerSocket ntnAddr
addr
          ntnFd -> ntnAddr -> m ()
configureSocket ntnFd
sock ntnAddr
addr
          Snocket m ntnFd ntnAddr -> ntnFd -> ntnAddr -> m ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.bind Snocket m ntnFd ntnAddr
sn ntnFd
sock ntnAddr
addr
          Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
ListeningServerSocket ntnAddr
addr
          Snocket m ntnFd ntnAddr -> ntnFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.listen Snocket m ntnFd ntnAddr
sn ntnFd
sock
          Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
ServerSocketUp ntnAddr
addr
          (ntnFd, ntnAddr) -> m a
f (ntnFd
sock, ntnAddr
addr)


withLocalSocket :: forall ntnAddr ntcFd ntcAddr m a.
                   ( MonadThrow m
                     -- Win32 only constraints:
                   , Typeable ntnAddr
                   , Show     ntnAddr
                   )
                => Tracer m (DiffusionTracer ntnAddr ntcAddr)
                -> (ntcFd -> m FileDescriptor)
                -> Snocket m ntcFd ntcAddr
                -> Either ntcFd ntcAddr
                -> (ntcFd -> m a)
                -> m a
withLocalSocket :: forall ntnAddr ntcFd ntcAddr (m :: * -> *) a.
(MonadThrow m, Typeable ntnAddr, Show ntnAddr) =>
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> (ntcFd -> m FileDescriptor)
-> Snocket m ntcFd ntcAddr
-> Either ntcFd ntcAddr
-> (ntcFd -> m a)
-> m a
withLocalSocket Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer ntcFd -> m FileDescriptor
getFileDescriptor Snocket m ntcFd ntcAddr
sn Either ntcFd ntcAddr
localAddress ntcFd -> m a
k =
  m (Either ntcFd (ntcFd, ntcAddr))
-> (Either ntcFd (ntcFd, ntcAddr) -> m ())
-> (Either ntcFd (ntcFd, ntcAddr) -> m a)
-> m a
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    (
      case Either ntcFd ntcAddr
localAddress of
#if defined(mingw32_HOST_OS)
         -- Windows uses named pipes so can't take advantage of existing sockets
         Left _ -> traceWith tracer (UnsupportedReadySocketCase
                                       :: DiffusionTracer ntnAddr ntcAddr)
                >> throwIO UnsupportedReadySocket
#else
         Left ntcFd
sd -> do
             addr <- Snocket m ntcFd ntcAddr -> ntcFd -> m ntcAddr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket m ntcFd ntcAddr
sn ntcFd
sd
             traceWith tracer (UsingSystemdSocket addr)
             return (Left sd)
#endif
         Right ntcAddr
addr -> do
             Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntcAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntcAddr -> DiffusionTracer ntnAddr ntcAddr
CreateSystemdSocketForSnocketPath ntcAddr
addr
             sd <- Snocket m ntcFd ntcAddr -> AddressFamily ntcAddr -> m ntcFd
forall (m :: * -> *) fd addr.
Snocket m fd addr -> AddressFamily addr -> m fd
Snocket.open Snocket m ntcFd ntcAddr
sn (Snocket m ntcFd ntcAddr -> ntcAddr -> AddressFamily ntcAddr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket m ntcFd ntcAddr
sn ntcAddr
addr)
             traceWith tracer $ CreatedLocalSocket addr
             return (Right (sd, addr))
    )
    -- We close the socket here, even if it was provided to us.
    (\case
      Right (ntcFd
sd, ntcAddr
_) -> Snocket m ntcFd ntcAddr -> ntcFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntcFd ntcAddr
sn ntcFd
sd
      Left   ntcFd
sd     -> Snocket m ntcFd ntcAddr -> ntcFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntcFd ntcAddr
sn ntcFd
sd
    )
    ((Either ntcFd (ntcFd, ntcAddr) -> m a) -> m a)
-> (Either ntcFd (ntcFd, ntcAddr) -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \case
      -- unconfigured socket
      Right (ntcFd
sd, ntcAddr
addr) -> do
        Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> (FileDescriptor -> DiffusionTracer ntnAddr ntcAddr)
-> FileDescriptor
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr.
ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
ConfiguringLocalSocket ntcAddr
addr
           (FileDescriptor -> m ()) -> m FileDescriptor -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ntcFd -> m FileDescriptor
getFileDescriptor ntcFd
sd
        Snocket m ntcFd ntcAddr -> ntcFd -> ntcAddr -> m ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.bind Snocket m ntcFd ntcAddr
sn ntcFd
sd ntcAddr
addr
        Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> (FileDescriptor -> DiffusionTracer ntnAddr ntcAddr)
-> FileDescriptor
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr.
ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
ListeningLocalSocket ntcAddr
addr
           (FileDescriptor -> m ()) -> m FileDescriptor -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ntcFd -> m FileDescriptor
getFileDescriptor ntcFd
sd
        Snocket m ntcFd ntcAddr -> ntcFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.listen Snocket m ntcFd ntcAddr
sn ntcFd
sd
        Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> (FileDescriptor -> DiffusionTracer ntnAddr ntcAddr)
-> FileDescriptor
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr.
ntcAddr -> FileDescriptor -> DiffusionTracer ntnAddr ntcAddr
LocalSocketUp ntcAddr
addr
           (FileDescriptor -> m ()) -> m FileDescriptor -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ntcFd -> m FileDescriptor
getFileDescriptor ntcFd
sd
        ntcFd -> m a
k ntcFd
sd

      -- pre-configured systemd socket
      Left ntcFd
sd -> ntcFd -> m a
k ntcFd
sd