{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE DerivingStrategies  #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}

module Ouroboros.Network.PeerSelection.RootPeersDNS.LedgerPeers (resolveLedgerPeers) where

import Control.Monad.Class.MonadAsync
import Data.List qualified as List (foldl')
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map

import Data.Set (Set)
import Data.Set qualified as Set

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadThrow

import Network.DNS qualified as DNS
import System.Random

import Ouroboros.Network.PeerSelection.LedgerPeers.Type
import Ouroboros.Network.PeerSelection.RelayAccessPoint
import Ouroboros.Network.PeerSelection.RootPeersDNS.DNSActions
import Ouroboros.Network.PeerSelection.RootPeersDNS.DNSSemaphore (DNSSemaphore,
           withDNSSemaphore)


-- | Provides DNS resolution functionality.
--
-- Concurrently resolve DNS names, respecting the 'maxDNSConcurrency' limit.
--
resolveLedgerPeers
  :: forall m peerAddr resolver exception.
     ( Ord peerAddr
     , MonadThrow m
     , MonadAsync m
     , Exception exception
     )
  => DNSSemaphore m
  -> DNS.ResolvConf
  -> DNSActions peerAddr resolver exception m
  -> LedgerPeersKind
  -> [RelayAccessPoint]
  -> StdGen
  -> m (Map DNS.Domain (Set peerAddr))
resolveLedgerPeers :: forall (m :: * -> *) peerAddr resolver exception.
(Ord peerAddr, MonadThrow m, MonadAsync m, Exception exception) =>
DNSSemaphore m
-> ResolvConf
-> DNSActions peerAddr resolver exception m
-> LedgerPeersKind
-> [RelayAccessPoint]
-> StdGen
-> m (Map Domain (Set peerAddr))
resolveLedgerPeers DNSSemaphore m
dnsSemaphore
                   ResolvConf
resolvConf
                   DNSActions {
                      ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource :: ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource :: forall peerAddr resolver exception (m :: * -> *).
DNSActions peerAddr resolver exception m
-> ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource,
                      DNSPeersKind
-> RelayAccessPoint
-> ResolvConf
-> resolver
-> StdGen
-> m (DNSLookupResult peerAddr)
dnsLookupWithTTL :: DNSPeersKind
-> RelayAccessPoint
-> ResolvConf
-> resolver
-> StdGen
-> m (DNSLookupResult peerAddr)
dnsLookupWithTTL :: forall peerAddr resolver exception (m :: * -> *).
DNSActions peerAddr resolver exception m
-> DNSPeersKind
-> RelayAccessPoint
-> ResolvConf
-> resolver
-> StdGen
-> m (DNSLookupResult peerAddr)
dnsLookupWithTTL
                    }
                   LedgerPeersKind
peerKind
                   [RelayAccessPoint]
domains
                   StdGen
rng
                   = do
    rr <- ResolvConf
-> m (Resource m (Either (DNSorIOError exception) resolver))
dnsResolverResource ResolvConf
resolvConf
    resourceVar <- newTVarIO rr
    resolveDomains resourceVar
  where
    resolveDomains
      :: StrictTVar m (Resource m (Either (DNSorIOError exception) resolver))
      -> m (Map DNS.Domain (Set peerAddr))
    resolveDomains :: StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
-> m (Map Domain (Set peerAddr))
resolveDomains StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
resourceVar = do
        rr <- StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
-> m (Resource m (Either (DNSorIOError exception) resolver))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> m a
readTVarIO StrictTVar
  m (Resource m (Either (DNSorIOError exception) resolver))
resourceVar
        (er, rr') <- withResource rr
        atomically $ writeTVar resourceVar rr'
        case er of
          Left (DNSError DNSError
err) -> DNSError -> m (Map Domain (Set peerAddr))
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO DNSError
err
          Left (IOError  exception
err) -> exception -> m (Map Domain (Set peerAddr))
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO exception
err
          Right resolver
resolver -> do
            let lookups :: [m (Domain, DNSLookupResult peerAddr)]
lookups =
                  [ (Domain
domain',) (DNSLookupResult peerAddr -> (Domain, DNSLookupResult peerAddr))
-> m (DNSLookupResult peerAddr)
-> m (Domain, DNSLookupResult peerAddr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DNSSemaphore m
-> m (DNSLookupResult peerAddr) -> m (DNSLookupResult peerAddr)
forall (m :: * -> *) a.
(MonadSTM m, MonadThrow m) =>
DNSSemaphore m -> m a -> m a
withDNSSemaphore DNSSemaphore m
dnsSemaphore
                                     (DNSPeersKind
-> RelayAccessPoint
-> ResolvConf
-> resolver
-> StdGen
-> m (DNSLookupResult peerAddr)
dnsLookupWithTTL
                                       (LedgerPeersKind -> DNSPeersKind
DNSLedgerPeer LedgerPeersKind
peerKind)
                                       RelayAccessPoint
domain
                                       ResolvConf
resolvConf
                                       resolver
resolver
                                       StdGen
rng)
                  | RelayAccessPoint
domain <- [RelayAccessPoint]
domains
                  , Domain
domain' <-
                      case RelayAccessPoint
domain of
                        RelayAccessAddress {}  -> []
                        RelayAccessDomain Domain
d PortNumber
_p -> [Domain
d]
                        RelayAccessSRVDomain Domain
d -> [Domain
d]
                  ]
            -- The timeouts here are handled by the 'lookupWithTTL'. They're
            -- configured via the DNS.ResolvConf resolvTimeout field and
            -- defaults to 3 sec.
            results <- [m (Domain, DNSLookupResult peerAddr)]
-> ([Async m (Domain, DNSLookupResult peerAddr)]
    -> m [(Domain, DNSLookupResult peerAddr)])
-> m [(Domain, DNSLookupResult peerAddr)]
forall (m :: * -> *) a b.
MonadAsync m =>
[m a] -> ([Async m a] -> m b) -> m b
withAsyncAll [m (Domain, DNSLookupResult peerAddr)]
lookups (STM m [(Domain, DNSLookupResult peerAddr)]
-> m [(Domain, DNSLookupResult peerAddr)]
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m [(Domain, DNSLookupResult peerAddr)]
 -> m [(Domain, DNSLookupResult peerAddr)])
-> ([Async m (Domain, DNSLookupResult peerAddr)]
    -> STM m [(Domain, DNSLookupResult peerAddr)])
-> [Async m (Domain, DNSLookupResult peerAddr)]
-> m [(Domain, DNSLookupResult peerAddr)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Async m (Domain, DNSLookupResult peerAddr)
 -> STM m (Domain, DNSLookupResult peerAddr))
-> [Async m (Domain, DNSLookupResult peerAddr)]
-> STM m [(Domain, DNSLookupResult peerAddr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Async m (Domain, DNSLookupResult peerAddr)
-> STM m (Domain, DNSLookupResult peerAddr)
forall a. Async m a -> STM m a
forall (m :: * -> *) a. MonadAsync m => Async m a -> STM m a
waitSTM)
            return $ List.foldl' processResult Map.empty results

    processResult :: Map DNS.Domain (Set peerAddr)
                  -> (DNS.Domain, DNSLookupResult peerAddr)
                  -> Map DNS.Domain (Set peerAddr)
    processResult :: Map Domain (Set peerAddr)
-> (Domain, DNSLookupResult peerAddr) -> Map Domain (Set peerAddr)
processResult Map Domain (Set peerAddr)
mr (Domain
domain , DNSLookupResult peerAddr
addrttls) = do
        (Maybe (Set peerAddr) -> Maybe (Set peerAddr))
-> Domain -> Map Domain (Set peerAddr) -> Map Domain (Set peerAddr)
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
Map.alter (DNSLookupResult peerAddr
-> Maybe (Set peerAddr) -> Maybe (Set peerAddr)
addFn DNSLookupResult peerAddr
addrttls) Domain
domain Map Domain (Set peerAddr)
mr

    addFn :: DNSLookupResult peerAddr
          -> Maybe (Set peerAddr)
          -> Maybe (Set peerAddr)
    addFn :: DNSLookupResult peerAddr
-> Maybe (Set peerAddr) -> Maybe (Set peerAddr)
addFn (Left [DNSError]
_) Maybe (Set peerAddr)
Nothing = Set peerAddr -> Maybe (Set peerAddr)
forall a. a -> Maybe a
Just Set peerAddr
forall a. Set a
Set.empty
    addFn (Left [DNSError]
_) Maybe (Set peerAddr)
addrs = Maybe (Set peerAddr)
addrs
    addFn (Right [(peerAddr, TTL)]
addrttls) Maybe (Set peerAddr)
Nothing =
      Set peerAddr -> Maybe (Set peerAddr)
forall a. a -> Maybe a
Just (Set peerAddr -> Maybe (Set peerAddr))
-> Set peerAddr -> Maybe (Set peerAddr)
forall a b. (a -> b) -> a -> b
$ [peerAddr] -> Set peerAddr
forall a. Ord a => [a] -> Set a
Set.fromList ([peerAddr] -> Set peerAddr)
-> ([(peerAddr, TTL)] -> [peerAddr])
-> [(peerAddr, TTL)]
-> Set peerAddr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((peerAddr, TTL) -> peerAddr) -> [(peerAddr, TTL)] -> [peerAddr]
forall a b. (a -> b) -> [a] -> [b]
map (peerAddr, TTL) -> peerAddr
forall a b. (a, b) -> a
fst ([(peerAddr, TTL)] -> Set peerAddr)
-> [(peerAddr, TTL)] -> Set peerAddr
forall a b. (a -> b) -> a -> b
$ [(peerAddr, TTL)]
addrttls
    addFn (Right [(peerAddr, TTL)]
addrttls) (Just Set peerAddr
addrSet) =
      let !addrSet' :: Set peerAddr
addrSet' = Set peerAddr -> Set peerAddr -> Set peerAddr
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set peerAddr
addrSet ([peerAddr] -> Set peerAddr
forall a. Ord a => [a] -> Set a
Set.fromList ([peerAddr] -> Set peerAddr)
-> ([(peerAddr, TTL)] -> [peerAddr])
-> [(peerAddr, TTL)]
-> Set peerAddr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((peerAddr, TTL) -> peerAddr) -> [(peerAddr, TTL)] -> [peerAddr]
forall a b. (a -> b) -> [a] -> [b]
map (peerAddr, TTL) -> peerAddr
forall a b. (a, b) -> a
fst ([(peerAddr, TTL)] -> Set peerAddr)
-> [(peerAddr, TTL)] -> Set peerAddr
forall a b. (a -> b) -> a -> b
$ [(peerAddr, TTL)]
addrttls)
       in Set peerAddr -> Maybe (Set peerAddr)
forall a. a -> Maybe a
Just Set peerAddr
addrSet'

withAsyncAll :: MonadAsync m => [m a] -> ([Async m a] -> m b) -> m b
withAsyncAll :: forall (m :: * -> *) a b.
MonadAsync m =>
[m a] -> ([Async m a] -> m b) -> m b
withAsyncAll [m a]
xs0 [Async m a] -> m b
action = [Async m a] -> [m a] -> m b
go [] [m a]
xs0
  where
    go :: [Async m a] -> [m a] -> m b
go [Async m a]
as []     = [Async m a] -> m b
action ([Async m a] -> [Async m a]
forall a. [a] -> [a]
reverse [Async m a]
as)
    go [Async m a]
as (m a
x:[m a]
xs) = m a -> (Async m a -> m b) -> m b
forall a b. m a -> (Async m a -> m b) -> m b
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync m a
x (\Async m a
a -> [Async m a] -> [m a] -> m b
go (Async m a
aAsync m a -> [Async m a] -> [Async m a]
forall a. a -> [a] -> [a]
:[Async m a]
as) [m a]
xs)