{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}

module Ouroboros.Network.PeerSelection.RootPeersDNS.PublicRootPeers
  ( -- * DNS based provider for public root peers
    publicRootPeersProvider
  , TracePublicRootPeers (..)
  ) where

import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import Data.Word (Word32)

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad (when)
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Control.Tracer (Tracer (..), traceWith)


import Network.DNS qualified as DNS
import Network.Socket qualified as Socket

import Ouroboros.Network.PeerSelection.PeerAdvertise (PeerAdvertise)
import Ouroboros.Network.PeerSelection.RelayAccessPoint
import Ouroboros.Network.PeerSelection.RootPeersDNS.DNSActions (DNSActions (..),
           DNSorIOError (..), Resource (..))
import Ouroboros.Network.PeerSelection.RootPeersDNS.DNSSemaphore (DNSSemaphore,
           withDNSSemaphore)

---------------------------------------------
-- Public root peer set provider using DNS
--

data TracePublicRootPeers =
       TracePublicRootRelayAccessPoint (Map RelayAccessPoint PeerAdvertise)
     | TracePublicRootDomains [DomainAccessPoint]
     | TracePublicRootResult  DNS.Domain [(IP, DNS.TTL)]
     | TracePublicRootFailure DNS.Domain DNS.DNSError
       --TODO: classify DNS errors, config error vs transitory
  deriving Int -> TracePublicRootPeers -> ShowS
[TracePublicRootPeers] -> ShowS
TracePublicRootPeers -> String
(Int -> TracePublicRootPeers -> ShowS)
-> (TracePublicRootPeers -> String)
-> ([TracePublicRootPeers] -> ShowS)
-> Show TracePublicRootPeers
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TracePublicRootPeers -> ShowS
showsPrec :: Int -> TracePublicRootPeers -> ShowS
$cshow :: TracePublicRootPeers -> String
show :: TracePublicRootPeers -> String
$cshowList :: [TracePublicRootPeers] -> ShowS
showList :: [TracePublicRootPeers] -> ShowS
Show

-- |
--
publicRootPeersProvider
  :: forall peerAddr resolver exception a m.
     (MonadThrow m, MonadAsync m, Exception exception,
      Ord peerAddr)
  => Tracer m TracePublicRootPeers
  -> (IP -> Socket.PortNumber -> peerAddr)
  -> DNSSemaphore m
  -> DNS.ResolvConf
  -> STM m (Map RelayAccessPoint PeerAdvertise)
  -> DNSActions resolver exception m
  -> ((Int -> m (Map peerAddr PeerAdvertise, DiffTime)) -> m a)
  -> m a
publicRootPeersProvider :: forall peerAddr resolver exception a (m :: * -> *).
(MonadThrow m, MonadAsync m, Exception exception, Ord peerAddr) =>
Tracer m TracePublicRootPeers
-> (IP -> PortNumber -> peerAddr)
-> DNSSemaphore m
-> ResolvConf
-> STM m (Map RelayAccessPoint PeerAdvertise)
-> DNSActions resolver exception m
-> ((Int -> m (Map peerAddr PeerAdvertise, DiffTime)) -> m a)
-> m a
publicRootPeersProvider Tracer m TracePublicRootPeers
tracer
                        IP -> PortNumber -> peerAddr
toPeerAddr
                        DNSSemaphore m
dnsSemaphore
                        ResolvConf
resolvConf
                        STM m (Map RelayAccessPoint PeerAdvertise)
readDomains
                        DNSActions {
                          ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource :: ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource :: forall resolver exception (m :: * -> *).
DNSActions resolver exception m
-> ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource,
                          ResolvConf -> resolver -> Domain -> m ([DNSError], [(IP, Word32)])
dnsLookupWithTTL :: ResolvConf -> resolver -> Domain -> m ([DNSError], [(IP, Word32)])
dnsLookupWithTTL :: forall resolver exception (m :: * -> *).
DNSActions resolver exception m
-> ResolvConf
-> resolver
-> Domain
-> m ([DNSError], [(IP, Word32)])
dnsLookupWithTTL
                        }
                        (Int -> m (Map peerAddr PeerAdvertise, DiffTime)) -> m a
action = do
    domains <- STM m (Map RelayAccessPoint PeerAdvertise)
-> m (Map RelayAccessPoint PeerAdvertise)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m (Map RelayAccessPoint PeerAdvertise)
readDomains
    traceWith tracer (TracePublicRootRelayAccessPoint domains)
    rr <- dnsResolverResource resolvConf
    resourceVar <- newTVarIO rr
    action (requestPublicRootPeers resourceVar)
  where
    processResult :: ((DomainAccessPoint, PeerAdvertise), ([DNS.DNSError], [(IP, DNS.TTL)]))
                  -> m ((DomainAccessPoint, PeerAdvertise), [(IP, DNS.TTL)])
    processResult :: ((DomainAccessPoint, PeerAdvertise), ([DNSError], [(IP, Word32)]))
-> m ((DomainAccessPoint, PeerAdvertise), [(IP, Word32)])
processResult ((DomainAccessPoint
domain, PeerAdvertise
pa), ([DNSError]
errs, [(IP, Word32)]
result)) = do
        (DNSError -> m ()) -> [DNSError] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Tracer m TracePublicRootPeers -> TracePublicRootPeers -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m TracePublicRootPeers
tracer (TracePublicRootPeers -> m ())
-> (DNSError -> TracePublicRootPeers) -> DNSError -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Domain -> DNSError -> TracePublicRootPeers
TracePublicRootFailure (DomainAccessPoint -> Domain
dapDomain DomainAccessPoint
domain))
              [DNSError]
errs
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(IP, Word32)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(IP, Word32)]
result) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
            Tracer m TracePublicRootPeers -> TracePublicRootPeers -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m TracePublicRootPeers
tracer (TracePublicRootPeers -> m ()) -> TracePublicRootPeers -> m ()
forall a b. (a -> b) -> a -> b
$ Domain -> [(IP, Word32)] -> TracePublicRootPeers
TracePublicRootResult (DomainAccessPoint -> Domain
dapDomain DomainAccessPoint
domain) [(IP, Word32)]
result

        ((DomainAccessPoint, PeerAdvertise), [(IP, Word32)])
-> m ((DomainAccessPoint, PeerAdvertise), [(IP, Word32)])
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ((DomainAccessPoint
domain, PeerAdvertise
pa), [(IP, Word32)]
result)

    requestPublicRootPeers
      :: StrictTVar m (Resource m (Either (DNSorIOError exception) resolver))
      -> Int
      -> m (Map peerAddr PeerAdvertise, DiffTime)
    requestPublicRootPeers :: StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
-> Int -> m (Map peerAddr PeerAdvertise, DiffTime)
requestPublicRootPeers StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
resourceVar Int
_numRequested = do
        domains <- STM m (Map RelayAccessPoint PeerAdvertise)
-> m (Map RelayAccessPoint PeerAdvertise)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m (Map RelayAccessPoint PeerAdvertise)
readDomains
        traceWith tracer (TracePublicRootRelayAccessPoint domains)
        rr <- atomically $ readTVar resourceVar
        (er, rr') <- withResource rr
        atomically $ writeTVar resourceVar rr'
        case er of
          Left (DNSError DNSError
err) -> DNSError -> m (Map peerAddr PeerAdvertise, DiffTime)
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO DNSError
err
          Left (IOError  exception
err) -> exception -> m (Map peerAddr PeerAdvertise, DiffTime)
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO exception
err
          Right resolver
resolver -> do
            let lookups :: [m ((DomainAccessPoint, PeerAdvertise),
    ([DNSError], [(IP, Word32)]))]
lookups =
                  [ ((Domain -> PortNumber -> DomainAccessPoint
DomainAccessPoint Domain
domain PortNumber
port, PeerAdvertise
pa),)
                      (([DNSError], [(IP, Word32)])
 -> ((DomainAccessPoint, PeerAdvertise),
     ([DNSError], [(IP, Word32)])))
-> m ([DNSError], [(IP, Word32)])
-> m ((DomainAccessPoint, PeerAdvertise),
      ([DNSError], [(IP, Word32)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DNSSemaphore m
-> m ([DNSError], [(IP, Word32)]) -> m ([DNSError], [(IP, Word32)])
forall (m :: * -> *) a.
(MonadSTM m, MonadThrow m) =>
DNSSemaphore m -> m a -> m a
withDNSSemaphore DNSSemaphore m
dnsSemaphore
                            (ResolvConf -> resolver -> Domain -> m ([DNSError], [(IP, Word32)])
dnsLookupWithTTL
                              ResolvConf
resolvConf
                              resolver
resolver
                              Domain
domain)
                  | (RelayAccessDomain Domain
domain PortNumber
port, PeerAdvertise
pa) <- Map RelayAccessPoint PeerAdvertise
-> [(RelayAccessPoint, PeerAdvertise)]
forall k a. Map k a -> [(k, a)]
Map.assocs Map RelayAccessPoint PeerAdvertise
domains ]
            -- The timeouts here are handled by the 'lookupWithTTL'. They're
            -- configured via the DNS.ResolvConf resolvTimeout field and defaults
            -- to 3 sec.
            results <- [m ((DomainAccessPoint, PeerAdvertise),
    ([DNSError], [(IP, Word32)]))]
-> ([Async
       m
       ((DomainAccessPoint, PeerAdvertise), ([DNSError], [(IP, Word32)]))]
    -> m [((DomainAccessPoint, PeerAdvertise),
           ([DNSError], [(IP, Word32)]))])
-> m [((DomainAccessPoint, PeerAdvertise),
       ([DNSError], [(IP, Word32)]))]
forall (m :: * -> *) a b.
MonadAsync m =>
[m a] -> ([Async m a] -> m b) -> m b
withAsyncAll [m ((DomainAccessPoint, PeerAdvertise),
    ([DNSError], [(IP, Word32)]))]
lookups (STM
  m
  [((DomainAccessPoint, PeerAdvertise),
    ([DNSError], [(IP, Word32)]))]
-> m [((DomainAccessPoint, PeerAdvertise),
       ([DNSError], [(IP, Word32)]))]
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM
   m
   [((DomainAccessPoint, PeerAdvertise),
     ([DNSError], [(IP, Word32)]))]
 -> m [((DomainAccessPoint, PeerAdvertise),
        ([DNSError], [(IP, Word32)]))])
-> ([Async
       m
       ((DomainAccessPoint, PeerAdvertise), ([DNSError], [(IP, Word32)]))]
    -> STM
         m
         [((DomainAccessPoint, PeerAdvertise),
           ([DNSError], [(IP, Word32)]))])
-> [Async
      m
      ((DomainAccessPoint, PeerAdvertise), ([DNSError], [(IP, Word32)]))]
-> m [((DomainAccessPoint, PeerAdvertise),
       ([DNSError], [(IP, Word32)]))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Async
   m
   ((DomainAccessPoint, PeerAdvertise), ([DNSError], [(IP, Word32)]))
 -> STM
      m
      ((DomainAccessPoint, PeerAdvertise), ([DNSError], [(IP, Word32)])))
-> [Async
      m
      ((DomainAccessPoint, PeerAdvertise), ([DNSError], [(IP, Word32)]))]
-> STM
     m
     [((DomainAccessPoint, PeerAdvertise),
       ([DNSError], [(IP, Word32)]))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Async
  m
  ((DomainAccessPoint, PeerAdvertise), ([DNSError], [(IP, Word32)]))
-> STM
     m
     ((DomainAccessPoint, PeerAdvertise), ([DNSError], [(IP, Word32)]))
forall a. Async m a -> STM m a
forall (m :: * -> *) a. MonadAsync m => Async m a -> STM m a
waitSTM)
            results' <- mapM processResult results
            let successes = [ ( (IP -> PortNumber -> peerAddr
toPeerAddr IP
ip PortNumber
dapPortNumber, PeerAdvertise
pa)
                              , Word32
ipttl)
                            | ( (DomainAccessPoint {PortNumber
dapPortNumber :: PortNumber
dapPortNumber :: DomainAccessPoint -> PortNumber
dapPortNumber}, PeerAdvertise
pa)
                              , [(IP, Word32)]
ipttls) <- [((DomainAccessPoint, PeerAdvertise), [(IP, Word32)])]
results'
                            , (IP
ip, Word32
ipttl) <- [(IP, Word32)]
ipttls
                            ]
                !domainsIps = [(IP -> PortNumber -> peerAddr
toPeerAddr IP
ip PortNumber
port, PeerAdvertise
pa)
                              | (RelayAccessAddress IP
ip PortNumber
port, PeerAdvertise
pa) <- Map RelayAccessPoint PeerAdvertise
-> [(RelayAccessPoint, PeerAdvertise)]
forall k a. Map k a -> [(k, a)]
Map.assocs Map RelayAccessPoint PeerAdvertise
domains ]
                !ips      = [(peerAddr, PeerAdvertise)] -> Map peerAddr PeerAdvertise
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ((((peerAddr, PeerAdvertise), Word32) -> (peerAddr, PeerAdvertise))
-> [((peerAddr, PeerAdvertise), Word32)]
-> [(peerAddr, PeerAdvertise)]
forall a b. (a -> b) -> [a] -> [b]
map ((peerAddr, PeerAdvertise), Word32) -> (peerAddr, PeerAdvertise)
forall a b. (a, b) -> a
fst [((peerAddr, PeerAdvertise), Word32)]
successes) Map peerAddr PeerAdvertise
-> Map peerAddr PeerAdvertise -> Map peerAddr PeerAdvertise
forall k a. Ord k => Map k a -> Map k a -> Map k a
`Map.union` [(peerAddr, PeerAdvertise)] -> Map peerAddr PeerAdvertise
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(peerAddr, PeerAdvertise)]
domainsIps
                !ttl      = if [m ((DomainAccessPoint, PeerAdvertise),
    ([DNSError], [(IP, Word32)]))]
-> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [m ((DomainAccessPoint, PeerAdvertise),
    ([DNSError], [(IP, Word32)]))]
lookups
                               then -- Not having any peers with domains configured is not
                                    -- a DNS error.
                                    [Word32] -> DiffTime
ttlForResults [Word32
60]
                               else [Word32] -> DiffTime
ttlForResults ((((peerAddr, PeerAdvertise), Word32) -> Word32)
-> [((peerAddr, PeerAdvertise), Word32)] -> [Word32]
forall a b. (a -> b) -> [a] -> [b]
map ((peerAddr, PeerAdvertise), Word32) -> Word32
forall a b. (a, b) -> b
snd [((peerAddr, PeerAdvertise), Word32)]
successes)
            -- If all the lookups failed we'll return an empty set with a minimum
            -- TTL, and the governor will invoke its exponential backoff.
            return (ips, ttl)

-- Aux

withAsyncAll :: MonadAsync m => [m a] -> ([Async m a] -> m b) -> m b
withAsyncAll :: forall (m :: * -> *) a b.
MonadAsync m =>
[m a] -> ([Async m a] -> m b) -> m b
withAsyncAll [m a]
xs0 [Async m a] -> m b
action = [Async m a] -> [m a] -> m b
go [] [m a]
xs0
  where
    go :: [Async m a] -> [m a] -> m b
go [Async m a]
as []     = [Async m a] -> m b
action ([Async m a] -> [Async m a]
forall a. [a] -> [a]
reverse [Async m a]
as)
    go [Async m a]
as (m a
x:[m a]
xs) = m a -> (Async m a -> m b) -> m b
forall a b. m a -> (Async m a -> m b) -> m b
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync m a
x (\Async m a
a -> [Async m a] -> [m a] -> m b
go (Async m a
aAsync m a -> [Async m a] -> [Async m a]
forall a. a -> [a] -> [a]
:[Async m a]
as) [m a]
xs)

-- | Policy for TTL for positive results
ttlForResults :: [DNS.TTL] -> DiffTime

-- This case says we have a successful reply but there is no answer.
-- This covers for example non-existent TLDs since there is no authority
-- to say that they should not exist.
ttlForResults :: [Word32] -> DiffTime
ttlForResults []   = DNSError -> DiffTime -> DiffTime
ttlForDnsError DNSError
DNS.NameError DiffTime
0
ttlForResults [Word32]
ttls = DiffTime -> DiffTime
clipTTLBelow
                   (DiffTime -> DiffTime)
-> (Word32 -> DiffTime) -> Word32 -> DiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> DiffTime
clipTTLAbove
                   (DiffTime -> DiffTime)
-> (Word32 -> DiffTime) -> Word32 -> DiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> DiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Word32 -> DiffTime)
                   (Word32 -> DiffTime) -> Word32 -> DiffTime
forall a b. (a -> b) -> a -> b
$ [Word32] -> Word32
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [Word32]
ttls

-- | Limit insane TTL choices.
clipTTLAbove, clipTTLBelow :: DiffTime -> DiffTime
clipTTLBelow :: DiffTime -> DiffTime
clipTTLBelow = DiffTime -> DiffTime -> DiffTime
forall a. Ord a => a -> a -> a
max DiffTime
60     -- between 1min
clipTTLAbove :: DiffTime -> DiffTime
clipTTLAbove = DiffTime -> DiffTime -> DiffTime
forall a. Ord a => a -> a -> a
min DiffTime
86400  -- and 24hrs

-- | Policy for TTL for negative results
-- Cache negative response for 3hrs
-- Otherwise, use exponential backoff, up to a limit
ttlForDnsError :: DNS.DNSError -> DiffTime -> DiffTime
ttlForDnsError :: DNSError -> DiffTime -> DiffTime
ttlForDnsError DNSError
DNS.NameError DiffTime
_ = DiffTime
10800
ttlForDnsError DNSError
_           DiffTime
ttl = DiffTime -> DiffTime
clipTTLAbove (DiffTime
ttl DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
* DiffTime
2 DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
+ DiffTime
5)