{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC "-fno-warn-name-shadowing" #-}
module Ouroboros.Network.Server.Socket
( AcceptedConnectionsLimit (..)
, AcceptConnectionsPolicyTrace (..)
, BeginConnection
, HandleConnection (..)
, ApplicationStart
, CompleteConnection
, CompleteApplicationResult (..)
, Result (..)
, Main
, run
, Socket (..)
, ioSocket
) where
import Control.Concurrent.Async (Async)
import Control.Concurrent.Async qualified as Async
import Control.Concurrent.STM (STM)
import Control.Concurrent.STM qualified as STM
import Control.Exception (IOException, SomeException (..), finally, mask, mask_,
onException, try)
import Control.Monad (forM_, join)
import Control.Monad.Class.MonadTime.SI (Time, getMonotonicTime)
import Control.Monad.Class.MonadTimer.SI (threadDelay)
import Control.Tracer (Tracer, traceWith)
import Data.Foldable (traverse_)
import Data.Set (Set)
import Data.Set qualified as Set
import Ouroboros.Network.ErrorPolicy (CompleteApplicationResult (..),
ErrorPolicyTrace, WithAddr)
import Ouroboros.Network.Server.RateLimiting
data Socket addr channel = Socket
{ forall addr channel.
Socket addr channel
-> IO (addr, channel, IO (), Socket addr channel)
acceptConnection :: IO (addr, channel, IO (), Socket addr channel)
}
ioSocket :: IO (addr, channel) -> Socket addr channel
ioSocket :: forall addr channel. IO (addr, channel) -> Socket addr channel
ioSocket IO (addr, channel)
io = Socket
{ acceptConnection :: IO (addr, channel, IO (), Socket addr channel)
acceptConnection = do
(addr, channel) <- IO (addr, channel)
io
pure (addr, channel, pure (), ioSocket io)
}
type StatusVar st = STM.TVar st
data HandleConnection channel st r where
Reject :: !st -> HandleConnection channel st r
Accept :: !st -> !(channel -> IO r) -> HandleConnection channel st r
type BeginConnection addr channel st r = Time -> addr -> st -> STM (HandleConnection channel st r)
type ApplicationStart addr st = addr -> Async () -> st -> STM st
type CompleteConnection addr st tr r =
Result addr r -> st -> STM (CompleteApplicationResult IO addr st)
type Main st t = st -> STM t
type ResultQ addr r = STM.TQueue (Result addr r)
data Result addr r = Result
{ forall addr r. Result addr r -> Async ()
resultThread :: !(Async ())
, forall addr r. Result addr r -> addr
resultAddr :: !addr
, forall addr r. Result addr r -> Time
resultTime :: !Time
, forall addr r. Result addr r -> Either SomeException r
resultValue :: !(Either SomeException r)
}
type ThreadsVar = STM.TVar (Set (Async ()))
spawnOne
:: addr
-> StatusVar st
-> ResultQ addr r
-> ThreadsVar
-> ApplicationStart addr st
-> IO r
-> IO ()
spawnOne :: forall addr st r.
addr
-> StatusVar st
-> ResultQ addr r
-> ThreadsVar
-> ApplicationStart addr st
-> IO r
-> IO ()
spawnOne addr
remoteAddr StatusVar st
statusVar ResultQ addr r
resQ ThreadsVar
threadsVar ApplicationStart addr st
applicationStart IO r
io = IO () -> IO ()
forall a. IO a -> IO a
mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
rec let threadAction = \IO r -> IO r
unmask -> do
STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
StatusVar st -> STM st
forall a. TVar a -> STM a
STM.readTVar StatusVar st
statusVar
STM st -> (st -> STM st) -> STM st
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ApplicationStart addr st
applicationStart addr
remoteAddr Async ()
thread
STM st -> (st -> STM ()) -> STM ()
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (StatusVar st -> st -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar StatusVar st
statusVar (st -> STM ()) -> st -> STM ()
forall a b. (a -> b) -> a -> b
$!)
val <- IO r -> IO (Either SomeException r)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO r -> IO r
unmask IO r
io)
t <- getMonotonicTime
STM.atomically $ STM.writeTQueue resQ (Result thread remoteAddr t val)
thread <- Async.asyncWithUnmask $ \forall a. IO a -> IO a
unmask ->
(IO r -> IO r) -> IO ()
threadAction IO r -> IO r
forall a. IO a -> IO a
unmask
STM.atomically $ STM.modifyTVar' threadsVar (Set.insert thread)
acceptLoop
:: Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
acceptLoop :: forall addr r st channel.
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
acceptLoop Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
socket = do
mNextSocket <- Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO (Maybe (Socket addr channel))
forall addr channel st r.
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO (Maybe (Socket addr channel))
acceptOne Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
socket
case mNextSocket of
Maybe (Socket addr channel)
Nothing -> do
DiffTime -> IO ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
0.5
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
forall addr r st channel.
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
acceptLoop Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
socket
Just Socket addr channel
nextSocket ->
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
forall addr r st channel.
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
acceptLoop Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
nextSocket
acceptOne
:: forall addr channel st r.
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO (Maybe (Socket addr channel))
acceptOne :: forall addr channel st r.
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO (Maybe (Socket addr channel))
acceptOne Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionsLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
socket = ((forall a. IO a -> IO a) -> IO (Maybe (Socket addr channel)))
-> IO (Maybe (Socket addr channel))
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO (Maybe (Socket addr channel)))
-> IO (Maybe (Socket addr channel)))
-> ((forall a. IO a -> IO a) -> IO (Maybe (Socket addr channel)))
-> IO (Maybe (Socket addr channel))
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
Tracer IO AcceptConnectionsPolicyTrace
-> STM IO Int -> AcceptedConnectionsLimit -> IO ()
forall (m :: * -> *).
(MonadSTM m, MonadDelay m) =>
Tracer m AcceptConnectionsPolicyTrace
-> STM m Int -> AcceptedConnectionsLimit -> m ()
runConnectionRateLimits
Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace
(Set (Async ()) -> Int
forall a. Set a -> Int
Set.size (Set (Async ()) -> Int) -> STM (Set (Async ())) -> STM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ThreadsVar -> STM (Set (Async ()))
forall a. TVar a -> STM a
STM.readTVar ThreadsVar
threadsVar)
AcceptedConnectionsLimit
acceptedConnectionsLimit
outcome <- IO (addr, channel, IO (), Socket addr channel)
-> IO
(Either IOException (addr, channel, IO (), Socket addr channel))
forall e a. Exception e => IO a -> IO (Either e a)
try (IO (addr, channel, IO (), Socket addr channel)
-> IO (addr, channel, IO (), Socket addr channel)
forall a. IO a -> IO a
restore (Socket addr channel
-> IO (addr, channel, IO (), Socket addr channel)
forall addr channel.
Socket addr channel
-> IO (addr, channel, IO (), Socket addr channel)
acceptConnection Socket addr channel
socket))
case outcome :: Either IOException (addr, channel, IO (), Socket addr channel) of
Left IOException
ex -> do
IO () -> IO ()
forall a. IO a -> IO a
restore (IOException -> IO ()
acceptException IOException
ex)
Maybe (Socket addr channel) -> IO (Maybe (Socket addr channel))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Socket addr channel)
forall a. Maybe a
Nothing
Right (addr
addr, channel
channel, IO ()
close, Socket addr channel
nextSocket) -> do
t <- IO Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
let decision = STM (Maybe (channel -> IO r)) -> IO (Maybe (channel -> IO r))
forall a. STM a -> IO a
STM.atomically (STM (Maybe (channel -> IO r)) -> IO (Maybe (channel -> IO r)))
-> STM (Maybe (channel -> IO r)) -> IO (Maybe (channel -> IO r))
forall a b. (a -> b) -> a -> b
$ do
st <- StatusVar st -> STM st
forall a. TVar a -> STM a
STM.readTVar StatusVar st
statusVar
!handleConn <- beginConnection t addr st
case handleConn of
Reject st
st' -> do
StatusVar st -> st -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar StatusVar st
statusVar st
st'
Maybe (channel -> IO r) -> STM (Maybe (channel -> IO r))
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (channel -> IO r)
forall a. Maybe a
Nothing
Accept st
st' channel -> IO r
io -> do
StatusVar st -> st -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar StatusVar st
statusVar st
st'
Maybe (channel -> IO r) -> STM (Maybe (channel -> IO r))
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (channel -> IO r) -> STM (Maybe (channel -> IO r)))
-> Maybe (channel -> IO r) -> STM (Maybe (channel -> IO r))
forall a b. (a -> b) -> a -> b
$ (channel -> IO r) -> Maybe (channel -> IO r)
forall a. a -> Maybe a
Just channel -> IO r
io
choice <- decision `onException` close
case choice of
Maybe (channel -> IO r)
Nothing -> IO ()
close
Just channel -> IO r
io -> addr
-> StatusVar st
-> ResultQ addr r
-> ThreadsVar
-> ApplicationStart addr st
-> IO r
-> IO ()
forall addr st r.
addr
-> StatusVar st
-> ResultQ addr r
-> ThreadsVar
-> ApplicationStart addr st
-> IO r
-> IO ()
spawnOne addr
addr StatusVar st
statusVar ResultQ addr r
resQ ThreadsVar
threadsVar ApplicationStart addr st
applicationStart (channel -> IO r
io channel
channel IO r -> IO () -> IO r
forall a b. IO a -> IO b -> IO a
`finally` IO ()
close)
pure (Just nextSocket)
mainLoop
:: forall addr st tr r t .
Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> CompleteConnection addr st tr r
-> Main st t
-> IO t
mainLoop :: forall addr st tr r t.
Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> CompleteConnection addr st tr r
-> Main st t
-> IO t
mainLoop Tracer IO (WithAddr addr ErrorPolicyTrace)
errorPolicyTrace ResultQ addr r
resQ ThreadsVar
threadsVar StatusVar st
statusVar CompleteConnection addr st tr r
complete Main st t
main =
IO (IO t) -> IO t
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (STM (IO t) -> IO (IO t)
forall a. STM a -> IO a
STM.atomically (STM (IO t) -> IO (IO t)) -> STM (IO t) -> IO (IO t)
forall a b. (a -> b) -> a -> b
$ STM (IO t)
mainTx STM (IO t) -> STM (IO t) -> STM (IO t)
forall a. STM a -> STM a -> STM a
`STM.orElse` STM (IO t)
connectionTx)
where
mainTx :: STM (IO t)
mainTx :: STM (IO t)
mainTx = do
st <- StatusVar st -> STM st
forall a. TVar a -> STM a
STM.readTVar StatusVar st
statusVar
t <- main st
pure $ pure t
connectionTx :: STM (IO t)
connectionTx :: STM (IO t)
connectionTx = do
result <- ResultQ addr r -> STM (Result addr r)
forall a. TQueue a -> STM a
STM.readTQueue ResultQ addr r
resQ
isMember <- Set.member (resultThread result) <$> STM.readTVar threadsVar
STM.check isMember
st <- STM.readTVar statusVar
CompleteApplicationResult
{ carState
, carThreads
, carTrace
} <- complete result st
STM.writeTVar statusVar carState
STM.modifyTVar' threadsVar (Set.delete (resultThread result))
pure $ do
traverse_ Async.cancel carThreads
traverse_ (traceWith errorPolicyTrace) carTrace
mainLoop errorPolicyTrace resQ threadsVar statusVar complete main
run
:: Tracer IO (WithAddr addr ErrorPolicyTrace)
-> Tracer IO AcceptConnectionsPolicyTrace
-> Socket addr channel
-> AcceptedConnectionsLimit
-> (IOException -> IO ())
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> CompleteConnection addr st tr r
-> Main st t
-> STM.TVar st
-> IO t
run :: forall addr channel st r tr t.
Tracer IO (WithAddr addr ErrorPolicyTrace)
-> Tracer IO AcceptConnectionsPolicyTrace
-> Socket addr channel
-> AcceptedConnectionsLimit
-> (IOException -> IO ())
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> CompleteConnection addr st tr r
-> Main st t
-> TVar st
-> IO t
run Tracer IO (WithAddr addr ErrorPolicyTrace)
errroPolicyTrace Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace Socket addr channel
socket AcceptedConnectionsLimit
acceptedConnectionLimit IOException -> IO ()
acceptException BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart CompleteConnection addr st tr r
complete Main st t
main TVar st
statusVar = do
resQ <- IO (TQueue (Result addr r))
forall a. IO (TQueue a)
STM.newTQueueIO
threadsVar <- STM.newTVarIO Set.empty
let acceptLoopDo = Tracer IO AcceptConnectionsPolicyTrace
-> TQueue (Result addr r)
-> ThreadsVar
-> TVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
forall addr r st channel.
Tracer IO AcceptConnectionsPolicyTrace
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> AcceptedConnectionsLimit
-> BeginConnection addr channel st r
-> ApplicationStart addr st
-> (IOException -> IO ())
-> Socket addr channel
-> IO ()
acceptLoop Tracer IO AcceptConnectionsPolicyTrace
acceptPolicyTrace TQueue (Result addr r)
resQ ThreadsVar
threadsVar TVar st
statusVar AcceptedConnectionsLimit
acceptedConnectionLimit BeginConnection addr channel st r
beginConnection ApplicationStart addr st
applicationStart IOException -> IO ()
acceptException Socket addr channel
socket
mainDo = Tracer IO (WithAddr addr ErrorPolicyTrace)
-> TQueue (Result addr r)
-> ThreadsVar
-> TVar st
-> CompleteConnection addr st tr r
-> Main st t
-> IO t
forall addr st tr r t.
Tracer IO (WithAddr addr ErrorPolicyTrace)
-> ResultQ addr r
-> ThreadsVar
-> StatusVar st
-> CompleteConnection addr st tr r
-> Main st t
-> IO t
mainLoop Tracer IO (WithAddr addr ErrorPolicyTrace)
errroPolicyTrace TQueue (Result addr r)
resQ ThreadsVar
threadsVar TVar st
statusVar CompleteConnection addr st tr r
complete Main st t
main
killChildren = do
children <- STM (Set (Async ())) -> IO (Set (Async ()))
forall a. STM a -> IO a
STM.atomically (STM (Set (Async ())) -> IO (Set (Async ())))
-> STM (Set (Async ())) -> IO (Set (Async ()))
forall a b. (a -> b) -> a -> b
$ ThreadsVar -> STM (Set (Async ()))
forall a. TVar a -> STM a
STM.readTVar ThreadsVar
threadsVar
forM_ (Set.toList children) Async.cancel
(snd <$> Async.concurrently acceptLoopDo mainDo) `finally` killChildren