{-# LANGUAGE NamedFieldPuns #-}

-- | Rage limiting of accepted connections
--
module Ouroboros.Network.Server.RateLimiting
  ( AcceptedConnectionsLimit (..)
  , runConnectionRateLimits
    -- * Tracing
  , AcceptConnectionsPolicyTrace (..)
  ) where

import Control.Monad (when)
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer.SI
import Control.Tracer (Tracer, traceWith)

import Data.Typeable (Typeable)
import Data.Word
import Text.Printf


-- | Policy which governs how to limit the number of accepted connections.
--
data AcceptedConnectionsLimit = AcceptedConnectionsLimit {

    -- | Hard limit of accepted connections.
    --
    AcceptedConnectionsLimit -> Word32
acceptedConnectionsHardLimit :: !Word32,

    -- | Soft limit of accepted connections.  If we are above this threshold,
    -- we will start rate limiting.
    --
    AcceptedConnectionsLimit -> Word32
acceptedConnectionsSoftLimit :: !Word32,

    -- | Max delay for limiting accepted connections.  We use linear
    -- regression starting from 0 at the soft limit up to
    -- `acceptedConnectionDelay` at the hard limit.
    --
    AcceptedConnectionsLimit -> DiffTime
acceptedConnectionsDelay     :: !DiffTime
  }
  deriving (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
(AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> Eq AcceptedConnectionsLimit
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
== :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
$c/= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
/= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
Eq, Eq AcceptedConnectionsLimit
Eq AcceptedConnectionsLimit =>
(AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Ordering)
-> (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> (AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool)
-> (AcceptedConnectionsLimit
    -> AcceptedConnectionsLimit -> AcceptedConnectionsLimit)
-> (AcceptedConnectionsLimit
    -> AcceptedConnectionsLimit -> AcceptedConnectionsLimit)
-> Ord AcceptedConnectionsLimit
AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Ordering
AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Ordering
compare :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Ordering
$c< :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
< :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
$c<= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
<= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
$c> :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
> :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
$c>= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
>= :: AcceptedConnectionsLimit -> AcceptedConnectionsLimit -> Bool
$cmax :: AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit
max :: AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit
$cmin :: AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit
min :: AcceptedConnectionsLimit
-> AcceptedConnectionsLimit -> AcceptedConnectionsLimit
Ord, Int -> AcceptedConnectionsLimit -> ShowS
[AcceptedConnectionsLimit] -> ShowS
AcceptedConnectionsLimit -> String
(Int -> AcceptedConnectionsLimit -> ShowS)
-> (AcceptedConnectionsLimit -> String)
-> ([AcceptedConnectionsLimit] -> ShowS)
-> Show AcceptedConnectionsLimit
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AcceptedConnectionsLimit -> ShowS
showsPrec :: Int -> AcceptedConnectionsLimit -> ShowS
$cshow :: AcceptedConnectionsLimit -> String
show :: AcceptedConnectionsLimit -> String
$cshowList :: [AcceptedConnectionsLimit] -> ShowS
showList :: [AcceptedConnectionsLimit] -> ShowS
Show)


-- | Rate limiting instruction.
--
data RateLimitDelay =
      -- | no rate limiting
      --
      NoRateLimiting

      -- | We are above the soft limit, we delay accepting the next connection>
    | SoftDelay DiffTime

      -- | We are above the hard limit, wait until the number of connections
      -- drops below the given threshold (currently this is the hard limit,
      -- which means we keep `acceptedConnectionsHardLimit` number of
      -- connections, later we c could be configured to something between
      -- `acceptedConnesiontSoftLimit` and `acceptedConnectionsHardLimit`).
      --
    | HardLimit Word32


-- | Interpretation of the 'AcceptedConnectionsLimit' policy.
--
getRateLimitDecision :: Int
                     -- ^ number of served concurrent connections
                     -> AcceptedConnectionsLimit
                     -- ^ limits
                     -> RateLimitDelay
getRateLimitDecision :: Int -> AcceptedConnectionsLimit -> RateLimitDelay
getRateLimitDecision Int
numberOfConnections
                     AcceptedConnectionsLimit { Word32
acceptedConnectionsHardLimit :: AcceptedConnectionsLimit -> Word32
acceptedConnectionsHardLimit :: Word32
acceptedConnectionsHardLimit
                                              , Word32
acceptedConnectionsSoftLimit :: AcceptedConnectionsLimit -> Word32
acceptedConnectionsSoftLimit :: Word32
acceptedConnectionsSoftLimit
                                              , DiffTime
acceptedConnectionsDelay :: AcceptedConnectionsLimit -> DiffTime
acceptedConnectionsDelay :: DiffTime
acceptedConnectionsDelay
                                              }
    -- below the soft limit we accept connections without any delay
    | Int
numberOfConnections  Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
softLimit = RateLimitDelay
NoRateLimiting

    -- above the hard limit will will wait until the number of connections drops
    -- below the soft limit
    | Int
numberOfConnections Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
hardLimit = Word32 -> RateLimitDelay
HardLimit Word32
acceptedConnectionsHardLimit

    -- in between we scale the delay using linear regression.
    | Bool
otherwise =
        DiffTime -> RateLimitDelay
SoftDelay (DiffTime -> RateLimitDelay) -> DiffTime -> RateLimitDelay
forall a b. (a -> b) -> a -> b
$
            Int -> DiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
numberOfConnections Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
softLimit)
          DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
* DiffTime
acceptedConnectionsDelay
          DiffTime -> DiffTime -> DiffTime
forall a. Fractional a => a -> a -> a
/ Int -> DiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int
hardLimit Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
softLimit) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
1)
  where
    hardLimit, softLimit :: Int
    hardLimit :: Int
hardLimit = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
acceptedConnectionsHardLimit
    softLimit :: Int
softLimit = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
acceptedConnectionsSoftLimit


-- | Get the number of current connections, make decision based on
-- 'AcceptedConnectionsLimit' and execute it.
--
runConnectionRateLimits
    :: ( MonadSTM   m
       , MonadDelay m
       )
    => Tracer m AcceptConnectionsPolicyTrace
    -> STM m Int
    -> AcceptedConnectionsLimit
    -> m ()
runConnectionRateLimits :: forall (m :: * -> *).
(MonadSTM m, MonadDelay m) =>
Tracer m AcceptConnectionsPolicyTrace
-> STM m Int -> AcceptedConnectionsLimit -> m ()
runConnectionRateLimits Tracer m AcceptConnectionsPolicyTrace
tracer
                        STM m Int
numberOfConnectionsSTM
                        acceptedConnectionsLimit :: AcceptedConnectionsLimit
acceptedConnectionsLimit@AcceptedConnectionsLimit
                          { DiffTime
acceptedConnectionsDelay :: AcceptedConnectionsLimit -> DiffTime
acceptedConnectionsDelay :: DiffTime
acceptedConnectionsDelay } = do
    numberOfConnections <- STM m Int -> m Int
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m Int
numberOfConnectionsSTM
    case getRateLimitDecision numberOfConnections acceptedConnectionsLimit of

      RateLimitDelay
NoRateLimiting  -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

      SoftDelay DiffTime
delay -> do
        Tracer m AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m AcceptConnectionsPolicyTrace
tracer (DiffTime -> Int -> AcceptConnectionsPolicyTrace
ServerTraceAcceptConnectionRateLimiting DiffTime
delay Int
numberOfConnections)
        DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
delay

      -- wait until the current number of connection drops below the limit, and
      -- wait at least 'acceptedConnectionsDelay'.  This is to avoid accepting
      -- the last connection to frequently if it fails almost immediately .
      HardLimit Word32
limit -> do
        Tracer m AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m AcceptConnectionsPolicyTrace
tracer (Word32 -> AcceptConnectionsPolicyTrace
ServerTraceAcceptConnectionHardLimit Word32
limit)
        start <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
        atomically $ do
          numberOfConnections' <- numberOfConnectionsSTM
          check (numberOfConnections' < fromIntegral limit)
        end <- getMonotonicTime
        let remainingDelay = DiffTime
acceptedConnectionsDelay DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
- Time
end Time -> Time -> DiffTime
`diffTime` Time
start
        when (remainingDelay > 0)
          $ threadDelay remainingDelay
        numberOfConnections' <- atomically numberOfConnectionsSTM
        traceWith tracer $ ServerTraceAcceptConnectionResume numberOfConnections'


--
-- trace
--


-- | Trace for the 'AcceptConnectionsLimit' policy.
--
data AcceptConnectionsPolicyTrace
      = ServerTraceAcceptConnectionRateLimiting DiffTime Int
      | ServerTraceAcceptConnectionHardLimit Word32
      | ServerTraceAcceptConnectionResume Int
  deriving (AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
(AcceptConnectionsPolicyTrace
 -> AcceptConnectionsPolicyTrace -> Bool)
-> (AcceptConnectionsPolicyTrace
    -> AcceptConnectionsPolicyTrace -> Bool)
-> Eq AcceptConnectionsPolicyTrace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
== :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
$c/= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
/= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
Eq, Eq AcceptConnectionsPolicyTrace
Eq AcceptConnectionsPolicyTrace =>
(AcceptConnectionsPolicyTrace
 -> AcceptConnectionsPolicyTrace -> Ordering)
-> (AcceptConnectionsPolicyTrace
    -> AcceptConnectionsPolicyTrace -> Bool)
-> (AcceptConnectionsPolicyTrace
    -> AcceptConnectionsPolicyTrace -> Bool)
-> (AcceptConnectionsPolicyTrace
    -> AcceptConnectionsPolicyTrace -> Bool)
-> (AcceptConnectionsPolicyTrace
    -> AcceptConnectionsPolicyTrace -> Bool)
-> (AcceptConnectionsPolicyTrace
    -> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace)
-> (AcceptConnectionsPolicyTrace
    -> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace)
-> Ord AcceptConnectionsPolicyTrace
AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Ordering
AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Ordering
compare :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Ordering
$c< :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
< :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
$c<= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
<= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
$c> :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
> :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
$c>= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
>= :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> Bool
$cmax :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace
max :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace
$cmin :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace
min :: AcceptConnectionsPolicyTrace
-> AcceptConnectionsPolicyTrace -> AcceptConnectionsPolicyTrace
Ord, Typeable)

instance Show AcceptConnectionsPolicyTrace where
    show :: AcceptConnectionsPolicyTrace -> String
show (ServerTraceAcceptConnectionRateLimiting DiffTime
delay Int
numberOfConnections) =
      String -> String -> ShowS
forall r. PrintfType r => String -> r
printf
        String
"rate limiting accepting connections, delaying next accept for %s, currently serving %s connections"
        (DiffTime -> String
forall a. Show a => a -> String
show DiffTime
delay) (Int -> String
forall a. Show a => a -> String
show Int
numberOfConnections)
    show (ServerTraceAcceptConnectionHardLimit Word32
limit) =
      String -> ShowS
forall r. PrintfType r => String -> r
printf
        String
"hard rate limit reached, waiting until the number of connections drops below %s"
        (Word32 -> String
forall a. Show a => a -> String
show Word32
limit)
    show (ServerTraceAcceptConnectionResume Int
numberOfConnections) =
      String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"hard rate limit over, accepting connections again, currently serving %d connections"
        Int
numberOfConnections