-- TODO: provide a corrss platform  network bindings using `network` or
-- `Win32-network`, to get rid of CPP.
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE DeriveFoldable      #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE NumericUnderscores  #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.NTP.Client.Query (
    NtpSettings(..)
  , NtpStatus(..)
  , CompletedNtpStatus(..)
  , ntpQuery

  -- * Logging
  , NtpTrace(..)
  , IPVersion (..)
  , ResultOrFailure (..)
  ) where

import           Control.Concurrent (threadDelay)
import           Control.Concurrent.Async
import           Control.Concurrent.STM
import           Control.Exception (Exception (..), IOException, bracket, catch, throwIO)
import           Control.Monad (foldM, forM_, replicateM_, when)
import           Control.Tracer
import           Data.Binary (decodeOrFail, encode)
import           Data.Bifunctor (bimap)
import qualified Data.ByteString.Lazy as LBS
import           Data.Either (partitionEithers)
import           Data.Functor (void)
import           Data.Foldable (Foldable (..), fold)
import           Data.Maybe
import           Network.Socket (Socket, SockAddr (..), AddrInfo (..))
import qualified Network.Socket as Socket
#if !defined(mingw32_HOST_OS)
import qualified Network.Socket.ByteString as Socket.ByteString (recvFrom, sendManyTo)
#else
import qualified System.Win32.Async.Socket.ByteString as Win32.Async
#endif
import           System.IOManager
import           Network.NTP.Client.Packet
                                    ( NtpPacket
                                    , mkNtpPacket
                                    , ntpPacketSize
                                    , Microsecond
                                    , NtpOffset (..)
                                    , getCurrentTime
                                    , clockOffsetPure
                                    )

-- | Settings of the ntp client.
--
data NtpSettings = NtpSettings
    { NtpSettings -> [String]
ntpServers                 :: [String]
      -- ^ List of server addresses. At least three servers are needed.

    , NtpSettings -> Int
ntpRequiredNumberOfResults :: Int
      -- ^ Minimum number of results to compute the offset, this should be less
      -- or equal to the length of 'ntpServers' (we send a single @tnp@ packet
      -- \/ query to a each server, if the dns name resolves to many addresses
      -- we pick the first one).

    , NtpSettings -> Microsecond
ntpResponseTimeout         :: Microsecond
      -- ^ Timeout for receiving a response from an @ntp@ server.

    , NtpSettings -> Microsecond
ntpPollDelay               :: Microsecond
      -- ^ How long to wait between two rounds of requests.  This should be set
      -- to something of an order of one hour,  @ntp@ servers should not be
      -- abused.
    }


-- | The Ntp client state: either cached results is availbale, or the ntp
-- client is engaged in ntp-protocol or there was a failure: e.g. connection
-- lost, or dns lookups did not return at least `ntpRequiredNumberOfResults`
-- addresses. 
--
data NtpStatus =
      -- | The difference between NTP time and local system time
      NtpDrift !NtpOffset
      -- | NTP client has send requests to the servers
    | NtpSyncPending
      -- | NTP is not available: the client has not received any respond within
      -- `ntpResponseTimeout` from at least `ntpRequiredNumberOfResults`
      -- servers.
    | NtpSyncUnavailable deriving (NtpStatus -> NtpStatus -> Bool
(NtpStatus -> NtpStatus -> Bool)
-> (NtpStatus -> NtpStatus -> Bool) -> Eq NtpStatus
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NtpStatus -> NtpStatus -> Bool
== :: NtpStatus -> NtpStatus -> Bool
$c/= :: NtpStatus -> NtpStatus -> Bool
/= :: NtpStatus -> NtpStatus -> Bool
Eq, Int -> NtpStatus -> ShowS
[NtpStatus] -> ShowS
NtpStatus -> String
(Int -> NtpStatus -> ShowS)
-> (NtpStatus -> String)
-> ([NtpStatus] -> ShowS)
-> Show NtpStatus
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NtpStatus -> ShowS
showsPrec :: Int -> NtpStatus -> ShowS
$cshow :: NtpStatus -> String
show :: NtpStatus -> String
$cshowList :: [NtpStatus] -> ShowS
showList :: [NtpStatus] -> ShowS
Show)


-- | A version of 'NtpStatus' specialized to the two "completed" states of that
-- type: 'NtpDrift' and 'NtpSyncUnavailable'.
--
data CompletedNtpStatus =
    -- | Corresponds to 'NtpDrift'
    CNtpDrift !NtpOffset
    -- | Corresponds to 'NtpSyncUnavailable'
  | CNtpSyncUnavailable


-- | Wait for at least three replies and report the minimum of the reported
-- offsets.
--
minimumOfSome :: Int -> [NtpOffset] -> Maybe NtpOffset
minimumOfSome :: Int -> [NtpOffset] -> Maybe NtpOffset
minimumOfSome Int
threshold [NtpOffset]
l
    = if [NtpOffset] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [NtpOffset]
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
threshold
        then NtpOffset -> Maybe NtpOffset
forall a. a -> Maybe a
Just (NtpOffset -> Maybe NtpOffset) -> NtpOffset -> Maybe NtpOffset
forall a b. (a -> b) -> a -> b
$ [NtpOffset] -> NtpOffset
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [NtpOffset]
l
        else Maybe NtpOffset
forall a. Maybe a
Nothing


-- | Get a list local udp addresses.
--
udpLocalAddresses :: IO [AddrInfo]
udpLocalAddresses :: IO [AddrInfo]
udpLocalAddresses = Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
Socket.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) Maybe String
forall a. Maybe a
Nothing (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ PortNumber -> String
forall a. Show a => a -> String
show PortNumber
port)
  where
    hints :: AddrInfo
hints = AddrInfo
Socket.defaultHints
          { addrFlags = [Socket.AI_PASSIVE]
          , addrSocketType = Socket.Datagram
          }
    port :: PortNumber
port = PortNumber
Socket.defaultPort

-- | Resolve dns names, return valid ntp 'SockAddr'es.
--
lookupNtpServers :: Tracer IO NtpTrace -> NtpSettings -> IO ([SockAddr], [SockAddr])
lookupNtpServers :: Tracer IO NtpTrace -> NtpSettings -> IO ([SockAddr], [SockAddr])
lookupNtpServers Tracer IO NtpTrace
tracer NtpSettings { [String]
ntpServers :: NtpSettings -> [String]
ntpServers :: [String]
ntpServers, Int
ntpRequiredNumberOfResults :: NtpSettings -> Int
ntpRequiredNumberOfResults :: Int
ntpRequiredNumberOfResults } = do
    addrs@(ipv4s, ipv6s) <- (([SockAddr], [SockAddr]) -> String -> IO ([SockAddr], [SockAddr]))
-> ([SockAddr], [SockAddr])
-> [String]
-> IO ([SockAddr], [SockAddr])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([SockAddr], [SockAddr]) -> String -> IO ([SockAddr], [SockAddr])
fn ([], []) [String]
ntpServers
    when (length (ipv4s ++ ipv6s) < ntpRequiredNumberOfResults) $ do
      -- TODO: this message is useless as it is, it should report addresses we
      -- could not resolve.
      traceWith tracer $ NtpTraceLookupsFails
      ioError $ userError "lookup NTP servers failed"
    pure addrs
  where
    fn :: ([SockAddr], [SockAddr]) -> String -> IO ([SockAddr], [SockAddr])
fn ([SockAddr]
as, [SockAddr]
bs) String
host = do
      addrs <- Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
Socket.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (String -> Maybe String
forall a. a -> Maybe a
Just String
host) Maybe String
forall a. Maybe a
Nothing
      case bimap listToMaybe listToMaybe $ partitionAddrInfos addrs of
          (Maybe AddrInfo
mipv4, Maybe AddrInfo
mipv6) ->
            ([SockAddr], [SockAddr]) -> IO ([SockAddr], [SockAddr])
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (([SockAddr], [SockAddr]) -> IO ([SockAddr], [SockAddr]))
-> ([SockAddr], [SockAddr]) -> IO ([SockAddr], [SockAddr])
forall a b. (a -> b) -> a -> b
$
              ( (SockAddr -> SockAddr
setNtpPort (SockAddr -> SockAddr)
-> (AddrInfo -> SockAddr) -> AddrInfo -> SockAddr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AddrInfo -> SockAddr
Socket.addrAddress (AddrInfo -> SockAddr) -> [AddrInfo] -> [SockAddr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AddrInfo -> [AddrInfo]
forall a. Maybe a -> [a]
maybeToList Maybe AddrInfo
mipv4) [SockAddr] -> [SockAddr] -> [SockAddr]
forall a. [a] -> [a] -> [a]
++ [SockAddr]
as
              , (SockAddr -> SockAddr
setNtpPort (SockAddr -> SockAddr)
-> (AddrInfo -> SockAddr) -> AddrInfo -> SockAddr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AddrInfo -> SockAddr
Socket.addrAddress (AddrInfo -> SockAddr) -> [AddrInfo] -> [SockAddr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AddrInfo -> [AddrInfo]
forall a. Maybe a -> [a]
maybeToList Maybe AddrInfo
mipv6) [SockAddr] -> [SockAddr] -> [SockAddr]
forall a. [a] -> [a] -> [a]
++ [SockAddr]
bs
              )

    setNtpPort :: SockAddr ->  SockAddr
    setNtpPort :: SockAddr -> SockAddr
setNtpPort SockAddr
addr = case SockAddr
addr of
        (SockAddrInet  PortNumber
_ HostAddress
host)            -> PortNumber -> HostAddress -> SockAddr
SockAddrInet  PortNumber
ntpPort HostAddress
host
        (SockAddrInet6 PortNumber
_ HostAddress
flow HostAddress6
host HostAddress
scope) -> PortNumber
-> HostAddress -> HostAddress6 -> HostAddress -> SockAddr
SockAddrInet6 PortNumber
ntpPort HostAddress
flow HostAddress6
host HostAddress
scope
        SockAddr
sockAddr                          -> SockAddr
sockAddr
      where
        ntpPort :: Socket.PortNumber
        ntpPort :: PortNumber
ntpPort = PortNumber
123

    -- The library uses 'Socket.AI_ADDRCONFIG' as simple test if IPv4 or IPv6 are configured.
    -- According to the documentation, 'Socket.AI_ADDRCONFIG' is not available on all platforms,
    -- but it is expected to work on win32, Mac OS X and Linux.
    hints :: AddrInfo
hints =
      AddrInfo
Socket.defaultHints
            { addrSocketType = Socket.Datagram
            , addrFlags =
                if Socket.addrInfoFlagImplemented Socket.AI_ADDRCONFIG
                  then [Socket.AI_ADDRCONFIG]
                  else []
            }


-- | Like 'waithCath', but re-throws all non 'IOException's.
--
waitCatchIOException :: Async a -> IO (Either IOException a)
waitCatchIOException :: forall a. Async a -> IO (Either IOException a)
waitCatchIOException Async a
a =
    Async a -> IO (Either SomeException a)
forall a. Async a -> IO (Either SomeException a)
waitCatch Async a
a IO (Either SomeException a)
-> (Either SomeException a -> IO (Either IOException a))
-> IO (Either IOException a)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Left SomeException
err ->
        case SomeException -> Maybe IOException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
err of
          Just IOException
ioerr -> Either IOException a -> IO (Either IOException a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IOException -> Either IOException a
forall a b. a -> Either a b
Left IOException
ioerr)
          Maybe IOException
Nothing    -> SomeException -> IO (Either IOException a)
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO SomeException
err
      Right a
x -> Either IOException a -> IO (Either IOException a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Either IOException a
forall a b. b -> Either a b
Right a
x)


-- | Partition 'AddrInfo` into ipv4 and ipv6 addresses.
--
partitionAddrInfos :: [AddrInfo] -> ([AddrInfo], [AddrInfo])
partitionAddrInfos :: [AddrInfo] -> ([AddrInfo], [AddrInfo])
partitionAddrInfos = [Either AddrInfo AddrInfo] -> ([AddrInfo], [AddrInfo])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either AddrInfo AddrInfo] -> ([AddrInfo], [AddrInfo]))
-> ([AddrInfo] -> [Either AddrInfo AddrInfo])
-> [AddrInfo]
-> ([AddrInfo], [AddrInfo])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AddrInfo -> Maybe (Either AddrInfo AddrInfo))
-> [AddrInfo] -> [Either AddrInfo AddrInfo]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe AddrInfo -> Maybe (Either AddrInfo AddrInfo)
fn
  where
    fn :: AddrInfo -> Maybe (Either AddrInfo AddrInfo)
    fn :: AddrInfo -> Maybe (Either AddrInfo AddrInfo)
fn AddrInfo
a | AddrInfo -> Family
Socket.addrFamily AddrInfo
a Family -> Family -> Bool
forall a. Eq a => a -> a -> Bool
== Family
Socket.AF_INET  = Either AddrInfo AddrInfo -> Maybe (Either AddrInfo AddrInfo)
forall a. a -> Maybe a
Just (AddrInfo -> Either AddrInfo AddrInfo
forall a b. a -> Either a b
Left AddrInfo
a)
         | AddrInfo -> Family
Socket.addrFamily AddrInfo
a Family -> Family -> Bool
forall a. Eq a => a -> a -> Bool
== Family
Socket.AF_INET6 = Either AddrInfo AddrInfo -> Maybe (Either AddrInfo AddrInfo)
forall a. a -> Maybe a
Just (AddrInfo -> Either AddrInfo AddrInfo
forall a b. b -> Either a b
Right AddrInfo
a)
         | Bool
otherwise                              = Maybe (Either AddrInfo AddrInfo)
forall a. Maybe a
Nothing



-- | A tag which describes which version of the ip protocol was used.
--
data IPVersion = IPv4 | IPv6
    deriving (IPVersion -> IPVersion -> Bool
(IPVersion -> IPVersion -> Bool)
-> (IPVersion -> IPVersion -> Bool) -> Eq IPVersion
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: IPVersion -> IPVersion -> Bool
== :: IPVersion -> IPVersion -> Bool
$c/= :: IPVersion -> IPVersion -> Bool
/= :: IPVersion -> IPVersion -> Bool
Eq, Int -> IPVersion -> ShowS
[IPVersion] -> ShowS
IPVersion -> String
(Int -> IPVersion -> ShowS)
-> (IPVersion -> String)
-> ([IPVersion] -> ShowS)
-> Show IPVersion
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> IPVersion -> ShowS
showsPrec :: Int -> IPVersion -> ShowS
$cshow :: IPVersion -> String
show :: IPVersion -> String
$cshowList :: [IPVersion] -> ShowS
showList :: [IPVersion] -> ShowS
Show)


-- | Result of two threads running concurrently.
--
data ResultOrFailure a
    = BothSucceeded !a
    -- ^ both threads suceeded
    | SuccessAndFailure !a !IPVersion !IOException
    -- ^ one of the threads errors. 'IPVersion' indicates which one.
    | BothFailed !IOException !IOException
    -- ^ both threads failed
    deriving (ResultOrFailure a -> ResultOrFailure a -> Bool
(ResultOrFailure a -> ResultOrFailure a -> Bool)
-> (ResultOrFailure a -> ResultOrFailure a -> Bool)
-> Eq (ResultOrFailure a)
forall a. Eq a => ResultOrFailure a -> ResultOrFailure a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => ResultOrFailure a -> ResultOrFailure a -> Bool
== :: ResultOrFailure a -> ResultOrFailure a -> Bool
$c/= :: forall a. Eq a => ResultOrFailure a -> ResultOrFailure a -> Bool
/= :: ResultOrFailure a -> ResultOrFailure a -> Bool
Eq, (forall m. Monoid m => ResultOrFailure m -> m)
-> (forall m a. Monoid m => (a -> m) -> ResultOrFailure a -> m)
-> (forall m a. Monoid m => (a -> m) -> ResultOrFailure a -> m)
-> (forall a b. (a -> b -> b) -> b -> ResultOrFailure a -> b)
-> (forall a b. (a -> b -> b) -> b -> ResultOrFailure a -> b)
-> (forall b a. (b -> a -> b) -> b -> ResultOrFailure a -> b)
-> (forall b a. (b -> a -> b) -> b -> ResultOrFailure a -> b)
-> (forall a. (a -> a -> a) -> ResultOrFailure a -> a)
-> (forall a. (a -> a -> a) -> ResultOrFailure a -> a)
-> (forall a. ResultOrFailure a -> [a])
-> (forall a. ResultOrFailure a -> Bool)
-> (forall a. ResultOrFailure a -> Int)
-> (forall a. Eq a => a -> ResultOrFailure a -> Bool)
-> (forall a. Ord a => ResultOrFailure a -> a)
-> (forall a. Ord a => ResultOrFailure a -> a)
-> (forall a. Num a => ResultOrFailure a -> a)
-> (forall a. Num a => ResultOrFailure a -> a)
-> Foldable ResultOrFailure
forall a. Eq a => a -> ResultOrFailure a -> Bool
forall a. Num a => ResultOrFailure a -> a
forall a. Ord a => ResultOrFailure a -> a
forall m. Monoid m => ResultOrFailure m -> m
forall a. ResultOrFailure a -> Bool
forall a. ResultOrFailure a -> Int
forall a. ResultOrFailure a -> [a]
forall a. (a -> a -> a) -> ResultOrFailure a -> a
forall m a. Monoid m => (a -> m) -> ResultOrFailure a -> m
forall b a. (b -> a -> b) -> b -> ResultOrFailure a -> b
forall a b. (a -> b -> b) -> b -> ResultOrFailure a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
$cfold :: forall m. Monoid m => ResultOrFailure m -> m
fold :: forall m. Monoid m => ResultOrFailure m -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> ResultOrFailure a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> ResultOrFailure a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> ResultOrFailure a -> m
foldMap' :: forall m a. Monoid m => (a -> m) -> ResultOrFailure a -> m
$cfoldr :: forall a b. (a -> b -> b) -> b -> ResultOrFailure a -> b
foldr :: forall a b. (a -> b -> b) -> b -> ResultOrFailure a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> ResultOrFailure a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> ResultOrFailure a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> ResultOrFailure a -> b
foldl :: forall b a. (b -> a -> b) -> b -> ResultOrFailure a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> ResultOrFailure a -> b
foldl' :: forall b a. (b -> a -> b) -> b -> ResultOrFailure a -> b
$cfoldr1 :: forall a. (a -> a -> a) -> ResultOrFailure a -> a
foldr1 :: forall a. (a -> a -> a) -> ResultOrFailure a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> ResultOrFailure a -> a
foldl1 :: forall a. (a -> a -> a) -> ResultOrFailure a -> a
$ctoList :: forall a. ResultOrFailure a -> [a]
toList :: forall a. ResultOrFailure a -> [a]
$cnull :: forall a. ResultOrFailure a -> Bool
null :: forall a. ResultOrFailure a -> Bool
$clength :: forall a. ResultOrFailure a -> Int
length :: forall a. ResultOrFailure a -> Int
$celem :: forall a. Eq a => a -> ResultOrFailure a -> Bool
elem :: forall a. Eq a => a -> ResultOrFailure a -> Bool
$cmaximum :: forall a. Ord a => ResultOrFailure a -> a
maximum :: forall a. Ord a => ResultOrFailure a -> a
$cminimum :: forall a. Ord a => ResultOrFailure a -> a
minimum :: forall a. Ord a => ResultOrFailure a -> a
$csum :: forall a. Num a => ResultOrFailure a -> a
sum :: forall a. Num a => ResultOrFailure a -> a
$cproduct :: forall a. Num a => ResultOrFailure a -> a
product :: forall a. Num a => ResultOrFailure a -> a
Foldable)

instance Show a => Show (ResultOrFailure a) where
    show :: ResultOrFailure a -> String
show (BothSucceeded a
a) = String
"BothSucceded " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
a
    show (SuccessAndFailure a
a IPVersion
ipVersion IOException
e) = [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
      [ String
"SuccessAndFailure "
      , a -> String
forall a. Show a => a -> String
show a
a
      , String
" "
      -- group ipVersion and error together, to indicated that the ipversion is
      -- about which thread errored.
      , (IPVersion, IOException) -> String
forall a. Show a => a -> String
show (IPVersion
ipVersion, IOException
e)
      ]
    show (BothFailed IOException
e4 IOException
e6) = [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
      [ String
"BothFailed "
      , IOException -> String
forall a. Show a => a -> String
show IOException
e4
      , String
" "
      , IOException -> String
forall a. Show a => a -> String
show IOException
e6
      ]

-- | Perform a series of NTP queries: one for each dns name.  Resolve each dns
-- name, get local addresses: both IPv4 and IPv6 and engage in ntp protocol
-- towards one ip address per address family per dns name, but only for address
-- families for which we have a local address.  This is to avoid trying to send
-- IPv4\/6 requests if IPv4\/6 gateway is not configured.
--
-- We produce a 'CompletedNtpStatus' rather than an 'NtpStatus' because we would
-- never construct an 'NtpStatus' using 'NtpSyncPending', so callers can avoid
-- catching and killing that case.
--
-- It may throw an `IOException`:
--
-- * if neither IPv4 nor IPv6 address is configured
-- * if network I/O errors 
--
ntpQuery
    :: IOManager
    -> Tracer IO NtpTrace
    -> NtpSettings
    -> IO CompletedNtpStatus
ntpQuery :: IOManager
-> Tracer IO NtpTrace -> NtpSettings -> IO CompletedNtpStatus
ntpQuery IOManager
ioManager Tracer IO NtpTrace
tracer ntpSettings :: NtpSettings
ntpSettings@NtpSettings { Int
ntpRequiredNumberOfResults :: NtpSettings -> Int
ntpRequiredNumberOfResults :: Int
ntpRequiredNumberOfResults } = do
    Tracer IO NtpTrace -> NtpTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO NtpTrace
tracer NtpTrace
NtpTraceClientStartQuery
    (v4Servers,   v6Servers) <- Tracer IO NtpTrace -> NtpSettings -> IO ([SockAddr], [SockAddr])
lookupNtpServers Tracer IO NtpTrace
tracer NtpSettings
ntpSettings
    localAddrs <- udpLocalAddresses
    (v4LocalAddr, v6LocalAddr)
      <- case partitionAddrInfos localAddrs of
          ([], []) -> do
            Tracer IO NtpTrace -> NtpTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO NtpTrace
tracer NtpTrace
NtpTraceNoLocalAddr
            IOException -> IO (Maybe AddrInfo, Maybe AddrInfo)
forall a. IOException -> IO a
ioError (IOException -> IO (Maybe AddrInfo, Maybe AddrInfo))
-> IOException -> IO (Maybe AddrInfo, Maybe AddrInfo)
forall a b. (a -> b) -> a -> b
$ String -> IOException
userError String
"no local address IPv4 and IPv6"
          ([AddrInfo]
ipv4s, [AddrInfo]
ipv6s) -> (Maybe AddrInfo, Maybe AddrInfo)
-> IO (Maybe AddrInfo, Maybe AddrInfo)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Maybe AddrInfo, Maybe AddrInfo)
 -> IO (Maybe AddrInfo, Maybe AddrInfo))
-> (Maybe AddrInfo, Maybe AddrInfo)
-> IO (Maybe AddrInfo, Maybe AddrInfo)
forall a b. (a -> b) -> a -> b
$
            -- head :: [a] -> Maybe a
            ( [AddrInfo] -> Maybe AddrInfo
forall a. [a] -> Maybe a
listToMaybe [AddrInfo]
ipv4s
            , [AddrInfo] -> Maybe AddrInfo
forall a. [a] -> Maybe a
listToMaybe [AddrInfo]
ipv6s
            )
    withAsync (runProtocol IPv4 v4LocalAddr v4Servers) $ \Async [NtpOffset]
ipv4Async ->
      IO [NtpOffset]
-> (Async [NtpOffset] -> IO CompletedNtpStatus)
-> IO CompletedNtpStatus
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (IPVersion -> Maybe AddrInfo -> [SockAddr] -> IO [NtpOffset]
runProtocol IPVersion
IPv6 Maybe AddrInfo
v6LocalAddr [SockAddr]
v6Servers) ((Async [NtpOffset] -> IO CompletedNtpStatus)
 -> IO CompletedNtpStatus)
-> (Async [NtpOffset] -> IO CompletedNtpStatus)
-> IO CompletedNtpStatus
forall a b. (a -> b) -> a -> b
$ \Async [NtpOffset]
ipv6Async -> do
        results <- Either IOException [NtpOffset]
-> Either IOException [NtpOffset] -> ResultOrFailure [NtpOffset]
forall a.
Either IOException [a]
-> Either IOException [a] -> ResultOrFailure [a]
mkResultOrFailure
                    (Either IOException [NtpOffset]
 -> Either IOException [NtpOffset] -> ResultOrFailure [NtpOffset])
-> IO (Either IOException [NtpOffset])
-> IO
     (Either IOException [NtpOffset] -> ResultOrFailure [NtpOffset])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Async [NtpOffset] -> IO (Either IOException [NtpOffset])
forall a. Async a -> IO (Either IOException a)
waitCatchIOException Async [NtpOffset]
ipv4Async
                    IO (Either IOException [NtpOffset] -> ResultOrFailure [NtpOffset])
-> IO (Either IOException [NtpOffset])
-> IO (ResultOrFailure [NtpOffset])
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Async [NtpOffset] -> IO (Either IOException [NtpOffset])
forall a. Async a -> IO (Either IOException a)
waitCatchIOException Async [NtpOffset]
ipv6Async
        traceWith tracer (NtpTraceRunProtocolResults results)
        handleResults (fold results)
  where
    mkResultOrFailure :: Either IOException [a] -- ^ ipv4 result
                      -> Either IOException [a] -- ^ ipv6 result
                      -> ResultOrFailure [a]
    mkResultOrFailure :: forall a.
Either IOException [a]
-> Either IOException [a] -> ResultOrFailure [a]
mkResultOrFailure (Right [a]
a0) (Right [a]
a1) = [a] -> ResultOrFailure [a]
forall a. a -> ResultOrFailure a
BothSucceeded ([a]
a0 [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> [a]
a1)
    mkResultOrFailure (Left IOException
e)   (Right [a]
a)  = [a] -> IPVersion -> IOException -> ResultOrFailure [a]
forall a. a -> IPVersion -> IOException -> ResultOrFailure a
SuccessAndFailure [a]
a IPVersion
IPv4 IOException
e
    mkResultOrFailure (Right [a]
a)  (Left IOException
e)   = [a] -> IPVersion -> IOException -> ResultOrFailure [a]
forall a. a -> IPVersion -> IOException -> ResultOrFailure a
SuccessAndFailure [a]
a IPVersion
IPv6 IOException
e
    mkResultOrFailure (Left IOException
e0)  (Left IOException
e1)  = IOException -> IOException -> ResultOrFailure [a]
forall a. IOException -> IOException -> ResultOrFailure a
BothFailed IOException
e0 IOException
e1

    runProtocol :: IPVersion -> Maybe AddrInfo -> [SockAddr] -> IO [NtpOffset]
    -- no addresses to sent to
    runProtocol :: IPVersion -> Maybe AddrInfo -> [SockAddr] -> IO [NtpOffset]
runProtocol IPVersion
_protocol Maybe AddrInfo
_localAddr  []      = [NtpOffset] -> IO [NtpOffset]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
    -- local address is not configured, e.g. no IPv6 or IPv6 gateway.
    runProtocol IPVersion
_protocol Maybe AddrInfo
Nothing     [SockAddr]
_       = [NtpOffset] -> IO [NtpOffset]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
    -- local address is configured, remote address list is non empty
    runProtocol IPVersion
protocol  (Just AddrInfo
addr) [SockAddr]
servers = do
       IOManager
-> Tracer IO NtpTrace
-> IPVersion
-> NtpSettings
-> AddrInfo
-> [SockAddr]
-> IO [NtpOffset]
runNtpQueries IOManager
ioManager Tracer IO NtpTrace
tracer IPVersion
protocol NtpSettings
ntpSettings AddrInfo
addr [SockAddr]
servers

    handleResults :: [NtpOffset] -> IO CompletedNtpStatus
    handleResults :: [NtpOffset] -> IO CompletedNtpStatus
handleResults [NtpOffset]
results = do
      let result :: Maybe NtpOffset
result = Int -> [NtpOffset] -> Maybe NtpOffset
minimumOfSome Int
ntpRequiredNumberOfResults [NtpOffset]
results
      Tracer IO NtpTrace -> NtpTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO NtpTrace
tracer (NtpStatus -> NtpTrace
NtpTraceResult (Maybe NtpOffset -> NtpStatus
ntpStatus Maybe NtpOffset
result))
      CompletedNtpStatus -> IO CompletedNtpStatus
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe NtpOffset -> CompletedNtpStatus
completedNtpStatus Maybe NtpOffset
result)

    ntpStatus :: Maybe NtpOffset -> NtpStatus
    ntpStatus :: Maybe NtpOffset -> NtpStatus
ntpStatus = NtpStatus
-> (NtpOffset -> NtpStatus) -> Maybe NtpOffset -> NtpStatus
forall b a. b -> (a -> b) -> Maybe a -> b
maybe NtpStatus
NtpSyncUnavailable NtpOffset -> NtpStatus
NtpDrift

    completedNtpStatus :: Maybe NtpOffset -> CompletedNtpStatus
    completedNtpStatus :: Maybe NtpOffset -> CompletedNtpStatus
completedNtpStatus = CompletedNtpStatus
-> (NtpOffset -> CompletedNtpStatus)
-> Maybe NtpOffset
-> CompletedNtpStatus
forall b a. b -> (a -> b) -> Maybe a -> b
maybe CompletedNtpStatus
CNtpSyncUnavailable NtpOffset -> CompletedNtpStatus
CNtpDrift



-- | Run an ntp query towards each address
--
runNtpQueries
    :: IOManager
    -> Tracer IO NtpTrace
    -> IPVersion   -- ^ address family, it must afree with local and remote
                   -- addresses
    -> NtpSettings
    -> AddrInfo    -- ^ local address
    -> [SockAddr]  -- ^ remote addresses, they are assumed to have the same
                   -- family as the local address
    -> IO [NtpOffset]
runNtpQueries :: IOManager
-> Tracer IO NtpTrace
-> IPVersion
-> NtpSettings
-> AddrInfo
-> [SockAddr]
-> IO [NtpOffset]
runNtpQueries IOManager
ioManager Tracer IO NtpTrace
tracer IPVersion
protocol NtpSettings
netSettings AddrInfo
localAddr [SockAddr]
destAddrs
    = IO Socket
-> (Socket -> IO ())
-> (Socket -> IO [NtpOffset])
-> IO [NtpOffset]
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO Socket
acquire Socket -> IO ()
release Socket -> IO [NtpOffset]
action
  where
    acquire :: IO Socket
    acquire :: IO Socket
acquire = Family -> SocketType -> ProtocolNumber -> IO Socket
Socket.socket (AddrInfo -> Family
addrFamily AddrInfo
localAddr) SocketType
Socket.Datagram ProtocolNumber
Socket.defaultProtocol

    release :: Socket -> IO ()
    release :: Socket -> IO ()
release = Socket -> IO ()
Socket.close

    action :: Socket -> IO [NtpOffset]
    action :: Socket -> IO [NtpOffset]
action Socket
socket = do
        IOManager -> forall hole. hole -> IO ()
associateWithIOManager IOManager
ioManager (Socket -> Either Any Socket
forall a b. b -> Either a b
Right Socket
socket)
        Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
socket SocketOption
Socket.ReuseAddr Int
1
        Socket -> SockAddr -> IO ()
Socket.bind Socket
socket (AddrInfo -> SockAddr
Socket.addrAddress AddrInfo
localAddr)
        inQueue <- STM (TVar [NtpOffset]) -> IO (TVar [NtpOffset])
forall a. STM a -> IO a
atomically (STM (TVar [NtpOffset]) -> IO (TVar [NtpOffset]))
-> STM (TVar [NtpOffset]) -> IO (TVar [NtpOffset])
forall a b. (a -> b) -> a -> b
$ [NtpOffset] -> STM (TVar [NtpOffset])
forall a. a -> STM (TVar a)
newTVar []
        withAsync timeout $ \Async ()
timeoutAsync ->
          IO () -> (Async () -> IO ()) -> IO ()
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (Socket -> TVar [NtpOffset] -> IO ()
receiver Socket
socket TVar [NtpOffset]
inQueue) ((Async () -> IO ()) -> IO ()) -> (Async () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Async ()
receiverAsync -> do
            [SockAddr] -> (SockAddr -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SockAddr]
destAddrs ((SockAddr -> IO ()) -> IO ()) -> (SockAddr -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \SockAddr
addr ->
              Socket -> SockAddr -> IO ()
sendNtpPacket Socket
socket SockAddr
addr
              IO () -> (IOException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch`
              -- catch 'IOException's so we don't bring the loop down;
              \(IOException
e :: IOException) -> Tracer IO NtpTrace -> NtpTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO NtpTrace
tracer (SockAddr -> IOException -> NtpTrace
NtpTracePacketSendError SockAddr
addr IOException
e)
            IO (Async (), ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Async (), ()) -> IO ()) -> IO (Async (), ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ [Async ()] -> IO (Async (), ())
forall a. [Async a] -> IO (Async a, a)
waitAny [Async ()
timeoutAsync, Async ()
receiverAsync]
        atomically $ readTVar inQueue

    --
    -- send a single ntp request towards one of the destination addresses
    --
    sendNtpPacket :: Socket -> SockAddr -> IO ()
    sendNtpPacket :: Socket -> SockAddr -> IO ()
sendNtpPacket Socket
sock SockAddr
addr = do
        p <- IO NtpPacket
mkNtpPacket
#if !defined(mingw32_HOST_OS)
        _ <- Socket.ByteString.sendManyTo sock (LBS.toChunks $ encode p) addr
#else
        -- TODO: add `sendManyTo` to `Win32-network`
        _ <- Win32.Async.sendAllTo sock (LBS.toStrict $ encode p) addr
#endif
        -- delay 100ms between sending requests, this avoids dealing with ntp
        -- results at the same time from various ntp servers, and thus we
        -- should get better results.
        threadDelay 100_000

    --
    -- timeout thread
    --
    timeout :: IO ()
timeout = do
        Int -> IO ()
threadDelay
          (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ (Microsecond -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Microsecond -> Int) -> Microsecond -> Int
forall a b. (a -> b) -> a -> b
$ NtpSettings -> Microsecond
ntpResponseTimeout NtpSettings
netSettings)
            Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
100_000 Int -> Int -> Int
forall a. Num a => a -> a -> a
* [SockAddr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SockAddr]
destAddrs
        Tracer IO NtpTrace -> NtpTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO NtpTrace
tracer (NtpTrace -> IO ()) -> NtpTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ IPVersion -> NtpTrace
NtpTraceWaitingForRepliesTimeout IPVersion
protocol

    --
    -- receiving thread
    --
    receiver :: Socket -> TVar [NtpOffset] -> IO ()
    receiver :: Socket -> TVar [NtpOffset] -> IO ()
receiver Socket
socket TVar [NtpOffset]
inQueue = Int -> IO () -> IO ()
forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ ([SockAddr] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SockAddr]
destAddrs) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        -- We don't catch exception here, we let them propagate.  This will
        -- reach top level handler in 'Network.NTP.Client.ntpClientThread' (see
        -- 'queryLoop' therein), which will be able to decide for how long to
        -- pause the the ntp-client.
#if !defined(mingw32_HOST_OS)
        (bs, senderAddr) <- Socket -> Int -> IO (ByteString, SockAddr)
Socket.ByteString.recvFrom Socket
socket Int
ntpPacketSize
#else
        (bs, senderAddr) <- Win32.Async.recvFrom socket ntpPacketSize
#endif
        t <- getCurrentTime
        case decodeOrFail $ LBS.fromStrict bs of
            Left  (LazyByteString
_, ByteOffset
_, String
err) -> Tracer IO NtpTrace -> NtpTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO NtpTrace
tracer (NtpTrace -> IO ()) -> NtpTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ SockAddr -> String -> NtpTrace
NtpTracePacketDecodeError SockAddr
senderAddr String
err
            -- TODO : filter bad packets, i.e. late packets and spoofed packets
            Right (LazyByteString
_, ByteOffset
_, NtpPacket
packet) -> do
                Tracer IO NtpTrace -> NtpTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO NtpTrace
tracer (NtpTrace -> IO ()) -> NtpTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ SockAddr -> NtpPacket -> NtpTrace
NtpTracePacketReceived SockAddr
senderAddr NtpPacket
packet
                let offset :: NtpOffset
offset = (NtpPacket -> Microsecond -> NtpOffset
clockOffsetPure NtpPacket
packet Microsecond
t)
                STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar [NtpOffset] -> ([NtpOffset] -> [NtpOffset]) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar [NtpOffset]
inQueue (NtpOffset
offset NtpOffset -> [NtpOffset] -> [NtpOffset]
forall a. a -> [a] -> [a]
:)

--
-- Trace
--


data NtpTrace
    = NtpTraceStartNtpClient
    | NtpTraceRestartDelay Int
    | NtpTraceRestartingClient
    | NtpTraceIOError IOError
    | NtpTraceLookupsFails
    | NtpTraceClientStartQuery
    | NtpTraceNoLocalAddr
    | NtpTraceResult NtpStatus
    | NtpTraceRunProtocolResults (ResultOrFailure [NtpOffset])
    | NtpTracePacketSent SockAddr NtpPacket
    | NtpTracePacketSendError SockAddr IOException
    | NtpTracePacketDecodeError SockAddr String
    | NtpTracePacketReceived SockAddr NtpPacket
    | NtpTraceWaitingForRepliesTimeout IPVersion
    deriving (Int -> NtpTrace -> ShowS
[NtpTrace] -> ShowS
NtpTrace -> String
(Int -> NtpTrace -> ShowS)
-> (NtpTrace -> String) -> ([NtpTrace] -> ShowS) -> Show NtpTrace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NtpTrace -> ShowS
showsPrec :: Int -> NtpTrace -> ShowS
$cshow :: NtpTrace -> String
show :: NtpTrace -> String
$cshowList :: [NtpTrace] -> ShowS
showList :: [NtpTrace] -> ShowS
Show)