{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE DerivingStrategies  #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Ouroboros.Network.PeerSelection.RootPeersDNS.LedgerPeers (resolveLedgerPeers) where

import Control.Monad (when)
import Control.Monad.Class.MonadAsync
import Control.Tracer (Tracer, traceWith)
import Data.IP qualified as IP
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map

import Data.Foldable (foldlM)
import Data.Set (Set)
import Data.Set qualified as Set

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadThrow


import Network.DNS qualified as DNS
import Network.Socket qualified as Socket

import Ouroboros.Network.PeerSelection.LedgerPeers.Common
import Ouroboros.Network.PeerSelection.RelayAccessPoint
import Ouroboros.Network.PeerSelection.RootPeersDNS.DNSActions (DNSActions (..),
           DNSorIOError (..), Resource (..))
import Ouroboros.Network.PeerSelection.RootPeersDNS.DNSSemaphore (DNSSemaphore,
           withDNSSemaphore)

-- | Provides DNS resolution functionality.
--
-- Concurrently resolve DNS names, respecting the 'maxDNSConcurrency' limit.
--
resolveLedgerPeers
  :: forall m peerAddr resolver exception.
     ( Ord peerAddr
     , MonadThrow m
     , MonadAsync m
     , Exception exception
     )
  => Tracer m TraceLedgerPeers
  -> (IP.IP -> Socket.PortNumber -> peerAddr)
  -> DNSSemaphore m
  -> DNS.ResolvConf
  -> DNSActions resolver exception m
  -> [DomainAccessPoint]
  -> m (Map DomainAccessPoint (Set peerAddr))
resolveLedgerPeers :: forall (m :: * -> *) peerAddr resolver exception.
(Ord peerAddr, MonadThrow m, MonadAsync m, Exception exception) =>
Tracer m TraceLedgerPeers
-> (IP -> PortNumber -> peerAddr)
-> DNSSemaphore m
-> ResolvConf
-> DNSActions resolver exception m
-> [DomainAccessPoint]
-> m (Map DomainAccessPoint (Set peerAddr))
resolveLedgerPeers Tracer m TraceLedgerPeers
tracer
                   IP -> PortNumber -> peerAddr
toPeerAddr
                   DNSSemaphore m
dnsSemaphore
                   ResolvConf
resolvConf
                   DNSActions {
                      ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource :: ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource :: forall resolver exception (m :: * -> *).
DNSActions resolver exception m
-> ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource,
                      ResolvConf -> resolver -> Domain -> m ([DNSError], [(IP, TTL)])
dnsLookupWithTTL :: ResolvConf -> resolver -> Domain -> m ([DNSError], [(IP, TTL)])
dnsLookupWithTTL :: forall resolver exception (m :: * -> *).
DNSActions resolver exception m
-> ResolvConf -> resolver -> Domain -> m ([DNSError], [(IP, TTL)])
dnsLookupWithTTL
                    }
                   [DomainAccessPoint]
domains
                   = do
    Tracer m TraceLedgerPeers -> TraceLedgerPeers -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m TraceLedgerPeers
tracer ([DomainAccessPoint] -> TraceLedgerPeers
TraceLedgerPeersDomains [DomainAccessPoint]
domains)
    rr <- ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource ResolvConf
resolvConf
    resourceVar <- newTVarIO rr
    resolveDomains resourceVar
  where
    resolveDomains
      :: StrictTVar m (Resource m (Either (DNSorIOError exception) resolver))
      -> m (Map DomainAccessPoint (Set peerAddr))
    resolveDomains :: StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
-> m (Map DomainAccessPoint (Set peerAddr))
resolveDomains StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
resourceVar = do
        rr <- STM m (Resource m (Either (DNSorIOError exception) resolver))
-> m (Resource m (Either (DNSorIOError exception) resolver))
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Resource m (Either (DNSorIOError exception) resolver))
 -> m (Resource m (Either (DNSorIOError exception) resolver)))
-> STM m (Resource m (Either (DNSorIOError exception) resolver))
-> m (Resource m (Either (DNSorIOError exception) resolver))
forall a b. (a -> b) -> a -> b
$ StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
-> STM m (Resource m (Either (DNSorIOError exception) resolver))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
resourceVar
        (er, rr') <- withResource rr
        atomically $ writeTVar resourceVar rr'
        case er of
          Left (DNSError DNSError
err) -> DNSError -> m (Map DomainAccessPoint (Set peerAddr))
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO DNSError
err
          Left (IOError  exception
err) -> exception -> m (Map DomainAccessPoint (Set peerAddr))
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO exception
err
          Right resolver
resolver -> do
            let lookups :: [m (DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
lookups =
                  [ (,) DomainAccessPoint
domain
                      (([DNSError], [(IP, TTL)])
 -> (DomainAccessPoint, ([DNSError], [(IP, TTL)])))
-> m ([DNSError], [(IP, TTL)])
-> m (DomainAccessPoint, ([DNSError], [(IP, TTL)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DNSSemaphore m
-> m ([DNSError], [(IP, TTL)]) -> m ([DNSError], [(IP, TTL)])
forall (m :: * -> *) a.
(MonadSTM m, MonadThrow m) =>
DNSSemaphore m -> m a -> m a
withDNSSemaphore DNSSemaphore m
dnsSemaphore
                            (ResolvConf -> resolver -> Domain -> m ([DNSError], [(IP, TTL)])
dnsLookupWithTTL
                              ResolvConf
resolvConf
                              resolver
resolver
                              (DomainAccessPoint -> Domain
dapDomain DomainAccessPoint
domain))
                  | DomainAccessPoint
domain <- [DomainAccessPoint]
domains ]
            -- The timeouts here are handled by the 'lookupWithTTL'. They're
            -- configured via the DNS.ResolvConf resolvTimeout field and
            -- defaults to 3 sec.
            results <- [m (DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
-> ([Async m (DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
    -> m [(DomainAccessPoint, ([DNSError], [(IP, TTL)]))])
-> m [(DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
forall (m :: * -> *) a b.
MonadAsync m =>
[m a] -> ([Async m a] -> m b) -> m b
withAsyncAll [m (DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
lookups (STM m [(DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
-> m [(DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m [(DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
 -> m [(DomainAccessPoint, ([DNSError], [(IP, TTL)]))])
-> ([Async m (DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
    -> STM m [(DomainAccessPoint, ([DNSError], [(IP, TTL)]))])
-> [Async m (DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
-> m [(DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Async m (DomainAccessPoint, ([DNSError], [(IP, TTL)]))
 -> STM m (DomainAccessPoint, ([DNSError], [(IP, TTL)])))
-> [Async m (DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
-> STM m [(DomainAccessPoint, ([DNSError], [(IP, TTL)]))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Async m (DomainAccessPoint, ([DNSError], [(IP, TTL)]))
-> STM m (DomainAccessPoint, ([DNSError], [(IP, TTL)]))
forall a. Async m a -> STM m a
forall (m :: * -> *) a. MonadAsync m => Async m a -> STM m a
waitSTM)
            foldlM processResult Map.empty results

    processResult :: Map DomainAccessPoint (Set peerAddr)
                  -> (DomainAccessPoint, ([DNS.DNSError], [(IP, DNS.TTL)]))
                  -> m (Map DomainAccessPoint (Set peerAddr))
    processResult :: Map DomainAccessPoint (Set peerAddr)
-> (DomainAccessPoint, ([DNSError], [(IP, TTL)]))
-> m (Map DomainAccessPoint (Set peerAddr))
processResult Map DomainAccessPoint (Set peerAddr)
mr (DomainAccessPoint
domain, ([DNSError]
errs, [(IP, TTL)]
ipsttls)) = do
        (DNSError -> m ()) -> [DNSError] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Tracer m TraceLedgerPeers -> TraceLedgerPeers -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m TraceLedgerPeers
tracer (TraceLedgerPeers -> m ())
-> (DNSError -> TraceLedgerPeers) -> DNSError -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Domain -> DNSError -> TraceLedgerPeers
TraceLedgerPeersFailure (DomainAccessPoint -> Domain
dapDomain DomainAccessPoint
domain))
              [DNSError]
errs
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(IP, TTL)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(IP, TTL)]
ipsttls) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
            Tracer m TraceLedgerPeers -> TraceLedgerPeers -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m TraceLedgerPeers
tracer (TraceLedgerPeers -> m ()) -> TraceLedgerPeers -> m ()
forall a b. (a -> b) -> a -> b
$ Domain -> [(IP, TTL)] -> TraceLedgerPeers
TraceLedgerPeersResult (DomainAccessPoint -> Domain
dapDomain DomainAccessPoint
domain) [(IP, TTL)]
ipsttls

        Map DomainAccessPoint (Set peerAddr)
-> m (Map DomainAccessPoint (Set peerAddr))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Map DomainAccessPoint (Set peerAddr)
 -> m (Map DomainAccessPoint (Set peerAddr)))
-> Map DomainAccessPoint (Set peerAddr)
-> m (Map DomainAccessPoint (Set peerAddr))
forall a b. (a -> b) -> a -> b
$ (Maybe (Set peerAddr) -> Maybe (Set peerAddr))
-> DomainAccessPoint
-> Map DomainAccessPoint (Set peerAddr)
-> Map DomainAccessPoint (Set peerAddr)
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
Map.alter Maybe (Set peerAddr) -> Maybe (Set peerAddr)
addFn DomainAccessPoint
domain Map DomainAccessPoint (Set peerAddr)
mr
      where
        addFn :: Maybe (Set peerAddr) -> Maybe (Set peerAddr)
        addFn :: Maybe (Set peerAddr) -> Maybe (Set peerAddr)
addFn Maybe (Set peerAddr)
Nothing =
            let ips :: [IP]
ips = ((IP, TTL) -> IP) -> [(IP, TTL)] -> [IP]
forall a b. (a -> b) -> [a] -> [b]
map (IP, TTL) -> IP
forall a b. (a, b) -> a
fst [(IP, TTL)]
ipsttls
                !addrs :: [peerAddr]
addrs = (IP -> peerAddr) -> [IP] -> [peerAddr]
forall a b. (a -> b) -> [a] -> [b]
map (\IP
ip -> IP -> PortNumber -> peerAddr
toPeerAddr IP
ip (DomainAccessPoint -> PortNumber
dapPortNumber DomainAccessPoint
domain))
                             [IP]
ips
                !addrSet :: Set peerAddr
addrSet = [peerAddr] -> Set peerAddr
forall a. Ord a => [a] -> Set a
Set.fromList [peerAddr]
addrs in
            Set peerAddr -> Maybe (Set peerAddr)
forall a. a -> Maybe a
Just Set peerAddr
addrSet
        addFn (Just Set peerAddr
addrSet) =
            let ips :: [IP]
ips = ((IP, TTL) -> IP) -> [(IP, TTL)] -> [IP]
forall a b. (a -> b) -> [a] -> [b]
map (IP, TTL) -> IP
forall a b. (a, b) -> a
fst [(IP, TTL)]
ipsttls
                !addrs :: [peerAddr]
addrs = (IP -> peerAddr) -> [IP] -> [peerAddr]
forall a b. (a -> b) -> [a] -> [b]
map (\IP
ip -> IP -> PortNumber -> peerAddr
toPeerAddr IP
ip (DomainAccessPoint -> PortNumber
dapPortNumber DomainAccessPoint
domain))
                             [IP]
ips
                !addrSet' :: Set peerAddr
addrSet' = Set peerAddr -> Set peerAddr -> Set peerAddr
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set peerAddr
addrSet ([peerAddr] -> Set peerAddr
forall a. Ord a => [a] -> Set a
Set.fromList [peerAddr]
addrs) in
            Set peerAddr -> Maybe (Set peerAddr)
forall a. a -> Maybe a
Just Set peerAddr
addrSet'

withAsyncAll :: MonadAsync m => [m a] -> ([Async m a] -> m b) -> m b
withAsyncAll :: forall (m :: * -> *) a b.
MonadAsync m =>
[m a] -> ([Async m a] -> m b) -> m b
withAsyncAll [m a]
xs0 [Async m a] -> m b
action = [Async m a] -> [m a] -> m b
go [] [m a]
xs0
  where
    go :: [Async m a] -> [m a] -> m b
go [Async m a]
as []     = [Async m a] -> m b
action ([Async m a] -> [Async m a]
forall a. [a] -> [a]
reverse [Async m a]
as)
    go [Async m a]
as (m a
x:[m a]
xs) = m a -> (Async m a -> m b) -> m b
forall a b. m a -> (Async m a -> m b) -> m b
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync m a
x (\Async m a
a -> [Async m a] -> [m a] -> m b
go (Async m a
aAsync m a -> [Async m a] -> [Async m a]
forall a. a -> [a] -> [a]
:[Async m a]
as) [m a]
xs)