{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- `withLocalSocket` has some constraints that are only required on Windows.
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

module Ouroboros.Network.Diffusion.Utils
  ( withSockets
  , withLocalSocket
  ) where


import Control.Monad.Class.MonadThrow
import Control.Tracer (Tracer, traceWith)
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NonEmpty
import Data.Typeable (Typeable)

import Ouroboros.Network.Snocket (FileDescriptor, Snocket)
import Ouroboros.Network.Snocket qualified as Snocket

import Ouroboros.Network.Diffusion.Types

--
-- Socket utility functions
--

withSockets :: forall m ntnFd ntnAddr ntcAddr a.
               ( MonadCatch m
               , Typeable ntnAddr
               , Show     ntnAddr
               )
            => Tracer m (DiffusionTracer ntnAddr ntcAddr)
            -> Snocket m ntnFd ntnAddr
            -> (ntnFd -> ntnAddr -> m ()) -- ^ configure a socket
            -> (ntnFd -> ntnAddr -> m ()) -- ^ configure a systemd socket
            -> [Either ntnFd ntnAddr]
            -> (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a)
            -> m a
withSockets :: forall (m :: * -> *) ntnFd ntnAddr ntcAddr a.
(MonadCatch m, Typeable ntnAddr, Show ntnAddr) =>
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> Snocket m ntnFd ntnAddr
-> (ntnFd -> ntnAddr -> m ())
-> (ntnFd -> ntnAddr -> m ())
-> [Either ntnFd ntnAddr]
-> (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a)
-> m a
withSockets Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer Snocket m ntnFd ntnAddr
sn
            ntnFd -> ntnAddr -> m ()
configureSocket
            ntnFd -> ntnAddr -> m ()
configureSystemdSocket
            [Either ntnFd ntnAddr]
addresses NonEmpty ntnFd -> NonEmpty ntnAddr -> m a
k = [(ntnFd, ntnAddr)] -> [Either ntnFd ntnAddr] -> m a
go [] [Either ntnFd ntnAddr]
addresses
  where
    go :: [(ntnFd, ntnAddr)] -> [Either ntnFd ntnAddr] -> m a
go ![(ntnFd, ntnAddr)]
acc (Either ntnFd ntnAddr
a : [Either ntnFd ntnAddr]
as) = Either ntnFd ntnAddr -> ((ntnFd, ntnAddr) -> m a) -> m a
withSocket Either ntnFd ntnAddr
a (\(ntnFd, ntnAddr)
sa -> [(ntnFd, ntnAddr)] -> [Either ntnFd ntnAddr] -> m a
go ((ntnFd, ntnAddr)
sa (ntnFd, ntnAddr) -> [(ntnFd, ntnAddr)] -> [(ntnFd, ntnAddr)]
forall a. a -> [a] -> [a]
: [(ntnFd, ntnAddr)]
acc) [Either ntnFd ntnAddr]
as)
    go []   []       = Failure -> m a
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO Failure
NoSocket
    go ![(ntnFd, ntnAddr)]
acc []       =
      let acc' :: NonEmpty (ntnFd, ntnAddr)
acc' = [(ntnFd, ntnAddr)] -> NonEmpty (ntnFd, ntnAddr)
forall a. HasCallStack => [a] -> NonEmpty a
NonEmpty.fromList ([(ntnFd, ntnAddr)] -> [(ntnFd, ntnAddr)]
forall a. [a] -> [a]
reverse [(ntnFd, ntnAddr)]
acc)
      in (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a
k (NonEmpty ntnFd -> NonEmpty ntnAddr -> m a)
-> NonEmpty ntnFd -> NonEmpty ntnAddr -> m a
forall a b. (a -> b) -> a -> b
$! ((ntnFd, ntnAddr) -> ntnFd
forall a b. (a, b) -> a
fst ((ntnFd, ntnAddr) -> ntnFd)
-> NonEmpty (ntnFd, ntnAddr) -> NonEmpty ntnFd
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (ntnFd, ntnAddr)
acc')) (NonEmpty ntnAddr -> m a) -> NonEmpty ntnAddr -> m a
forall a b. (a -> b) -> a -> b
$! ((ntnFd, ntnAddr) -> ntnAddr
forall a b. (a, b) -> b
snd ((ntnFd, ntnAddr) -> ntnAddr)
-> NonEmpty (ntnFd, ntnAddr) -> NonEmpty ntnAddr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (ntnFd, ntnAddr)
acc')

    withSocket :: Either ntnFd ntnAddr
               -> ((ntnFd, ntnAddr) -> m a)
               -> m a
    withSocket :: Either ntnFd ntnAddr -> ((ntnFd, ntnAddr) -> m a) -> m a
withSocket (Left ntnFd
sock) (ntnFd, ntnAddr) -> m a
f =
      do !addr <- Snocket m ntnFd ntnAddr -> ntnFd -> m ntnAddr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket m ntnFd ntnAddr
sn ntnFd
sock
         configureSystemdSocket sock addr
         f (sock, addr)
      m a -> m () -> m a
forall a b. m a -> m b -> m a
forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
`onException` Snocket m ntnFd ntnAddr -> ntnFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntnFd ntnAddr
sn ntnFd
sock
    withSocket (Right ntnAddr
addr) (ntnFd, ntnAddr) -> m a
f =
      m ntnFd -> (ntnFd -> m ()) -> (ntnFd -> m a) -> m a
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
        (do Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
CreatingServerSocket ntnAddr
addr)
            Snocket m ntnFd ntnAddr -> AddressFamily ntnAddr -> m ntnFd
forall (m :: * -> *) fd addr.
Snocket m fd addr -> AddressFamily addr -> m fd
Snocket.open Snocket m ntnFd ntnAddr
sn (Snocket m ntnFd ntnAddr -> ntnAddr -> AddressFamily ntnAddr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket m ntnFd ntnAddr
sn ntnAddr
addr))
        (Snocket m ntnFd ntnAddr -> ntnFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntnFd ntnAddr
sn)
        ((ntnFd -> m a) -> m a) -> (ntnFd -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \ntnFd
sock -> do
          Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
ConfiguringServerSocket ntnAddr
addr
          ntnFd -> ntnAddr -> m ()
configureSocket ntnFd
sock ntnAddr
addr
          Snocket m ntnFd ntnAddr -> ntnFd -> ntnAddr -> m ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.bind Snocket m ntnFd ntnAddr
sn ntnFd
sock ntnAddr
addr
          Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
ListeningServerSocket ntnAddr
addr
          Snocket m ntnFd ntnAddr -> ntnFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.listen Snocket m ntnFd ntnAddr
sn ntnFd
sock
          Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntnAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntnAddr -> DiffusionTracer ntnAddr ntcAddr
ServerSocketUp ntnAddr
addr
          (ntnFd, ntnAddr) -> m a
f (ntnFd
sock, ntnAddr
addr)


withLocalSocket :: forall ntnAddr ntcFd ntcAddr m a.
                   ( MonadThrow m
                     -- Win32 only constraints:
                   , Typeable ntnAddr
                   , Show     ntnAddr
                   )
                => Tracer m (DiffusionTracer ntnAddr ntcAddr)
                -> (ntcFd -> m FileDescriptor)
                -> (ntcAddr -> m ())
                -- ^ configure the local socket file.
                -> Snocket m ntcFd ntcAddr
                -> Either ntcFd ntcAddr
                -> (ntcFd -> m a)
                -> m a
withLocalSocket :: forall ntnAddr ntcFd ntcAddr (m :: * -> *) a.
(MonadThrow m, Typeable ntnAddr, Show ntnAddr) =>
Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> (ntcFd -> m FileDescriptor)
-> (ntcAddr -> m ())
-> Snocket m ntcFd ntcAddr
-> Either ntcFd ntcAddr
-> (ntcFd -> m a)
-> m a
withLocalSocket Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer
                ntcFd -> m FileDescriptor
getFileDescriptor
                ntcAddr -> m ()
configureSocketFile
                Snocket m ntcFd ntcAddr
sn Either ntcFd ntcAddr
localAddress ntcFd -> m a
k =
  m (Either ntcFd (ntcFd, ntcAddr))
-> (Either ntcFd (ntcFd, ntcAddr) -> m ())
-> (Either ntcFd (ntcFd, ntcAddr) -> m a)
-> m a
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    (
      case Either ntcFd ntcAddr
localAddress of
#if defined(mingw32_HOST_OS)
         -- Windows uses named pipes so can't take advantage of existing sockets
         Left _ -> traceWith tracer (UnsupportedReadySocketCase
                                       :: DiffusionTracer ntnAddr ntcAddr)
                >> throwIO UnsupportedReadySocket
#else
         Left ntcFd
sd -> do
             addr <- Snocket m ntcFd ntcAddr -> ntcFd -> m ntcAddr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket m ntcFd ntcAddr
sn ntcFd
sd
             traceWith tracer (UsingSystemdSocket addr)
             return (Left sd)
#endif
         Right ntcAddr
addr -> do
             Tracer m (DiffusionTracer ntnAddr ntcAddr)
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (DiffusionTracer ntnAddr ntcAddr)
tracer (DiffusionTracer ntnAddr ntcAddr -> m ())
-> DiffusionTracer ntnAddr ntcAddr -> m ()
forall a b. (a -> b) -> a -> b
$ ntcAddr -> DiffusionTracer ntnAddr ntcAddr
forall ntnAddr ntcAddr. ntcAddr -> DiffusionTracer ntnAddr ntcAddr
CreateSystemdSocketForSnocketPath ntcAddr
addr
             sd <- Snocket m ntcFd ntcAddr -> AddressFamily ntcAddr -> m ntcFd
forall (m :: * -> *) fd addr.
Snocket m fd addr -> AddressFamily addr -> m fd
Snocket.open Snocket m ntcFd ntcAddr
sn (Snocket m ntcFd ntcAddr -> ntcAddr -> AddressFamily ntcAddr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket m ntcFd ntcAddr
sn ntcAddr
addr)
             traceWith tracer $ CreatedLocalSocket addr
             return (Right (sd, addr))
    )
    -- We close the socket here, even if it was provided to us.
    (\case
      Right (ntcFd
sd, ntcAddr
_) -> Snocket m ntcFd ntcAddr -> ntcFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntcFd ntcAddr
sn ntcFd
sd
      Left   ntcFd
sd     -> Snocket m ntcFd ntcAddr -> ntcFd -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m ntcFd ntcAddr
sn ntcFd
sd
    )
    ((Either ntcFd (ntcFd, ntcAddr) -> m a) -> m a)
-> (Either ntcFd (ntcFd, ntcAddr) -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \case
      -- not configured socket
      Right (ntcFd
sd, ntcAddr
addr) -> do
        fd <- ntcFd -> m FileDescriptor
getFileDescriptor ntcFd
sd
        traceWith tracer (ConfiguringLocalSocket addr fd)
        Snocket.bind sn sd addr
        configureSocketFile addr
        traceWith tracer (ConfiguredLocalSocket addr fd)
        traceWith tracer (ListeningLocalSocket addr fd)
        Snocket.listen sn sd
        traceWith tracer (LocalSocketUp addr fd)
        k sd

      -- pre-configured systemd socket
      Left ntcFd
sd -> do
        addr <- Snocket m ntcFd ntcAddr -> ntcFd -> m ntcAddr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket m ntcFd ntcAddr
sn ntcFd
sd
        configureSocketFile addr
        k sd