{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE RecursiveDo         #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}

module Ouroboros.Network.Subscription.Worker
  ( SocketStateChange
  , SocketState (..)
  , CompleteApplication
  , ConnectResult (..)
  , Result (..)
  , Main
  , StateVar
  , LocalAddresses (..)
    -- * Subscription worker
  , WorkerCallbacks (..)
  , WorkerParams (..)
  , worker
    -- * Socket API
  , safeConnect
    -- * Constants
  , defaultConnectionAttemptDelay
  , minConnectionAttemptDelay
  , maxConnectionAttemptDelay
  , ipRetryDelay
    -- * Errors
  , SubscriberError (..)
    -- * Tracing
  , SubscriptionTrace (..)
  ) where

import Control.Applicative ((<|>))
import Control.Concurrent.STM qualified as STM
import Control.Exception (SomeException (..))
import Control.Monad (forever, join, unless, when)
import Control.Monad.Fix (MonadFix)
import Data.Foldable (traverse_)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Void (Void)
import GHC.Stack
import Network.Socket (Family (AF_UNIX))
import Text.Printf

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI
import Control.Tracer

import Ouroboros.Network.ErrorPolicy (CompleteApplication,
           CompleteApplicationResult (..), ErrorPolicyTrace, Result (..),
           WithAddr)
import Ouroboros.Network.Server.ConnectionTable
import Ouroboros.Network.Snocket (Snocket (..))
import Ouroboros.Network.Snocket qualified as Snocket
import Ouroboros.Network.Subscription.Subscriber

-- | Time to wait between connection attempts when we don't have any DeltaQ
-- info.
--
defaultConnectionAttemptDelay :: DiffTime
defaultConnectionAttemptDelay :: DiffTime
defaultConnectionAttemptDelay = DiffTime
0.025 -- 25ms delay

-- | Minimum time to wait between connection attempts.
--
minConnectionAttemptDelay :: DiffTime
minConnectionAttemptDelay :: DiffTime
minConnectionAttemptDelay = DiffTime
0.010 -- 10ms delay

-- | Maximum time to wait between connection attempts.
--
maxConnectionAttemptDelay :: DiffTime
maxConnectionAttemptDelay :: DiffTime
maxConnectionAttemptDelay = DiffTime
2 -- 2s delay

-- | Minimum time to wait between ip reconnects
--
ipRetryDelay :: DiffTime
ipRetryDelay :: DiffTime
ipRetryDelay = DiffTime
10 -- 10s delay

data ResOrAct m addr tr r =
     Res !(Result addr r)
   | Act (Set (Async m ())) -- ^ threads to kill
         (Maybe tr)         -- ^ trace point

-- | Result queue.  The spawned threads will keep writing to it, while the main
-- server will read from it.
--
type ResultQ m addr tr r = StrictTQueue m (ResOrAct m addr tr r)

newResultQ :: forall m addr tr r. MonadSTM m => m (ResultQ m addr tr r)
newResultQ :: forall (m :: * -> *) addr tr r.
MonadSTM m =>
m (ResultQ m addr tr r)
newResultQ = STM m (ResultQ m addr tr r) -> m (ResultQ m addr tr r)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (ResultQ m addr tr r) -> m (ResultQ m addr tr r))
-> STM m (ResultQ m addr tr r) -> m (ResultQ m addr tr r)
forall a b. (a -> b) -> a -> b
$ STM m (ResultQ m addr tr r)
forall (m :: * -> *) a. MonadSTM m => STM m (StrictTQueue m a)
newTQueue

-- | Mutable state kept by the worker.  All the workers in this module are
-- polymorphic over the state type.  The state is updated with two callbacks:
--
-- * 'CompleteConnect'     - STM transaction which runs when the connect call
--                           returned, if it thrown an exception it will be
--                           passed to the callback.
-- * 'CompleteApplication' - STM transaction which runs when application
--                           returned.  It will receive the result of the
--                           application or an exception raised by it.
--
type StateVar m s = StrictTVar m s

-- | The set of all spawned threads. Used for waiting or cancelling them when
-- the server shuts down.
--
type ThreadsVar m = StrictTVar m (Set (Async m ()))


data SocketState m addr
   = CreatedSocket !addr !(Async m ())
   | ClosedSocket  !addr !(Async m ())

-- | Callback which fires: when we create or close a socket.
--
type SocketStateChange m s addr = SocketState m addr -> s -> STM m s

-- | Given current state 'retry' too keep the subscription worker going.
-- When this transaction returns, all the threads spawned by the worker will be
-- killed.
--
type Main m s t = s -> STM m t

data LocalAddresses addr = LocalAddresses {
    -- | Local IPv4 address to use, Nothing indicates don't use IPv4
    forall addr. LocalAddresses addr -> Maybe addr
laIpv4 :: Maybe addr
    -- | Local IPv6 address to use, Nothing indicates don't use IPv6
  , forall addr. LocalAddresses addr -> Maybe addr
laIpv6 :: Maybe addr
    -- | Local Unix address to use, Nothing indicates don't use Unix sockets
  , forall addr. LocalAddresses addr -> Maybe addr
laUnix :: Maybe addr
  } deriving (LocalAddresses addr -> LocalAddresses addr -> Bool
(LocalAddresses addr -> LocalAddresses addr -> Bool)
-> (LocalAddresses addr -> LocalAddresses addr -> Bool)
-> Eq (LocalAddresses addr)
forall addr.
Eq addr =>
LocalAddresses addr -> LocalAddresses addr -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall addr.
Eq addr =>
LocalAddresses addr -> LocalAddresses addr -> Bool
== :: LocalAddresses addr -> LocalAddresses addr -> Bool
$c/= :: forall addr.
Eq addr =>
LocalAddresses addr -> LocalAddresses addr -> Bool
/= :: LocalAddresses addr -> LocalAddresses addr -> Bool
Eq, Int -> LocalAddresses addr -> ShowS
[LocalAddresses addr] -> ShowS
LocalAddresses addr -> String
(Int -> LocalAddresses addr -> ShowS)
-> (LocalAddresses addr -> String)
-> ([LocalAddresses addr] -> ShowS)
-> Show (LocalAddresses addr)
forall addr. Show addr => Int -> LocalAddresses addr -> ShowS
forall addr. Show addr => [LocalAddresses addr] -> ShowS
forall addr. Show addr => LocalAddresses addr -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall addr. Show addr => Int -> LocalAddresses addr -> ShowS
showsPrec :: Int -> LocalAddresses addr -> ShowS
$cshow :: forall addr. Show addr => LocalAddresses addr -> String
show :: LocalAddresses addr -> String
$cshowList :: forall addr. Show addr => [LocalAddresses addr] -> ShowS
showList :: [LocalAddresses addr] -> ShowS
Show)

instance Semigroup (LocalAddresses addr) where
    LocalAddresses addr
a <> :: LocalAddresses addr -> LocalAddresses addr -> LocalAddresses addr
<> LocalAddresses addr
b = LocalAddresses {
        laIpv4 :: Maybe addr
laIpv4 = LocalAddresses addr -> Maybe addr
forall addr. LocalAddresses addr -> Maybe addr
laIpv4 LocalAddresses addr
a Maybe addr -> Maybe addr -> Maybe addr
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> LocalAddresses addr -> Maybe addr
forall addr. LocalAddresses addr -> Maybe addr
laIpv4 LocalAddresses addr
b,
        laIpv6 :: Maybe addr
laIpv6 = LocalAddresses addr -> Maybe addr
forall addr. LocalAddresses addr -> Maybe addr
laIpv6 LocalAddresses addr
a Maybe addr -> Maybe addr -> Maybe addr
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> LocalAddresses addr -> Maybe addr
forall addr. LocalAddresses addr -> Maybe addr
laIpv6 LocalAddresses addr
b,
        laUnix :: Maybe addr
laUnix = LocalAddresses addr -> Maybe addr
forall addr. LocalAddresses addr -> Maybe addr
laUnix LocalAddresses addr
a Maybe addr -> Maybe addr -> Maybe addr
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> LocalAddresses addr -> Maybe addr
forall addr. LocalAddresses addr -> Maybe addr
laUnix LocalAddresses addr
b
      }


-- | Allocate a socket and connect to a peer, execute the continuation with
-- async exceptions masked.  The continuation receives the 'unmask' callback.
--
safeConnect :: ( MonadMask m
               )
            => Snocket m sock addr
            -> (sock -> addr -> m ()) -- ^ configure the socket
            -> addr
            -- ^ remote addr
            -> addr
            -- ^ local addr
            -> m ()
            -- ^ allocate extra action; executed with async exceptions masked in
            -- the allocation action of 'bracket'
            -> m ()
            -- ^ release extra action; executed with async exceptions masked in
            -- the closing action of 'bracket'
            -> ((forall x. m x -> m x) -> sock -> Either SomeException () -> m t)
            -- ^ continuation executed with async exceptions
            -- masked; it receives: unmask function, allocated socket and
            -- connection error.
            -> m t
safeConnect :: forall (m :: * -> *) sock addr t.
MonadMask m =>
Snocket m sock addr
-> (sock -> addr -> m ())
-> addr
-> addr
-> m ()
-> m ()
-> ((forall x. m x -> m x)
    -> sock -> Either SomeException () -> m t)
-> m t
safeConnect Snocket m sock addr
sn sock -> addr -> m ()
configureSock addr
remoteAddr addr
localAddr m ()
malloc m ()
mclean (forall x. m x -> m x) -> sock -> Either SomeException () -> m t
k =
    m sock -> (sock -> m ()) -> (sock -> m t) -> m t
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
      (do sock <- Snocket m sock addr -> AddressFamily addr -> m sock
forall (m :: * -> *) fd addr.
Snocket m fd addr -> AddressFamily addr -> m fd
Snocket.open Snocket m sock addr
sn (Snocket m sock addr -> addr -> AddressFamily addr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket m sock addr
sn addr
remoteAddr)
          malloc
          pure sock
      )
      (\sock
sock -> Snocket m sock addr -> sock -> m ()
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m ()
Snocket.close Snocket m sock addr
sn sock
sock m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m ()
mclean)
      (\sock
sock -> ((forall x. m x -> m x) -> m t) -> m t
forall b. ((forall x. m x -> m x) -> m b) -> m b
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall x. m x -> m x) -> m t) -> m t)
-> ((forall x. m x -> m x) -> m t) -> m t
forall a b. (a -> b) -> a -> b
$ \forall x. m x -> m x
unmask -> do
          res <- m () -> m (Either SomeException ())
forall e a. Exception e => m a -> m (Either e a)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try (m () -> m (Either SomeException ()))
-> m () -> m (Either SomeException ())
forall a b. (a -> b) -> a -> b
$ do
            sock -> addr -> m ()
configureSock sock
sock addr
localAddr
            let doBind :: Bool
doBind = case Snocket m sock addr -> addr -> AddressFamily addr
forall (m :: * -> *) fd addr.
Snocket m fd addr -> addr -> AddressFamily addr
Snocket.addrFamily Snocket m sock addr
sn addr
localAddr of
                              Snocket.SocketFamily Family
fam -> Family
fam Family -> Family -> Bool
forall a. Eq a => a -> a -> Bool
/= Family
AF_UNIX
                              AddressFamily addr
_                        -> Bool
False -- Bind is a nop for Named Pipes anyway
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
doBind (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
              Snocket m sock addr -> sock -> addr -> m ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.bind Snocket m sock addr
sn sock
sock addr
localAddr
            m () -> m ()
forall x. m x -> m x
unmask (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Snocket m sock addr -> sock -> addr -> m ()
forall (m :: * -> *) fd addr.
Snocket m fd addr -> fd -> addr -> m ()
Snocket.connect Snocket m sock addr
sn sock
sock addr
remoteAddr
          k unmask sock res)


--
-- Internal API
--


-- | GADT which classifies connection result.
--
data ConnectResult =
      ConnectSuccess
    -- ^ Successful connection.
    | ConnectSuccessLast
    -- ^ Successfully connection, reached the valency target.  Other ongoing
    -- connection attempts will be killed.
    | ConnectValencyExceeded
    -- ^ Someone else manged to create the final connection to a target before
    -- us.
    deriving (ConnectResult -> ConnectResult -> Bool
(ConnectResult -> ConnectResult -> Bool)
-> (ConnectResult -> ConnectResult -> Bool) -> Eq ConnectResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectResult -> ConnectResult -> Bool
== :: ConnectResult -> ConnectResult -> Bool
$c/= :: ConnectResult -> ConnectResult -> Bool
/= :: ConnectResult -> ConnectResult -> Bool
Eq, Eq ConnectResult
Eq ConnectResult =>
(ConnectResult -> ConnectResult -> Ordering)
-> (ConnectResult -> ConnectResult -> Bool)
-> (ConnectResult -> ConnectResult -> Bool)
-> (ConnectResult -> ConnectResult -> Bool)
-> (ConnectResult -> ConnectResult -> Bool)
-> (ConnectResult -> ConnectResult -> ConnectResult)
-> (ConnectResult -> ConnectResult -> ConnectResult)
-> Ord ConnectResult
ConnectResult -> ConnectResult -> Bool
ConnectResult -> ConnectResult -> Ordering
ConnectResult -> ConnectResult -> ConnectResult
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ConnectResult -> ConnectResult -> Ordering
compare :: ConnectResult -> ConnectResult -> Ordering
$c< :: ConnectResult -> ConnectResult -> Bool
< :: ConnectResult -> ConnectResult -> Bool
$c<= :: ConnectResult -> ConnectResult -> Bool
<= :: ConnectResult -> ConnectResult -> Bool
$c> :: ConnectResult -> ConnectResult -> Bool
> :: ConnectResult -> ConnectResult -> Bool
$c>= :: ConnectResult -> ConnectResult -> Bool
>= :: ConnectResult -> ConnectResult -> Bool
$cmax :: ConnectResult -> ConnectResult -> ConnectResult
max :: ConnectResult -> ConnectResult -> ConnectResult
$cmin :: ConnectResult -> ConnectResult -> ConnectResult
min :: ConnectResult -> ConnectResult -> ConnectResult
Ord, Int -> ConnectResult -> ShowS
[ConnectResult] -> ShowS
ConnectResult -> String
(Int -> ConnectResult -> ShowS)
-> (ConnectResult -> String)
-> ([ConnectResult] -> ShowS)
-> Show ConnectResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectResult -> ShowS
showsPrec :: Int -> ConnectResult -> ShowS
$cshow :: ConnectResult -> String
show :: ConnectResult -> String
$cshowList :: [ConnectResult] -> ShowS
showList :: [ConnectResult] -> ShowS
Show)

-- | Traverse 'SubscriptionTarget's in an infinite loop.
--
subscriptionLoop
    :: forall m s sock localAddrs addr a x.
       ( MonadAsync m
       , MonadDelay m
       , MonadMask  m
       , MonadFix   m
       , Ord (Async m ())
       , Ord addr
       )
    => Tracer              m (SubscriptionTrace addr)

    -- various state variables of the subscription loop
    -> ConnectionTable     m   addr
    -> ResultQ             m   addr (WithAddr addr ErrorPolicyTrace) a
    -> StateVar            m s
    -> ThreadsVar          m

    -> Snocket             m sock addr
    -> (sock -> addr -> m ())

    -> WorkerCallbacks m s addr a x
    -> WorkerParams m localAddrs addr
    -- ^ given a remote address, pick the local one
    -> (sock -> m a)
    -- ^ application
    -> m Void
subscriptionLoop :: forall (m :: * -> *) s sock (localAddrs :: * -> *) addr a x.
(MonadAsync m, MonadDelay m, MonadMask m, MonadFix m,
 Ord (Async m ()), Ord addr) =>
Tracer m (SubscriptionTrace addr)
-> ConnectionTable m addr
-> ResultQ m addr (WithAddr addr ErrorPolicyTrace) a
-> StateVar m s
-> ThreadsVar m
-> Snocket m sock addr
-> (sock -> addr -> m ())
-> WorkerCallbacks m s addr a x
-> WorkerParams m localAddrs addr
-> (sock -> m a)
-> m Void
subscriptionLoop
      Tracer m (SubscriptionTrace addr)
tr ConnectionTable m addr
tbl ResultQ m addr (WithAddr addr ErrorPolicyTrace) a
resQ StateVar m s
sVar ThreadsVar m
threadsVar Snocket m sock addr
snocket sock -> addr -> m ()
configureSock
      WorkerCallbacks { wcSocketStateChangeTx :: forall (m :: * -> *) s addr a t.
WorkerCallbacks m s addr a t -> SocketStateChange m s addr
wcSocketStateChangeTx   = SocketStateChange m s addr
socketStateChangeTx
                      , wcCompleteApplicationTx :: forall (m :: * -> *) s addr a t.
WorkerCallbacks m s addr a t -> CompleteApplication m s addr a
wcCompleteApplicationTx = CompleteApplication m s addr a
completeApplicationTx
                      }
      WorkerParams { wpLocalAddresses :: forall (m :: * -> *) (localAddrs :: * -> *) addr.
WorkerParams m localAddrs addr -> localAddrs addr
wpLocalAddresses         = localAddrs addr
localAddresses
                   , wpConnectionAttemptDelay :: forall (m :: * -> *) (localAddrs :: * -> *) addr.
WorkerParams m localAddrs addr -> addr -> Maybe DiffTime
wpConnectionAttemptDelay = addr -> Maybe DiffTime
connectionAttemptDelay
                   , wpSubscriptionTarget :: forall (m :: * -> *) (localAddrs :: * -> *) addr.
WorkerParams m localAddrs addr -> m (SubscriptionTarget m addr)
wpSubscriptionTarget     = m (SubscriptionTarget m addr)
subscriptionTargets
                   , wpValency :: forall (m :: * -> *) (localAddrs :: * -> *) addr.
WorkerParams m localAddrs addr -> Int
wpValency                = Int
valency
                   , addr -> localAddrs addr -> Maybe addr
wpSelectAddress :: addr -> localAddrs addr -> Maybe addr
wpSelectAddress :: forall (m :: * -> *) (localAddrs :: * -> *) addr.
WorkerParams m localAddrs addr
-> addr -> localAddrs addr -> Maybe addr
wpSelectAddress
                   }
      sock -> m a
k = do
    valencyVar <- STM m (ValencyCounter m) -> m (ValencyCounter m)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (ValencyCounter m) -> m (ValencyCounter m))
-> STM m (ValencyCounter m) -> m (ValencyCounter m)
forall a b. (a -> b) -> a -> b
$ ConnectionTable m addr -> Int -> STM m (ValencyCounter m)
forall (m :: * -> *) addr.
MonadSTM m =>
ConnectionTable m addr -> Int -> STM m (ValencyCounter m)
newValencyCounter ConnectionTable m addr
tbl Int
valency

    -- outer loop: set new 'conThread' variable, get targets and traverse
    -- through them trying to connect to each addr.
    forever $ do
      traceWith tr (SubscriptionTraceStart valency)
      start <- getMonotonicTime
      conThreads <- newTVarIO Set.empty
      sTarget <- subscriptionTargets
      innerLoop conThreads valencyVar sTarget
      atomically $ waitValencyCounter valencyVar

      -- We always wait at least 'ipRetryDelay' seconds between calls to
      -- 'getTargets', and before trying to restart the subscriptions we also
      -- wait 1 second so that if multiple subscription targets fail around the
      -- same time we will try to restart with a valency
      -- higher than 1.
      threadDelay 1
      end <- getMonotonicTime
      let duration = Time -> Time -> DiffTime
diffTime Time
end Time
start
      currentValency <- atomically $ readValencyCounter valencyVar
      traceWith tr $ SubscriptionTraceRestart duration valency
          (valency - currentValency)

      when (duration < ipRetryDelay) $
          threadDelay $ ipRetryDelay - duration

  where
    -- a single run through @sTarget :: SubscriptionTarget m addr@.
    innerLoop :: StrictTVar m (Set (Async m ()))
              -> ValencyCounter m
              -> SubscriptionTarget m addr
              -> m ()
    innerLoop :: ThreadsVar m
-> ValencyCounter m -> SubscriptionTarget m addr -> m ()
innerLoop ThreadsVar m
conThreads ValencyCounter m
valencyVar SubscriptionTarget m addr
sTarget = do
      mt <- SubscriptionTarget m addr
-> m (Maybe (addr, SubscriptionTarget m addr))
forall (m :: * -> *) target.
SubscriptionTarget m target
-> m (Maybe (target, SubscriptionTarget m target))
getSubscriptionTarget SubscriptionTarget m addr
sTarget
      case mt of
        Maybe (addr, SubscriptionTarget m addr)
Nothing -> do
          len <- (Set (Async m ()) -> Int) -> m (Set (Async m ())) -> m Int
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Set (Async m ()) -> Int
forall a. Set a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (m (Set (Async m ())) -> m Int) -> m (Set (Async m ())) -> m Int
forall a b. (a -> b) -> a -> b
$ STM m (Set (Async m ())) -> m (Set (Async m ()))
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Set (Async m ())) -> m (Set (Async m ())))
-> STM m (Set (Async m ())) -> m (Set (Async m ()))
forall a b. (a -> b) -> a -> b
$ ThreadsVar m -> STM m (Set (Async m ()))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar ThreadsVar m
conThreads
          when (len > 0) $
              traceWith tr $ SubscriptionTraceSubscriptionWaiting len

          -- We wait on the list of active connection threads instead of using
          -- an async wait function since some of the connections may succeed
          -- and then should be left running.
          --
          -- Note: active connections are removed from 'conThreads' when the
          -- 'connect' call finishes.
          atomically $ do
              activeCons <- readTVar conThreads
              unless (null activeCons) retry

          valencyLeft <- atomically $ readValencyCounter valencyVar
          if valencyLeft <= 0
             then traceWith tr SubscriptionTraceSubscriptionRunning
             else traceWith tr SubscriptionTraceSubscriptionFailed

        Just (addr
remoteAddr, SubscriptionTarget m addr
sTargetNext) -> do
          valencyLeft <- STM m Int -> m Int
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m Int -> m Int) -> STM m Int -> m Int
forall a b. (a -> b) -> a -> b
$ ValencyCounter m -> STM m Int
forall (m :: * -> *). MonadSTM m => ValencyCounter m -> STM m Int
readValencyCounter ValencyCounter m
valencyVar

          -- If we have already created enough connections (valencyLeft <= 0)
          -- we don't need to traverse the rest of the list.
          if valencyLeft <= 0
              then traceWith tr SubscriptionTraceSubscriptionRunning
              else innerStep conThreads valencyVar remoteAddr sTargetNext

    innerStep :: StrictTVar m (Set (Async m ()))
              -- ^ outstanding connection threads; threads are removed as soon
              -- as the connection succeeds.  They are all cancelled when
              -- valency drops to 0.  The asynchronous exception which cancels
              -- the connection thread can only occur while connecting and not
              -- when an application is running.  This is guaranteed since
              -- threads are removed from this set as soon connecting is
              -- finished (successfully or not) and before application is
              -- started.
              -> ValencyCounter m
              -> addr
              -> SubscriptionTarget m addr
              -> m ()
    innerStep :: ThreadsVar m
-> ValencyCounter m -> addr -> SubscriptionTarget m addr -> m ()
innerStep ThreadsVar m
conThreads ValencyCounter m
valencyVar !addr
remoteAddr SubscriptionTarget m addr
sTargetNext = do
      r <- ConnectionTable m addr
-> addr
-> ConnectionDirection
-> ValencyCounter m
-> m ConnectionTableRef
forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
ConnectionTable m addr
-> addr
-> ConnectionDirection
-> ValencyCounter m
-> m ConnectionTableRef
refConnection ConnectionTable m addr
tbl addr
remoteAddr ConnectionDirection
ConnectionOutbound ValencyCounter m
valencyVar
      case r of
        ConnectionTableRef
ConnectionTableCreate ->
          case addr -> localAddrs addr -> Maybe addr
wpSelectAddress addr
remoteAddr localAddrs addr
localAddresses of
            Maybe addr
Nothing ->
              Tracer m (SubscriptionTrace addr) -> SubscriptionTrace addr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (SubscriptionTrace addr)
tr (addr -> SubscriptionTrace addr
forall addr. addr -> SubscriptionTrace addr
SubscriptionTraceUnsupportedRemoteAddr addr
remoteAddr)

            -- This part is very similar to
            -- 'Ouroboros.Network.Server.Socket.spawnOne', it should not
            -- deadlock by the same reasons.  The difference is that we are
            -- using 'mask' and 'async' as 'asyncWithUnmask' is not available.
            Just addr
localAddr ->
             do rec
                  thread <- async $ do
                    traceWith tr $ SubscriptionTraceConnectStart remoteAddr
                    -- Try to connect; 'safeConnect' is using 'bracket' to
                    -- create / close a socket and update the states.  The
                    -- continuation, e.g.  'connAction' runs with async
                    -- exceptions masked, and receives the unmask function from
                    -- this bracket.
                    safeConnect
                      snocket
                      configureSock
                      remoteAddr
                      localAddr
                      (do
                        traceWith tr $ SubscriptionTraceAllocateSocket remoteAddr
                        atomically $ do
                          modifyTVar conThreads (Set.insert thread)
                          modifyTVar threadsVar (Set.insert thread)
                          readTVar sVar
                            >>= socketStateChangeTx (CreatedSocket remoteAddr thread)
                            >>= (writeTVar sVar $!))
                      (do
                        atomically $ do
                          -- The thread is removed from 'conThreads'
                          -- inside 'connAction'.
                          modifyTVar threadsVar (Set.delete thread)
                          readTVar sVar
                            >>= socketStateChangeTx (ClosedSocket remoteAddr thread)
                            >>= (writeTVar sVar $!)
                        traceWith tr $ SubscriptionTraceCloseSocket remoteAddr)
                      (connAction
                        thread conThreads valencyVar
                        remoteAddr)

                let delay = case addr -> Maybe DiffTime
connectionAttemptDelay addr
remoteAddr of
                                Just DiffTime
d  -> DiffTime
d DiffTime -> DiffTime -> DiffTime
forall a. Ord a => a -> a -> a
`max` DiffTime
minConnectionAttemptDelay
                                             DiffTime -> DiffTime -> DiffTime
forall a. Ord a => a -> a -> a
`min` DiffTime
maxConnectionAttemptDelay
                                Maybe DiffTime
Nothing -> DiffTime
defaultConnectionAttemptDelay
                traceWith tr
                          (SubscriptionTraceSubscriptionWaitingNewConnection delay)
                threadDelay delay

        ConnectionTableRef
ConnectionTableExist ->
          Tracer m (SubscriptionTrace addr) -> SubscriptionTrace addr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (SubscriptionTrace addr)
tr (SubscriptionTrace addr -> m ()) -> SubscriptionTrace addr -> m ()
forall a b. (a -> b) -> a -> b
$ addr -> SubscriptionTrace addr
forall addr. addr -> SubscriptionTrace addr
SubscriptionTraceConnectionExist addr
remoteAddr
        ConnectionTableRef
ConnectionTableDuplicate -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      innerLoop conThreads valencyVar sTargetNext

    -- Start connection thread: connect to the remote peer, run application.
    -- This function runs with asynchronous exceptions masked.
    --
    connAction :: Async m ()
               -> StrictTVar m (Set (Async m ()))
               -> ValencyCounter m
               -> addr
               -> (forall y. m y -> m y) -- unmask exceptions
               -> sock
               -> Either SomeException ()
               -> m ()
    connAction :: Async m ()
-> ThreadsVar m
-> ValencyCounter m
-> addr
-> (forall x. m x -> m x)
-> sock
-> Either SomeException ()
-> m ()
connAction Async m ()
thread ThreadsVar m
conThreads ValencyCounter m
valencyVar addr
remoteAddr forall x. m x -> m x
unmask sock
sock Either SomeException ()
connectionRes = do
      t <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
      case connectionRes of
        -- connection error
        Left (SomeException e
e) -> do
          Tracer m (SubscriptionTrace addr) -> SubscriptionTrace addr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (SubscriptionTrace addr)
tr (SubscriptionTrace addr -> m ()) -> SubscriptionTrace addr -> m ()
forall a b. (a -> b) -> a -> b
$ addr -> e -> SubscriptionTrace addr
forall addr e. Exception e => addr -> e -> SubscriptionTrace addr
SubscriptionTraceConnectException addr
remoteAddr e
e
          STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            -- remove thread from active connections threads
            ThreadsVar m -> (Set (Async m ()) -> Set (Async m ())) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar ThreadsVar m
conThreads (Async m () -> Set (Async m ()) -> Set (Async m ())
forall a. Ord a => a -> Set a -> Set a
Set.delete Async m ()
thread)

            CompleteApplicationResult
              { carState
              , carThreads
              , carTrace
              } <- StateVar m s -> STM m s
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StateVar m s
sVar STM m s
-> (s -> STM m (CompleteApplicationResult m addr s))
-> STM m (CompleteApplicationResult m addr s)
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CompleteApplication m s addr a
completeApplicationTx (Time -> addr -> e -> Result addr a
forall e addr r. Exception e => Time -> addr -> e -> Result addr r
ConnectionError Time
t addr
remoteAddr e
e)
            writeTVar sVar carState
            writeTQueue resQ (Act carThreads carTrace)

        -- connection succeeded
        Right ()
_ -> do
          localAddr <- Snocket m sock addr -> sock -> m addr
forall (m :: * -> *) fd addr. Snocket m fd addr -> fd -> m addr
Snocket.getLocalAddr Snocket m sock addr
snocket sock
sock
          connRes <- atomically $ do
            -- we successfully connected, remove the thread from
            -- outstanding connection threads.
            modifyTVar conThreads (Set.delete thread)

            v <- readValencyCounter valencyVar
            if v > 0
              then do
                addConnection tbl remoteAddr localAddr ConnectionOutbound (Just valencyVar)
                CompleteApplicationResult
                  { carState
                  , carThreads
                  , carTrace
                  } <- readTVar sVar >>= completeApplicationTx (Connected t remoteAddr)
                writeTVar sVar carState
                writeTQueue resQ (Act carThreads carTrace)
                return $ if v == 1
                          then ConnectSuccessLast
                          else ConnectSuccess
              else
                return ConnectValencyExceeded

          -- handle connection result
          traceWith tr $ SubscriptionTraceConnectEnd remoteAddr connRes
          case connRes of
            ConnectResult
ConnectValencyExceeded -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            -- otherwise it was a success
            ConnectResult
_           -> do
              Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ConnectResult
connRes ConnectResult -> ConnectResult -> Bool
forall a. Eq a => a -> a -> Bool
== ConnectResult
ConnectSuccessLast) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                -- outstanding connection threads
                threads <- STM m (Set (Async m ())) -> m (Set (Async m ()))
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Set (Async m ())) -> m (Set (Async m ())))
-> STM m (Set (Async m ())) -> m (Set (Async m ()))
forall a b. (a -> b) -> a -> b
$ ThreadsVar m -> STM m (Set (Async m ()))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar ThreadsVar m
conThreads
                mapM_ (\Async m ()
tid ->
                        Async m () -> SubscriberError -> m ()
forall e a. Exception e => Async m a -> e -> m ()
forall (m :: * -> *) e a.
(MonadAsync m, Exception e) =>
Async m a -> e -> m ()
cancelWith Async m ()
tid
                        (SubscriberErrorType -> String -> CallStack -> SubscriberError
SubscriberError
                          SubscriberErrorType
SubscriberParallelConnectionCancelled
                          String
"Parallel connection cancelled"
                          CallStack
HasCallStack => CallStack
callStack)
                      )threads


              -- run application
              appRes :: Either SomeException a
                <- m a -> m (Either SomeException a)
forall e a. Exception e => m a -> m (Either e a)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try (m a -> m (Either SomeException a))
-> m a -> m (Either SomeException a)
forall a b. (a -> b) -> a -> b
$ m a -> m a
forall x. m x -> m x
unmask (sock -> m a
k sock
sock)

              case appRes of
                Right a
_ -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                Left SomeException
e -> Tracer m (SubscriptionTrace addr) -> SubscriptionTrace addr -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m (SubscriptionTrace addr)
tr (SubscriptionTrace addr -> m ()) -> SubscriptionTrace addr -> m ()
forall a b. (a -> b) -> a -> b
$ addr -> SomeException -> SubscriptionTrace addr
forall addr e. Exception e => addr -> e -> SubscriptionTrace addr
SubscriptionTraceApplicationException addr
remoteAddr SomeException
e

              t' <- getMonotonicTime
              atomically $ do
                case appRes of
                  Right a
a ->
                    ResultQ m addr (WithAddr addr ErrorPolicyTrace) a
-> ResOrAct m addr (WithAddr addr ErrorPolicyTrace) a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTQueue m a -> a -> STM m ()
writeTQueue ResultQ m addr (WithAddr addr ErrorPolicyTrace) a
resQ (Result addr a -> ResOrAct m addr (WithAddr addr ErrorPolicyTrace) a
forall (m :: * -> *) addr tr r.
Result addr r -> ResOrAct m addr tr r
Res (Time -> addr -> a -> Result addr a
forall addr r. Time -> addr -> r -> Result addr r
ApplicationResult Time
t' addr
remoteAddr a
a))
                  Left (SomeException e
e) ->
                    ResultQ m addr (WithAddr addr ErrorPolicyTrace) a
-> ResOrAct m addr (WithAddr addr ErrorPolicyTrace) a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTQueue m a -> a -> STM m ()
writeTQueue ResultQ m addr (WithAddr addr ErrorPolicyTrace) a
resQ (Result addr a -> ResOrAct m addr (WithAddr addr ErrorPolicyTrace) a
forall (m :: * -> *) addr tr r.
Result addr r -> ResOrAct m addr tr r
Res (Time -> addr -> e -> Result addr a
forall e addr r. Exception e => Time -> addr -> e -> Result addr r
ApplicationError Time
t' addr
remoteAddr e
e))
                removeConnectionSTM tbl remoteAddr localAddr ConnectionOutbound

-- | Almost the same as 'Ouroboros.Network.Server.Socket.mainLoop'.
-- 'mainLoop' reads from the result queue and runs the 'CompleteApplication'
-- callback.
--
mainLoop
  :: forall s r addr t.
     Tracer IO (WithAddr addr ErrorPolicyTrace)
  -> ResultQ IO addr (WithAddr addr ErrorPolicyTrace) r
  -> ThreadsVar IO
  -> StateVar IO s
  -> CompleteApplication IO s addr r
  -> Main IO s t
  -> IO t
mainLoop :: forall s r addr t.
Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ IO addr (WithAddr addr ErrorPolicyTrace) r
-> ThreadsVar IO
-> StateVar IO s
-> CompleteApplication IO s addr r
-> Main IO s t
-> IO t
mainLoop Tracer IO (WithAddr addr ErrorPolicyTrace)
errorPolicyTracer ResultQ IO addr (WithAddr addr ErrorPolicyTrace) r
resQ ThreadsVar IO
threadsVar StateVar IO s
statusVar CompleteApplication IO s addr r
completeApplicationTx Main IO s t
main = do
    IO (IO t) -> IO t
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (STM IO (IO t) -> IO (IO t)
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO (IO t) -> IO (IO t)) -> STM IO (IO t) -> IO (IO t)
forall a b. (a -> b) -> a -> b
$ STM (IO t)
STM IO (IO t)
mainTx STM (IO t) -> STM (IO t) -> STM (IO t)
forall a. STM a -> STM a -> STM a
`STM.orElse` STM (IO t)
STM IO (IO t)
connectionTx)
  where
    -- Sample the state, and run the main action. If it does not retry, then
    -- the `mainLoop` finishes with `pure t` where `t` is the main action result.
    mainTx :: STM IO (IO t)
    mainTx :: STM IO (IO t)
mainTx = do
      t <- StateVar IO s -> STM IO s
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StateVar IO s
statusVar STM s -> (s -> STM t) -> STM t
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= s -> STM t
Main IO s t
main
      pure $ pure t

    -- Wait for some connection to finish, update the state with its result,
    -- then recurse onto `mainLoop`.
    connectionTx :: STM IO (IO t)
    connectionTx :: STM IO (IO t)
connectionTx = do
      result <- ResultQ IO addr (WithAddr addr ErrorPolicyTrace) r
-> STM IO (ResOrAct IO addr (WithAddr addr ErrorPolicyTrace) r)
forall (m :: * -> *) a. MonadSTM m => StrictTQueue m a -> STM m a
readTQueue ResultQ IO addr (WithAddr addr ErrorPolicyTrace) r
resQ
      case result of
        Act Set (Async IO ())
threads Maybe (WithAddr addr ErrorPolicyTrace)
tr -> IO t -> STM (IO t)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IO t -> STM (IO t)) -> IO t -> STM (IO t)
forall a b. (a -> b) -> a -> b
$ do
          (Async () -> IO ()) -> Set (Async ()) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ Async () -> IO ()
Async IO () -> IO ()
forall a. Async IO a -> IO ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel Set (Async ())
Set (Async IO ())
threads
          (WithAddr addr ErrorPolicyTrace -> IO ())
-> Maybe (WithAddr addr ErrorPolicyTrace) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (Tracer IO (WithAddr addr ErrorPolicyTrace)
-> WithAddr addr ErrorPolicyTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO (WithAddr addr ErrorPolicyTrace)
errorPolicyTracer) Maybe (WithAddr addr ErrorPolicyTrace)
tr
          Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ IO addr (WithAddr addr ErrorPolicyTrace) r
-> ThreadsVar IO
-> StateVar IO s
-> CompleteApplication IO s addr r
-> Main IO s t
-> IO t
forall s r addr t.
Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ IO addr (WithAddr addr ErrorPolicyTrace) r
-> ThreadsVar IO
-> StateVar IO s
-> CompleteApplication IO s addr r
-> Main IO s t
-> IO t
mainLoop Tracer IO (WithAddr addr ErrorPolicyTrace)
errorPolicyTracer ResultQ IO addr (WithAddr addr ErrorPolicyTrace) r
resQ ThreadsVar IO
threadsVar StateVar IO s
statusVar CompleteApplication IO s addr r
completeApplicationTx Main IO s t
main
        Res Result addr r
r -> do
          s <- StateVar IO s -> STM IO s
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StateVar IO s
statusVar
          CompleteApplicationResult
            { carState
            , carThreads
            , carTrace
            } <- completeApplicationTx r s
          writeTVar statusVar carState
          pure $ do
            traverse_ cancel carThreads
            traverse_ (traceWith errorPolicyTracer) carTrace
            mainLoop errorPolicyTracer resQ threadsVar statusVar completeApplicationTx main


--
-- Worker
--

-- | Worker STM callbacks
--
data WorkerCallbacks m s addr a t = WorkerCallbacks {
    forall (m :: * -> *) s addr a t.
WorkerCallbacks m s addr a t -> SocketStateChange m s addr
wcSocketStateChangeTx   :: SocketStateChange m s addr,
    forall (m :: * -> *) s addr a t.
WorkerCallbacks m s addr a t -> CompleteApplication m s addr a
wcCompleteApplicationTx :: CompleteApplication m s addr a,
    forall (m :: * -> *) s addr a t.
WorkerCallbacks m s addr a t -> Main m s t
wcMainTx                :: Main m s t
  }

-- | Worker parameters
--
data WorkerParams m localAddrs addr = WorkerParams {
    forall (m :: * -> *) (localAddrs :: * -> *) addr.
WorkerParams m localAddrs addr -> localAddrs addr
wpLocalAddresses         :: localAddrs addr,
    -- ^ local addresses of the server
    forall (m :: * -> *) (localAddrs :: * -> *) addr.
WorkerParams m localAddrs addr
-> addr -> localAddrs addr -> Maybe addr
wpSelectAddress          :: addr -> localAddrs addr -> Maybe addr,
    -- ^ given remote addr pick the local address
    forall (m :: * -> *) (localAddrs :: * -> *) addr.
WorkerParams m localAddrs addr -> addr -> Maybe DiffTime
wpConnectionAttemptDelay :: addr -> Maybe DiffTime,
    -- ^ delay after a connection attempt to 'addr'
    forall (m :: * -> *) (localAddrs :: * -> *) addr.
WorkerParams m localAddrs addr -> m (SubscriptionTarget m addr)
wpSubscriptionTarget     :: m (SubscriptionTarget m addr),
    forall (m :: * -> *) (localAddrs :: * -> *) addr.
WorkerParams m localAddrs addr -> Int
wpValency                :: Int
  }

-- |  This is the most abstract worker, which puts all the pieces together.  It
-- will execute until @main :: Main m s t@ returns.  It runs
-- 'subscriptionLoop' in a new threads and will exit when it dies.  Spawn
-- threads are cancelled in a 'finally' callback by throwing 'SubscriberError'.
--
-- Note: This function runs in 'IO' only because 'MonadSTM' does not yet support
-- 'orElse', PR #432.
--
worker
    :: forall s sock localAddrs addr a x.
       Ord addr
    => Tracer              IO (SubscriptionTrace addr)
    -> Tracer              IO (WithAddr addr ErrorPolicyTrace)
    -> ConnectionTable     IO   addr
    -> StateVar            IO s

    -> Snocket             IO sock addr
    -> (sock -> addr -> IO ())

    -> WorkerCallbacks     IO s addr a x
    -> WorkerParams        IO   localAddrs addr

    -> (sock -> IO a)
    -- ^ application
    -> IO x
worker :: forall s sock (localAddrs :: * -> *) addr a x.
Ord addr =>
Tracer IO (SubscriptionTrace addr)
-> Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ConnectionTable IO addr
-> StateVar IO s
-> Snocket IO sock addr
-> (sock -> addr -> IO ())
-> WorkerCallbacks IO s addr a x
-> WorkerParams IO localAddrs addr
-> (sock -> IO a)
-> IO x
worker Tracer IO (SubscriptionTrace addr)
tr Tracer IO (WithAddr addr ErrorPolicyTrace)
errTrace ConnectionTable IO addr
tbl StateVar IO s
sVar Snocket IO sock addr
snocket sock -> addr -> IO ()
configureSock workerCallbacks :: WorkerCallbacks IO s addr a x
workerCallbacks@WorkerCallbacks {CompleteApplication IO s addr a
wcCompleteApplicationTx :: forall (m :: * -> *) s addr a t.
WorkerCallbacks m s addr a t -> CompleteApplication m s addr a
wcCompleteApplicationTx :: CompleteApplication IO s addr a
wcCompleteApplicationTx, Main IO s x
wcMainTx :: forall (m :: * -> *) s addr a t.
WorkerCallbacks m s addr a t -> Main m s t
wcMainTx :: Main IO s x
wcMainTx } WorkerParams IO localAddrs addr
workerParams sock -> IO a
k = do
    resQ <- IO (ResultQ IO addr (WithAddr addr ErrorPolicyTrace) a)
forall (m :: * -> *) addr tr r.
MonadSTM m =>
m (ResultQ m addr tr r)
newResultQ
    threadsVar <- newTVarIO Set.empty
    withAsync
      (subscriptionLoop tr tbl resQ sVar threadsVar snocket configureSock
         workerCallbacks workerParams k) $ \Async IO Void
_ ->
           Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ IO addr (WithAddr addr ErrorPolicyTrace) a
-> ThreadsVar IO
-> StateVar IO s
-> CompleteApplication IO s addr a
-> Main IO s x
-> IO x
forall s r addr t.
Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ IO addr (WithAddr addr ErrorPolicyTrace) r
-> ThreadsVar IO
-> StateVar IO s
-> CompleteApplication IO s addr r
-> Main IO s t
-> IO t
mainLoop Tracer IO (WithAddr addr ErrorPolicyTrace)
errTrace ResultQ IO addr (WithAddr addr ErrorPolicyTrace) a
resQ StrictTVar IO (Set (Async ()))
ThreadsVar IO
threadsVar StateVar IO s
sVar CompleteApplication IO s addr a
wcCompleteApplicationTx Main IO s x
wcMainTx
           IO x -> IO () -> IO x
forall a b. IO a -> IO b -> IO a
forall (m :: * -> *) a b. MonadThrow m => m a -> m b -> m a
`finally` ThreadsVar IO -> IO ()
forall {m :: * -> *} {t :: * -> *} {a}.
(Foldable t, MonadAsync m) =>
StrictTVar m (t (Async m a)) -> m ()
killThreads StrictTVar IO (Set (Async ()))
ThreadsVar IO
threadsVar
  where
    killThreads :: StrictTVar m (t (Async m a)) -> m ()
killThreads StrictTVar m (t (Async m a))
threadsVar = do
      let e :: SubscriberError
e = SubscriberErrorType -> String -> CallStack -> SubscriberError
SubscriberError
                SubscriberErrorType
SubscriberWorkerCancelled
                String
"SubscriptionWorker exiting"
                CallStack
HasCallStack => CallStack
callStack
      children <- STM m (t (Async m a)) -> m (t (Async m a))
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (t (Async m a)) -> m (t (Async m a)))
-> STM m (t (Async m a)) -> m (t (Async m a))
forall a b. (a -> b) -> a -> b
$ StrictTVar m (t (Async m a)) -> STM m (t (Async m a))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (t (Async m a))
threadsVar
      mapM_ (\Async m a
a -> Async m a -> SubscriberError -> m ()
forall e a. Exception e => Async m a -> e -> m ()
forall (m :: * -> *) e a.
(MonadAsync m, Exception e) =>
Async m a -> e -> m ()
cancelWith Async m a
a SubscriberError
e) children


--
-- Auxiliary types: errors, traces
--

data SubscriberError = SubscriberError {
      SubscriberError -> SubscriberErrorType
seType    :: !SubscriberErrorType
    , SubscriberError -> String
seMessage :: !String
    , SubscriberError -> CallStack
seStack   :: !CallStack
    } deriving Int -> SubscriberError -> ShowS
[SubscriberError] -> ShowS
SubscriberError -> String
(Int -> SubscriberError -> ShowS)
-> (SubscriberError -> String)
-> ([SubscriberError] -> ShowS)
-> Show SubscriberError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SubscriberError -> ShowS
showsPrec :: Int -> SubscriberError -> ShowS
$cshow :: SubscriberError -> String
show :: SubscriberError -> String
$cshowList :: [SubscriberError] -> ShowS
showList :: [SubscriberError] -> ShowS
Show

-- | Enumeration of error conditions.
--
data SubscriberErrorType = SubscriberParallelConnectionCancelled
                         | SubscriberWorkerCancelled
                         deriving (SubscriberErrorType -> SubscriberErrorType -> Bool
(SubscriberErrorType -> SubscriberErrorType -> Bool)
-> (SubscriberErrorType -> SubscriberErrorType -> Bool)
-> Eq SubscriberErrorType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SubscriberErrorType -> SubscriberErrorType -> Bool
== :: SubscriberErrorType -> SubscriberErrorType -> Bool
$c/= :: SubscriberErrorType -> SubscriberErrorType -> Bool
/= :: SubscriberErrorType -> SubscriberErrorType -> Bool
Eq, Int -> SubscriberErrorType -> ShowS
[SubscriberErrorType] -> ShowS
SubscriberErrorType -> String
(Int -> SubscriberErrorType -> ShowS)
-> (SubscriberErrorType -> String)
-> ([SubscriberErrorType] -> ShowS)
-> Show SubscriberErrorType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SubscriberErrorType -> ShowS
showsPrec :: Int -> SubscriberErrorType -> ShowS
$cshow :: SubscriberErrorType -> String
show :: SubscriberErrorType -> String
$cshowList :: [SubscriberErrorType] -> ShowS
showList :: [SubscriberErrorType] -> ShowS
Show)

instance Exception SubscriberError where
    displayException :: SubscriberError -> String
displayException SubscriberError{SubscriberErrorType
seType :: SubscriberError -> SubscriberErrorType
seType :: SubscriberErrorType
seType, String
seMessage :: SubscriberError -> String
seMessage :: String
seMessage, CallStack
seStack :: SubscriberError -> CallStack
seStack :: CallStack
seStack}
      = String -> String -> String -> ShowS
forall r. PrintfType r => String -> r
printf String
"%s %s at %s"
         (SubscriberErrorType -> String
forall a. Show a => a -> String
show SubscriberErrorType
seType)
         (ShowS
forall a. Show a => a -> String
show String
seMessage)
         (CallStack -> String
prettyCallStack CallStack
seStack)


data SubscriptionTrace addr =
      SubscriptionTraceConnectStart addr
    | SubscriptionTraceConnectEnd addr ConnectResult
    | forall e. Exception e => SubscriptionTraceSocketAllocationException addr e
    | forall e. Exception e => SubscriptionTraceConnectException addr e
    | forall e. Exception e => SubscriptionTraceApplicationException addr e
    | SubscriptionTraceTryConnectToPeer addr
    | SubscriptionTraceSkippingPeer addr
    | SubscriptionTraceSubscriptionRunning
    | SubscriptionTraceSubscriptionWaiting Int
    | SubscriptionTraceSubscriptionFailed
    | SubscriptionTraceSubscriptionWaitingNewConnection DiffTime
    | SubscriptionTraceStart Int
    | SubscriptionTraceRestart DiffTime Int Int
    | SubscriptionTraceConnectionExist addr
    | SubscriptionTraceUnsupportedRemoteAddr addr
    | SubscriptionTraceMissingLocalAddress
    | SubscriptionTraceAllocateSocket addr
    | SubscriptionTraceCloseSocket addr

instance Show addr => Show (SubscriptionTrace addr) where
    show :: SubscriptionTrace addr -> String
show (SubscriptionTraceConnectStart addr
dst) =
        String
"Connection Attempt Start, destination " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
dst
    show (SubscriptionTraceConnectEnd addr
dst ConnectResult
res) =
        String
"Connection Attempt End, destination " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
dst String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" outcome: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ConnectResult -> String
forall a. Show a => a -> String
show ConnectResult
res
    show (SubscriptionTraceSocketAllocationException addr
dst e
e) =
        String
"Socket Allocation Exception, destination " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
dst String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" exception: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ e -> String
forall a. Show a => a -> String
show e
e
    show (SubscriptionTraceConnectException addr
dst e
e) =
        String
"Connection Attempt Exception, destination " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
dst String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" exception: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ e -> String
forall a. Show a => a -> String
show e
e
    show (SubscriptionTraceTryConnectToPeer addr
addr) =
        String
"Trying to connect to " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
addr
    show (SubscriptionTraceSkippingPeer addr
addr) =
        String
"Skipping peer " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
addr
    show SubscriptionTrace addr
SubscriptionTraceSubscriptionRunning =
        String
"Required subscriptions started"
    show (SubscriptionTraceSubscriptionWaiting Int
d) =
        String
"Waiting on " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
d String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" active connections"
    show SubscriptionTrace addr
SubscriptionTraceSubscriptionFailed =
        String
"Failed to start all required subscriptions"
    show (SubscriptionTraceSubscriptionWaitingNewConnection DiffTime
delay) =
        String
"Waiting " String -> ShowS
forall a. [a] -> [a] -> [a]
++ DiffTime -> String
forall a. Show a => a -> String
show DiffTime
delay String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" before attempting a new connection"
    show (SubscriptionTraceStart Int
val) = String
"Starting Subscription Worker, valency " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
val
    show (SubscriptionTraceRestart DiffTime
duration Int
desiredVal Int
currentVal) =
        String
"Restarting Subscription after " String -> ShowS
forall a. [a] -> [a] -> [a]
++ DiffTime -> String
forall a. Show a => a -> String
show DiffTime
duration String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" desired valency " String -> ShowS
forall a. [a] -> [a] -> [a]
++
        Int -> String
forall a. Show a => a -> String
show Int
desiredVal String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" current valency " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
currentVal
    show (SubscriptionTraceConnectionExist addr
dst) =
        String
"Connection Existed to " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
dst
    show (SubscriptionTraceUnsupportedRemoteAddr addr
dst) =
        String
"Unsupported remote target address " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
dst
    -- TODO: add address family
    show SubscriptionTrace addr
SubscriptionTraceMissingLocalAddress =
        String
"Missing local address"
    show (SubscriptionTraceApplicationException addr
addr e
e) =
        String
"Application Exception: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
addr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ e -> String
forall a. Show a => a -> String
show e
e
    show (SubscriptionTraceAllocateSocket addr
addr) =
        String
"Allocate socket to " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
addr
    show (SubscriptionTraceCloseSocket addr
addr) =
        String
"Closed socket to " String -> ShowS
forall a. [a] -> [a] -> [a]
++ addr -> String
forall a. Show a => a -> String
show addr
addr