{-# LANGUAGE NamedFieldPuns #-}
module Ouroboros.Network.Server.RateLimiting
( AcceptedConnectionsLimit (..)
, runConnectionRateLimits
, AcceptConnectionsPolicyTrace (..)
) where
import Control.Monad (when)
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI
import Control.Tracer (Tracer, traceWith)
import Data.Typeable (Typeable)
import Data.Word
import Text.Printf
data AcceptedConnectionsLimit = AcceptedConnectionsLimit {
AcceptedConnectionsLimit -> Word32
acceptedConnectionsHardLimit :: !Word32,
AcceptedConnectionsLimit -> Word32
acceptedConnectionsSoftLimit :: !Word32,
AcceptedConnectionsLimit -> DiffTime
acceptedConnectionsDelay :: !DiffTime
}
deriving (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
(AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> Eq AcceptedConnectionsLimit
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
== :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
$c/= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
/= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
Eq, Eq AcceptedConnectionsLimit
Eq AcceptedConnectionsLimit =>
(AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Ordering)
-> (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> (AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit)
-> (AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit)
-> Ord AcceptedConnectionsLimit
AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Ordering
AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Ordering
compare :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Ordering
$c< :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
< :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
$c<= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
<= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
$c> :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
> :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
$c>= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
>= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
$cmax :: AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit
max :: AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit
$cmin :: AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit
min :: AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit
Ord, Int -> AcceptedConnectionsLimit -> ShowS
[AcceptedConnectionsLimit] -> ShowS
AcceptedConnectionsLimit -> String
(Int -> AcceptedConnectionsLimit -> ShowS)
-> (AcceptedConnectionsLimit -> String)
-> ([AcceptedConnectionsLimit] -> ShowS)
-> Show AcceptedConnectionsLimit
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AcceptedConnectionsLimit -> ShowS
showsPrec :: Int -> AcceptedConnectionsLimit -> ShowS
$cshow :: AcceptedConnectionsLimit -> String
show :: AcceptedConnectionsLimit -> String
$cshowList :: [AcceptedConnectionsLimit] -> ShowS
showList :: [AcceptedConnectionsLimit] -> ShowS
Show)
data RateLimitDelay =
NoRateLimiting
| SoftDelay DiffTime
| HardLimit Word32
getRateLimitDecision :: Int
-> AcceptedConnectionsLimit
-> RateLimitDelay
getRateLimitDecision :: Int -> AcceptedConnectionsLimit -> RateLimitDelay
getRateLimitDecision Int
numberOfConnections
AcceptedConnectionsLimit { Word32
acceptedConnectionsHardLimit :: AcceptedConnectionsLimit -> Word32
acceptedConnectionsHardLimit :: Word32
acceptedConnectionsHardLimit
, Word32
acceptedConnectionsSoftLimit :: AcceptedConnectionsLimit -> Word32
acceptedConnectionsSoftLimit :: Word32
acceptedConnectionsSoftLimit
, DiffTime
acceptedConnectionsDelay :: AcceptedConnectionsLimit -> DiffTime
acceptedConnectionsDelay :: DiffTime
acceptedConnectionsDelay
}
| Int
numberOfConnections Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
softLimit = RateLimitDelay
NoRateLimiting
| Int
numberOfConnections Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
hardLimit = Word32 -> RateLimitDelay
HardLimit Word32
acceptedConnectionsHardLimit
| Bool
otherwise =
DiffTime -> RateLimitDelay
SoftDelay (DiffTime -> RateLimitDelay) -> DiffTime -> RateLimitDelay
forall a b. (a -> b) -> a -> b
$
Int -> DiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
numberOfConnections Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
softLimit)
DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
* DiffTime
acceptedConnectionsDelay
DiffTime -> DiffTime -> DiffTime
forall a. Fractional a => a -> a -> a
/ Int -> DiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int
hardLimit Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
softLimit) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
1)
where
hardLimit, softLimit :: Int
hardLimit :: Int
hardLimit = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
acceptedConnectionsHardLimit
softLimit :: Int
softLimit = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
acceptedConnectionsSoftLimit
runConnectionRateLimits
:: ( MonadSTM m
, MonadDelay m
)
=> Tracer m AcceptConnectionsPolicyTrace
-> STM m Int
-> AcceptedConnectionsLimit
-> m ()
runConnectionRateLimits :: forall (m :: * -> *).
(MonadSTM m, MonadDelay m) =>
Tracer m AcceptConnectionsPolicyTrace
-> STM m Int -> AcceptedConnectionsLimit -> m ()
runConnectionRateLimits Tracer m AcceptConnectionsPolicyTrace
tracer
STM m Int
numberOfConnectionsSTM
acceptedConnectionsLimit :: AcceptedConnectionsLimit
acceptedConnectionsLimit@AcceptedConnectionsLimit
{ DiffTime
acceptedConnectionsDelay :: AcceptedConnectionsLimit -> DiffTime
acceptedConnectionsDelay :: DiffTime
acceptedConnectionsDelay } = do
numberOfConnections <- STM m Int -> m Int
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m Int
numberOfConnectionsSTM
case getRateLimitDecision numberOfConnections acceptedConnectionsLimit of
RateLimitDelay
NoRateLimiting -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
SoftDelay DiffTime
delay -> do
Tracer m AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m AcceptConnectionsPolicyTrace
tracer (DiffTime -> Int -> AcceptConnectionsPolicyTrace
ServerTraceAcceptConnectionRateLimiting DiffTime
delay Int
numberOfConnections)
DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
delay
HardLimit Word32
limit -> do
Tracer m AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m AcceptConnectionsPolicyTrace
tracer (Word32 -> AcceptConnectionsPolicyTrace
ServerTraceAcceptConnectionHardLimit Word32
limit)
start <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
atomically $ do
numberOfConnections' <- numberOfConnectionsSTM
check (numberOfConnections' < fromIntegral limit)
end <- getMonotonicTime
let remainingDelay = DiffTime
acceptedConnectionsDelay DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
- Time
end Time -> Time -> DiffTime
`diffTime` Time
start
when (remainingDelay > 0)
$ threadDelay remainingDelay
numberOfConnections' <- atomically numberOfConnectionsSTM
traceWith tracer $ ServerTraceAcceptConnectionResume numberOfConnections'
data AcceptConnectionsPolicyTrace
= ServerTraceAcceptConnectionRateLimiting DiffTime Int
| ServerTraceAcceptConnectionHardLimit Word32
| ServerTraceAcceptConnectionResume Int
deriving (AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
(AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool)
-> (AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool)
-> Eq AcceptConnectionsPolicyTrace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
== :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
$c/= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
/= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
Eq, Eq AcceptConnectionsPolicyTrace
Eq AcceptConnectionsPolicyTrace =>
(AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Ordering)
-> (AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool)
-> (AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool)
-> (AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool)
-> (AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool)
-> (AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace)
-> (AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace)
-> Ord AcceptConnectionsPolicyTrace
AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Ordering
AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Ordering
compare :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Ordering
$c< :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
< :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
$c<= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
<= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
$c> :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
> :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
$c>= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
>= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
$cmax :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace
max :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace
$cmin :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace
min :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace
Ord, Typeable)
instance Show AcceptConnectionsPolicyTrace where
show :: AcceptConnectionsPolicyTrace -> String
show (ServerTraceAcceptConnectionRateLimiting DiffTime
delay Int
numberOfConnections) =
String -> String -> ShowS
forall r. PrintfType r => String -> r
printf
String
"rate limiting accepting connections, delaying next accept for %s, currently serving %s connections"
(DiffTime -> String
forall a. Show a => a -> String
show DiffTime
delay) (Int -> String
forall a. Show a => a -> String
show Int
numberOfConnections)
show (ServerTraceAcceptConnectionHardLimit Word32
limit) =
String -> ShowS
forall r. PrintfType r => String -> r
printf
String
"hard rate limit reached, waiting until the number of connections drops below %s"
(Word32 -> String
forall a. Show a => a -> String
show Word32
limit)
show (ServerTraceAcceptConnectionResume Int
numberOfConnections) =
String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"hard rate limit over, accepting connections again, currently serving %d connections"
Int
numberOfConnections