{-# LANGUAGE NumericUnderscores  #-}
{-# LANGUAGE LambdaCase          #-}
module Network.NTP.Client (
-- * The API for starting an ntp client-thread.
    withNtpClient
  , NtpSettings(..)
  , NtpClient(..)
  , NtpStatus(..)
  -- ** Low level interface
  -- *** Running an @ntp@ query
  , ntpQuery

  -- * Logging interface
  , NtpTrace(..)
  , IPVersion(..)
  , ResultOrFailure(..)
  ) where

import           Control.Concurrent (threadDelay)
import           Control.Concurrent.Async
import           Control.Concurrent.STM (STM, atomically, check)
import           Control.Concurrent.STM.TVar
import           Control.Monad (when)
import           System.IO.Error (tryIOError)
import           Control.Tracer
import           Data.Void (Void)

import           System.IOManager

import           Network.NTP.Client.Query

-- | 'NtpClient' which recieves updates of the wall clcok drift every
-- 'ntpPollDelay'.  It also allows to force engaging in ntp protocol.
--
data NtpClient = NtpClient
    { -- | Query the current NTP status.
      NtpClient -> STM NtpStatus
ntpGetStatus     :: STM NtpStatus
      -- | Force to update the ntp state, unless an ntp query is already
      -- running.  This is a blocking operation.
    , NtpClient -> IO NtpStatus
ntpQueryBlocking :: IO NtpStatus
      -- | Ntp client thread
    , NtpClient -> Async Void
ntpThread        :: Async Void
    }


-- | Setup a NtpClient and run an application that uses provided 'NtpClient'.
-- The 'NtpClient' is terminated when the callback returns.  The application
-- can 'waitCatch' on 'ntpThread'.
--
withNtpClient :: IOManager
              -> Tracer IO NtpTrace
              -> NtpSettings
              -> (NtpClient -> IO a)
              -> IO a
withNtpClient :: forall a.
IOManager
-> Tracer IO NtpTrace -> NtpSettings -> (NtpClient -> IO a) -> IO a
withNtpClient IOManager
ioManager Tracer IO NtpTrace
tracer NtpSettings
ntpSettings NtpClient -> IO a
action = do
    Tracer IO NtpTrace -> NtpTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO NtpTrace
tracer NtpTrace
NtpTraceStartNtpClient
    ntpStatus <- NtpStatus -> IO (TVar NtpStatus)
forall a. a -> IO (TVar a)
newTVarIO NtpStatus
NtpSyncPending
    withAsync (ntpClientThread ioManager tracer ntpSettings ntpStatus) $ \Async Void
tid -> do
        let client :: NtpClient
client = NtpClient
              { ntpGetStatus :: STM NtpStatus
ntpGetStatus = TVar NtpStatus -> STM NtpStatus
forall a. TVar a -> STM a
readTVar TVar NtpStatus
ntpStatus
              , ntpQueryBlocking :: IO NtpStatus
ntpQueryBlocking = do
                  -- trigger an update, unless an ntp query is not already
                  -- running
                  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                    status <- TVar NtpStatus -> STM NtpStatus
forall a. TVar a -> STM a
readTVar TVar NtpStatus
ntpStatus
                    when (status /= NtpSyncPending)
                      $ writeTVar ntpStatus NtpSyncPending
                  -- block until the state changes
                  STM NtpStatus -> IO NtpStatus
forall a. STM a -> IO a
atomically (STM NtpStatus -> IO NtpStatus) -> STM NtpStatus -> IO NtpStatus
forall a b. (a -> b) -> a -> b
$ do
                      status <- TVar NtpStatus -> STM NtpStatus
forall a. TVar a -> STM a
readTVar TVar NtpStatus
ntpStatus
                      check $ status /= NtpSyncPending
                      return status
              , ntpThread :: Async Void
ntpThread = Async Void
tid
              }
        NtpClient -> IO a
action NtpClient
client

awaitPendingWithTimeout :: TVar NtpStatus -> Int -> IO ()
awaitPendingWithTimeout :: TVar NtpStatus -> Int -> IO ()
awaitPendingWithTimeout TVar NtpStatus
tvar Int
t
    = IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO ()
race_
       ( Int -> IO ()
threadDelay Int
t )
       ( STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
           s <- TVar NtpStatus -> STM NtpStatus
forall a. TVar a -> STM a
readTVar TVar NtpStatus
tvar
           check $ s == NtpSyncPending
       )

-- | ntp client thread which wakes up every 'ntpPollDelay' to make ntp queries.
-- It can be woken up earlier by setting 'NptStatus' to 'NtpSyncPending'.
ntpClientThread
    :: IOManager
    -> Tracer IO NtpTrace
    -> NtpSettings
    -> TVar NtpStatus
    -> IO Void
ntpClientThread :: IOManager
-> Tracer IO NtpTrace -> NtpSettings -> TVar NtpStatus -> IO Void
ntpClientThread IOManager
ioManager Tracer IO NtpTrace
tracer NtpSettings
ntpSettings TVar NtpStatus
ntpStatus = Int -> IO Void
queryLoop Int
initialErrorDelay
  where
    queryLoop :: Int -> IO Void
    queryLoop :: Int -> IO Void
queryLoop Int
errorDelay = IO CompletedNtpStatus -> IO (Either IOError CompletedNtpStatus)
forall a. IO a -> IO (Either IOError a)
tryIOError (IOManager
-> Tracer IO NtpTrace -> NtpSettings -> IO CompletedNtpStatus
ntpQuery IOManager
ioManager Tracer IO NtpTrace
tracer NtpSettings
ntpSettings) IO (Either IOError CompletedNtpStatus)
-> (Either IOError CompletedNtpStatus -> IO Void) -> IO Void
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Right (CNtpDrift NtpOffset
offset) -> do
            STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar NtpStatus -> NtpStatus -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar NtpStatus
ntpStatus (NtpOffset -> NtpStatus
NtpDrift NtpOffset
offset)
            -- After a successful query the client sleeps
            -- for the time interval set in `ntpPollDelay`.
            TVar NtpStatus -> Int -> IO ()
awaitPendingWithTimeout TVar NtpStatus
ntpStatus (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Microsecond -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Microsecond -> Int) -> Microsecond -> Int
forall a b. (a -> b) -> a -> b
$ NtpSettings -> Microsecond
ntpPollDelay NtpSettings
ntpSettings
            Int -> IO Void
queryLoop Int
initialErrorDelay -- Use the initialErrorDelay.
        Right CompletedNtpStatus
CNtpSyncUnavailable -> Int -> IO Void
fastRetry Int
errorDelay
        Left IOError
err -> do
            Tracer IO NtpTrace -> NtpTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO NtpTrace
tracer (NtpTrace -> IO ()) -> NtpTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ IOError -> NtpTrace
NtpTraceIOError IOError
err
            Int -> IO Void
fastRetry Int
errorDelay

    -- When a query was not successful client does a fast retry.
    -- It sleeps for the time defined by `errorDelay`.
    fastRetry :: Int -> IO Void
fastRetry Int
errorDelay = do
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar NtpStatus -> NtpStatus -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar NtpStatus
ntpStatus NtpStatus
NtpSyncUnavailable
        Tracer IO NtpTrace -> NtpTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO NtpTrace
tracer (NtpTrace -> IO ()) -> NtpTrace -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> NtpTrace
NtpTraceRestartDelay Int
errorDelay
        TVar NtpStatus -> Int -> IO ()
awaitPendingWithTimeout TVar NtpStatus
ntpStatus (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Int
errorDelay Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1_000_000
        Tracer IO NtpTrace -> NtpTrace -> IO ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO NtpTrace
tracer NtpTrace
NtpTraceRestartingClient
        -- Double the error delay but, do not wait more than 600s.
        Int -> IO Void
queryLoop (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
errorDelay Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
600)

    initialErrorDelay :: Int
initialErrorDelay = Int
5