{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}

module Ouroboros.Network.PeerSelection.RootPeersDNS.PublicRootPeers
  ( -- * DNS based provider for public root peers
    publicRootPeersProvider
  , TracePublicRootPeers (..)
  ) where

import Data.List (partition)
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import Data.Word (Word32)
import System.Random

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Tracer (Tracer (..), traceWith)

import Network.DNS (DNSError)
import Network.DNS qualified as DNS
import Network.Socket qualified as Socket

import Ouroboros.Network.PeerSelection.PeerAdvertise (PeerAdvertise)
import Ouroboros.Network.PeerSelection.RelayAccessPoint
import Ouroboros.Network.PeerSelection.RootPeersDNS.DNSActions
import Ouroboros.Network.PeerSelection.RootPeersDNS.DNSSemaphore (DNSSemaphore,
           withDNSSemaphore)

---------------------------------------------
-- Public root peer set provider using DNS
--

data TracePublicRootPeers =
       TracePublicRootRelayAccessPoint (Map RelayAccessPoint PeerAdvertise)
     | TracePublicRootDomains [RelayAccessPoint]
  deriving Int -> TracePublicRootPeers -> ShowS
[TracePublicRootPeers] -> ShowS
TracePublicRootPeers -> String
(Int -> TracePublicRootPeers -> ShowS)
-> (TracePublicRootPeers -> String)
-> ([TracePublicRootPeers] -> ShowS)
-> Show TracePublicRootPeers
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TracePublicRootPeers -> ShowS
showsPrec :: Int -> TracePublicRootPeers -> ShowS
$cshow :: TracePublicRootPeers -> String
show :: TracePublicRootPeers -> String
$cshowList :: [TracePublicRootPeers] -> ShowS
showList :: [TracePublicRootPeers] -> ShowS
Show

-- | fulfills a request from 'requestPublicRootPeers'
--
publicRootPeersProvider
  :: forall peerAddr resolver exception a m.
     (MonadThrow m, MonadAsync m, Exception exception,
      Ord peerAddr)
  => Tracer m TracePublicRootPeers
  -> (IP -> Socket.PortNumber -> peerAddr)
  -> DNSSemaphore m
  -> DNS.ResolvConf
  -> STM m (Map RelayAccessPoint PeerAdvertise)
  -> DNSActions peerAddr resolver exception m
  -> StdGen
  -> ((Int -> m (Map peerAddr PeerAdvertise, DiffTime)) -> m a)
  -> m a
publicRootPeersProvider :: forall peerAddr resolver exception a (m :: * -> *).
(MonadThrow m, MonadAsync m, Exception exception, Ord peerAddr) =>
Tracer m TracePublicRootPeers
-> (IP -> PortNumber -> peerAddr)
-> DNSSemaphore m
-> ResolvConf
-> STM m (Map RelayAccessPoint PeerAdvertise)
-> DNSActions peerAddr resolver exception m
-> StdGen
-> ((Int -> m (Map peerAddr PeerAdvertise, DiffTime)) -> m a)
-> m a
publicRootPeersProvider Tracer m TracePublicRootPeers
tracer
                        IP -> PortNumber -> peerAddr
toPeerAddr
                        DNSSemaphore m
dnsSemaphore
                        ResolvConf
resolvConf
                        STM m (Map RelayAccessPoint PeerAdvertise)
readDomains
                        DNSActions {
                          ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource :: ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource :: forall peerAddr resolver exception (m :: * -> *).
DNSActions peerAddr resolver exception m
-> ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource,
                          DNSPeersKind
-> RelayAccessPoint
-> ResolvConf
-> resolver
-> StdGen
-> m (DNSLookupResult peerAddr)
dnsLookupWithTTL :: DNSPeersKind
-> RelayAccessPoint
-> ResolvConf
-> resolver
-> StdGen
-> m (DNSLookupResult peerAddr)
dnsLookupWithTTL :: forall peerAddr resolver exception (m :: * -> *).
DNSActions peerAddr resolver exception m
-> DNSPeersKind
-> RelayAccessPoint
-> ResolvConf
-> resolver
-> StdGen
-> m (DNSLookupResult peerAddr)
dnsLookupWithTTL
                        }
                        StdGen
rng
                        (Int -> m (Map peerAddr PeerAdvertise, DiffTime)) -> m a
action = do
    domains <- STM m (Map RelayAccessPoint PeerAdvertise)
-> m (Map RelayAccessPoint PeerAdvertise)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m (Map RelayAccessPoint PeerAdvertise)
readDomains
    traceWith tracer (TracePublicRootRelayAccessPoint domains)
    rr <- dnsResolverResource resolvConf
    resourceVar <- newTVarIO rr
    action (requestPublicRootPeers resourceVar)
  where
    requestPublicRootPeers
      :: StrictTVar m (Resource m (Either (DNSorIOError exception) resolver))
      -> Int
      -> m (Map peerAddr PeerAdvertise, DiffTime)
    requestPublicRootPeers :: StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
-> Int -> m (Map peerAddr PeerAdvertise, DiffTime)
requestPublicRootPeers StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
resourceVar Int
_numRequested = do
        domains <- STM m (Map RelayAccessPoint PeerAdvertise)
-> m (Map RelayAccessPoint PeerAdvertise)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m (Map RelayAccessPoint PeerAdvertise)
readDomains
        traceWith tracer (TracePublicRootRelayAccessPoint domains)
        rr <- readTVarIO resourceVar
        (er, rr') <- withResource rr
        atomically $ writeTVar resourceVar rr'
        case er of
          Left (DNSError DNSError
err) -> DNSError -> m (Map peerAddr PeerAdvertise, DiffTime)
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO DNSError
err
          Left (IOError  exception
err) -> exception -> m (Map peerAddr PeerAdvertise, DiffTime)
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO exception
err
          Right resolver
resolver -> do
            let ([(RelayAccessPoint, PeerAdvertise)]
doms, [(RelayAccessPoint, PeerAdvertise)]
relayAddrs) =
                  (((RelayAccessPoint, PeerAdvertise) -> Bool)
 -> [(RelayAccessPoint, PeerAdvertise)]
 -> ([(RelayAccessPoint, PeerAdvertise)],
     [(RelayAccessPoint, PeerAdvertise)]))
-> [(RelayAccessPoint, PeerAdvertise)]
-> ((RelayAccessPoint, PeerAdvertise) -> Bool)
-> ([(RelayAccessPoint, PeerAdvertise)],
    [(RelayAccessPoint, PeerAdvertise)])
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((RelayAccessPoint, PeerAdvertise) -> Bool)
-> [(RelayAccessPoint, PeerAdvertise)]
-> ([(RelayAccessPoint, PeerAdvertise)],
    [(RelayAccessPoint, PeerAdvertise)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Map RelayAccessPoint PeerAdvertise
-> [(RelayAccessPoint, PeerAdvertise)]
forall k a. Map k a -> [(k, a)]
Map.assocs Map RelayAccessPoint PeerAdvertise
domains) (((RelayAccessPoint, PeerAdvertise) -> Bool)
 -> ([(RelayAccessPoint, PeerAdvertise)],
     [(RelayAccessPoint, PeerAdvertise)]))
-> ((RelayAccessPoint, PeerAdvertise) -> Bool)
-> ([(RelayAccessPoint, PeerAdvertise)],
    [(RelayAccessPoint, PeerAdvertise)])
forall a b. (a -> b) -> a -> b
$ \case
                    (RelayAccessAddress {}, PeerAdvertise
_) -> Bool
False
                    (RelayAccessPoint, PeerAdvertise)
_otherwise                 -> Bool
True
                lookups :: [m (DNSLookupResult peerAddr, PeerAdvertise)]
lookups =
                  [ (, PeerAdvertise
pa)
                      (DNSLookupResult peerAddr
 -> (DNSLookupResult peerAddr, PeerAdvertise))
-> m (DNSLookupResult peerAddr)
-> m (DNSLookupResult peerAddr, PeerAdvertise)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (do
                        String -> m ()
forall (m :: * -> *). MonadThread m => String -> m ()
labelThisThread String
"dnsLookupWithTTL"
                        DNSSemaphore m
-> m (DNSLookupResult peerAddr) -> m (DNSLookupResult peerAddr)
forall (m :: * -> *) a.
(MonadSTM m, MonadThrow m) =>
DNSSemaphore m -> m a -> m a
withDNSSemaphore DNSSemaphore m
dnsSemaphore
                            (DNSPeersKind
-> RelayAccessPoint
-> ResolvConf
-> resolver
-> StdGen
-> m (DNSLookupResult peerAddr)
dnsLookupWithTTL
                              DNSPeersKind
DNSPublicPeer
                              RelayAccessPoint
domain
                              ResolvConf
resolvConf
                              resolver
resolver
                              StdGen
rng))
                  | (RelayAccessPoint
domain, PeerAdvertise
pa) <- [(RelayAccessPoint, PeerAdvertise)]
doms
                  , case RelayAccessPoint
domain of
                      RelayAccessAddress {}   -> Bool
False
                      RelayAccessDomain  {}   -> Bool
True
                      RelayAccessSRVDomain {} -> Bool
True
                  ]
            -- The timeouts here are handled by the 'lookupWithTTL'. They're
            -- configured via the DNS.ResolvConf resolvTimeout field and defaults
            -- to 3 sec.
            results  <- [m (DNSLookupResult peerAddr, PeerAdvertise)]
-> ([Async m (DNSLookupResult peerAddr, PeerAdvertise)]
    -> m [(DNSLookupResult peerAddr, PeerAdvertise)])
-> m [(DNSLookupResult peerAddr, PeerAdvertise)]
forall (m :: * -> *) a b.
MonadAsync m =>
[m a] -> ([Async m a] -> m b) -> m b
withAsyncAll [m (DNSLookupResult peerAddr, PeerAdvertise)]
lookups (STM m [(DNSLookupResult peerAddr, PeerAdvertise)]
-> m [(DNSLookupResult peerAddr, PeerAdvertise)]
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m [(DNSLookupResult peerAddr, PeerAdvertise)]
 -> m [(DNSLookupResult peerAddr, PeerAdvertise)])
-> ([Async m (DNSLookupResult peerAddr, PeerAdvertise)]
    -> STM m [(DNSLookupResult peerAddr, PeerAdvertise)])
-> [Async m (DNSLookupResult peerAddr, PeerAdvertise)]
-> m [(DNSLookupResult peerAddr, PeerAdvertise)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Async m (DNSLookupResult peerAddr, PeerAdvertise)
 -> STM m (DNSLookupResult peerAddr, PeerAdvertise))
-> [Async m (DNSLookupResult peerAddr, PeerAdvertise)]
-> STM m [(DNSLookupResult peerAddr, PeerAdvertise)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Async m (DNSLookupResult peerAddr, PeerAdvertise)
-> STM m (DNSLookupResult peerAddr, PeerAdvertise)
forall a. Async m a -> STM m a
forall (m :: * -> *) a. MonadAsync m => Async m a -> STM m a
waitSTM)
            let successes = [ ( (peerAddr
addr, PeerAdvertise
pa)
                              , Word32
ttl')
                            | ( Right [(peerAddr, Word32)]
addrttls
                              , PeerAdvertise
pa) <- [(DNSLookupResult peerAddr, PeerAdvertise)]
results
                            , (peerAddr
addr, Word32
ttl') <- [(peerAddr, Word32)]
addrttls
                            ]
                !domainsIps = [(IP -> PortNumber -> peerAddr
toPeerAddr IP
ip PortNumber
port, PeerAdvertise
pa)
                              | (RelayAccessAddress IP
ip PortNumber
port, PeerAdvertise
pa) <- [(RelayAccessPoint, PeerAdvertise)]
relayAddrs ]
                !addrs      = [(peerAddr, PeerAdvertise)] -> Map peerAddr PeerAdvertise
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ((((peerAddr, PeerAdvertise), Word32) -> (peerAddr, PeerAdvertise))
-> [((peerAddr, PeerAdvertise), Word32)]
-> [(peerAddr, PeerAdvertise)]
forall a b. (a -> b) -> [a] -> [b]
map ((peerAddr, PeerAdvertise), Word32) -> (peerAddr, PeerAdvertise)
forall a b. (a, b) -> a
fst [((peerAddr, PeerAdvertise), Word32)]
successes) Map peerAddr PeerAdvertise
-> Map peerAddr PeerAdvertise -> Map peerAddr PeerAdvertise
forall k a. Ord k => Map k a -> Map k a -> Map k a
`Map.union` [(peerAddr, PeerAdvertise)] -> Map peerAddr PeerAdvertise
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(peerAddr, PeerAdvertise)]
domainsIps
                !ttl        = if [m (DNSLookupResult peerAddr, PeerAdvertise)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [m (DNSLookupResult peerAddr, PeerAdvertise)]
lookups
                                then -- Not having any peers with domains configured is not
                                     -- a DNS error.
                                     [Word32] -> DiffTime
ttlForResults [Word32
60]
                                else [Word32] -> DiffTime
ttlForResults ((((peerAddr, PeerAdvertise), Word32) -> Word32)
-> [((peerAddr, PeerAdvertise), Word32)] -> [Word32]
forall a b. (a -> b) -> [a] -> [b]
map ((peerAddr, PeerAdvertise), Word32) -> Word32
forall a b. (a, b) -> b
snd [((peerAddr, PeerAdvertise), Word32)]
successes)
            -- If all the lookups failed we'll return an empty set with a minimum
            -- TTL, and the governor will invoke its exponential backoff.
            return (addrs, ttl)

-- Aux

withAsyncAll :: MonadAsync m => [m a] -> ([Async m a] -> m b) -> m b
withAsyncAll :: forall (m :: * -> *) a b.
MonadAsync m =>
[m a] -> ([Async m a] -> m b) -> m b
withAsyncAll [m a]
xs0 [Async m a] -> m b
action = [Async m a] -> [m a] -> m b
go [] [m a]
xs0
  where
    go :: [Async m a] -> [m a] -> m b
go [Async m a]
as []     = [Async m a] -> m b
action ([Async m a] -> [Async m a]
forall a. [a] -> [a]
reverse [Async m a]
as)
    go [Async m a]
as (m a
x:[m a]
xs) = m a -> (Async m a -> m b) -> m b
forall a b. m a -> (Async m a -> m b) -> m b
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync m a
x (\Async m a
a -> [Async m a] -> [m a] -> m b
go (Async m a
aAsync m a -> [Async m a] -> [Async m a]
forall a. a -> [a] -> [a]
:[Async m a]
as) [m a]
xs)

-- | Policy for TTL for positive results
ttlForResults :: [DNS.TTL] -> DiffTime

-- This case says we have a successful reply but there is no answer.
-- This covers for example non-existent TLDs since there is no authority
-- to say that they should not exist.
ttlForResults :: [Word32] -> DiffTime
ttlForResults []   = DNSError -> DiffTime -> DiffTime
ttlForDnsError DNSError
DNS.NameError DiffTime
0
ttlForResults [Word32]
ttls = DiffTime -> DiffTime
clipTTLBelow
                   (DiffTime -> DiffTime)
-> (Word32 -> DiffTime) -> Word32 -> DiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> DiffTime
clipTTLAbove
                   (DiffTime -> DiffTime)
-> (Word32 -> DiffTime) -> Word32 -> DiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> DiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Word32 -> DiffTime)
                   (Word32 -> DiffTime) -> Word32 -> DiffTime
forall a b. (a -> b) -> a -> b
$ [Word32] -> Word32
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [Word32]
ttls

-- | Limit insane TTL choices.
clipTTLAbove, clipTTLBelow :: DiffTime -> DiffTime
clipTTLBelow :: DiffTime -> DiffTime
clipTTLBelow = DiffTime -> DiffTime -> DiffTime
forall a. Ord a => a -> a -> a
max DiffTime
60     -- between 1min
clipTTLAbove :: DiffTime -> DiffTime
clipTTLAbove = DiffTime -> DiffTime -> DiffTime
forall a. Ord a => a -> a -> a
min DiffTime
86400  -- and 24hrs

-- | Policy for TTL for negative results
-- Cache negative response for 3hrs
-- Otherwise, use exponential backoff, up to a limit
ttlForDnsError :: DNSError -> DiffTime -> DiffTime
ttlForDnsError :: DNSError -> DiffTime -> DiffTime
ttlForDnsError DNSError
DNS.NameError DiffTime
_ = DiffTime
10800
ttlForDnsError DNSError
_           DiffTime
ttl = DiffTime -> DiffTime
clipTTLAbove (DiffTime
ttl DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
* DiffTime
2 DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
+ DiffTime
5)