{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric  #-}
{-# LANGUAGE LambdaCase     #-}
{-# LANGUAGE NamedFieldPuns #-}

module DMQ.NodeToClient where

import Codec.CBOR.Term qualified as CBOR
import Control.DeepSeq (NFData)
import Control.Monad ((>=>))
import Data.Bits (Bits (..))
import Data.Text (Text)
import Data.Text qualified as T
import GHC.Generics (Generic)

import Control.Monad.Class.MonadST (MonadST)
import Control.Tracer (Tracer, nullTracer)

import Network.Mux qualified as Mx

import Ouroboros.Network.CodecCBORTerm (CodecCBORTerm (..))
import Ouroboros.Network.ConnectionId (ConnectionId)
import Ouroboros.Network.Driver.Simple (TraceSendRecv)
import Ouroboros.Network.Handshake.Acceptable (Acceptable (..))
import Ouroboros.Network.Handshake.Queryable (Queryable (..))
import Ouroboros.Network.Magic (NetworkMagic (..))
import Ouroboros.Network.Protocol.Handshake (Accept (..), Handshake,
           HandshakeArguments (..))
import Ouroboros.Network.Protocol.Handshake.Codec (cborTermVersionDataCodec,
           codecHandshake, noTimeLimitsHandshake)

data NodeToClientVersion =
  NodeToClientV_1
  deriving (NodeToClientVersion -> NodeToClientVersion -> Bool
(NodeToClientVersion -> NodeToClientVersion -> Bool)
-> (NodeToClientVersion -> NodeToClientVersion -> Bool)
-> Eq NodeToClientVersion
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NodeToClientVersion -> NodeToClientVersion -> Bool
== :: NodeToClientVersion -> NodeToClientVersion -> Bool
$c/= :: NodeToClientVersion -> NodeToClientVersion -> Bool
/= :: NodeToClientVersion -> NodeToClientVersion -> Bool
Eq, Eq NodeToClientVersion
Eq NodeToClientVersion =>
(NodeToClientVersion -> NodeToClientVersion -> Ordering)
-> (NodeToClientVersion -> NodeToClientVersion -> Bool)
-> (NodeToClientVersion -> NodeToClientVersion -> Bool)
-> (NodeToClientVersion -> NodeToClientVersion -> Bool)
-> (NodeToClientVersion -> NodeToClientVersion -> Bool)
-> (NodeToClientVersion
    -> NodeToClientVersion -> NodeToClientVersion)
-> (NodeToClientVersion
    -> NodeToClientVersion -> NodeToClientVersion)
-> Ord NodeToClientVersion
NodeToClientVersion -> NodeToClientVersion -> Bool
NodeToClientVersion -> NodeToClientVersion -> Ordering
NodeToClientVersion -> NodeToClientVersion -> NodeToClientVersion
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: NodeToClientVersion -> NodeToClientVersion -> Ordering
compare :: NodeToClientVersion -> NodeToClientVersion -> Ordering
$c< :: NodeToClientVersion -> NodeToClientVersion -> Bool
< :: NodeToClientVersion -> NodeToClientVersion -> Bool
$c<= :: NodeToClientVersion -> NodeToClientVersion -> Bool
<= :: NodeToClientVersion -> NodeToClientVersion -> Bool
$c> :: NodeToClientVersion -> NodeToClientVersion -> Bool
> :: NodeToClientVersion -> NodeToClientVersion -> Bool
$c>= :: NodeToClientVersion -> NodeToClientVersion -> Bool
>= :: NodeToClientVersion -> NodeToClientVersion -> Bool
$cmax :: NodeToClientVersion -> NodeToClientVersion -> NodeToClientVersion
max :: NodeToClientVersion -> NodeToClientVersion -> NodeToClientVersion
$cmin :: NodeToClientVersion -> NodeToClientVersion -> NodeToClientVersion
min :: NodeToClientVersion -> NodeToClientVersion -> NodeToClientVersion
Ord, Int -> NodeToClientVersion
NodeToClientVersion -> Int
NodeToClientVersion -> [NodeToClientVersion]
NodeToClientVersion -> NodeToClientVersion
NodeToClientVersion -> NodeToClientVersion -> [NodeToClientVersion]
NodeToClientVersion
-> NodeToClientVersion
-> NodeToClientVersion
-> [NodeToClientVersion]
(NodeToClientVersion -> NodeToClientVersion)
-> (NodeToClientVersion -> NodeToClientVersion)
-> (Int -> NodeToClientVersion)
-> (NodeToClientVersion -> Int)
-> (NodeToClientVersion -> [NodeToClientVersion])
-> (NodeToClientVersion
    -> NodeToClientVersion -> [NodeToClientVersion])
-> (NodeToClientVersion
    -> NodeToClientVersion -> [NodeToClientVersion])
-> (NodeToClientVersion
    -> NodeToClientVersion
    -> NodeToClientVersion
    -> [NodeToClientVersion])
-> Enum NodeToClientVersion
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: NodeToClientVersion -> NodeToClientVersion
succ :: NodeToClientVersion -> NodeToClientVersion
$cpred :: NodeToClientVersion -> NodeToClientVersion
pred :: NodeToClientVersion -> NodeToClientVersion
$ctoEnum :: Int -> NodeToClientVersion
toEnum :: Int -> NodeToClientVersion
$cfromEnum :: NodeToClientVersion -> Int
fromEnum :: NodeToClientVersion -> Int
$cenumFrom :: NodeToClientVersion -> [NodeToClientVersion]
enumFrom :: NodeToClientVersion -> [NodeToClientVersion]
$cenumFromThen :: NodeToClientVersion -> NodeToClientVersion -> [NodeToClientVersion]
enumFromThen :: NodeToClientVersion -> NodeToClientVersion -> [NodeToClientVersion]
$cenumFromTo :: NodeToClientVersion -> NodeToClientVersion -> [NodeToClientVersion]
enumFromTo :: NodeToClientVersion -> NodeToClientVersion -> [NodeToClientVersion]
$cenumFromThenTo :: NodeToClientVersion
-> NodeToClientVersion
-> NodeToClientVersion
-> [NodeToClientVersion]
enumFromThenTo :: NodeToClientVersion
-> NodeToClientVersion
-> NodeToClientVersion
-> [NodeToClientVersion]
Enum, NodeToClientVersion
NodeToClientVersion
-> NodeToClientVersion -> Bounded NodeToClientVersion
forall a. a -> a -> Bounded a
$cminBound :: NodeToClientVersion
minBound :: NodeToClientVersion
$cmaxBound :: NodeToClientVersion
maxBound :: NodeToClientVersion
Bounded, Int -> NodeToClientVersion -> ShowS
[NodeToClientVersion] -> ShowS
NodeToClientVersion -> String
(Int -> NodeToClientVersion -> ShowS)
-> (NodeToClientVersion -> String)
-> ([NodeToClientVersion] -> ShowS)
-> Show NodeToClientVersion
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NodeToClientVersion -> ShowS
showsPrec :: Int -> NodeToClientVersion -> ShowS
$cshow :: NodeToClientVersion -> String
show :: NodeToClientVersion -> String
$cshowList :: [NodeToClientVersion] -> ShowS
showList :: [NodeToClientVersion] -> ShowS
Show, (forall x. NodeToClientVersion -> Rep NodeToClientVersion x)
-> (forall x. Rep NodeToClientVersion x -> NodeToClientVersion)
-> Generic NodeToClientVersion
forall x. Rep NodeToClientVersion x -> NodeToClientVersion
forall x. NodeToClientVersion -> Rep NodeToClientVersion x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. NodeToClientVersion -> Rep NodeToClientVersion x
from :: forall x. NodeToClientVersion -> Rep NodeToClientVersion x
$cto :: forall x. Rep NodeToClientVersion x -> NodeToClientVersion
to :: forall x. Rep NodeToClientVersion x -> NodeToClientVersion
Generic, NodeToClientVersion -> ()
(NodeToClientVersion -> ()) -> NFData NodeToClientVersion
forall a. (a -> ()) -> NFData a
$crnf :: NodeToClientVersion -> ()
rnf :: NodeToClientVersion -> ()
NFData)

nodeToClientVersionCodec :: CodecCBORTerm (Text, Maybe Int) NodeToClientVersion
nodeToClientVersionCodec :: CodecCBORTerm (Text, Maybe Int) NodeToClientVersion
nodeToClientVersionCodec = CodecCBORTerm { NodeToClientVersion -> Term
encodeTerm :: NodeToClientVersion -> Term
encodeTerm :: NodeToClientVersion -> Term
encodeTerm, Term -> Either (Text, Maybe Int) NodeToClientVersion
decodeTerm :: Term -> Either (Text, Maybe Int) NodeToClientVersion
decodeTerm :: Term -> Either (Text, Maybe Int) NodeToClientVersion
decodeTerm }
    where
      encodeTerm :: NodeToClientVersion -> Term
encodeTerm = \case
          NodeToClientVersion
NodeToClientV_1 -> Int -> Term
enc Int
1
        where
          enc :: Int -> CBOR.Term
          enc :: Int -> Term
enc = Int -> Term
CBOR.TInt (Int -> Term) -> (Int -> Int) -> Int -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`setBit` Int
nodeToClientVersionBit)

      decodeTerm :: Term -> Either (Text, Maybe Int) NodeToClientVersion
decodeTerm =
          Term -> Either (Text, Maybe Int) Int
dec (Term -> Either (Text, Maybe Int) Int)
-> (Int -> Either (Text, Maybe Int) NodeToClientVersion)
-> Term
-> Either (Text, Maybe Int) NodeToClientVersion
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> \case
            Int
1 -> NodeToClientVersion -> Either (Text, Maybe Int) NodeToClientVersion
forall a b. b -> Either a b
Right NodeToClientVersion
NodeToClientV_1
            Int
n  -> (Text, Maybe Int) -> Either (Text, Maybe Int) NodeToClientVersion
forall a b. a -> Either a b
Left (Int -> (Text, Maybe Int)
forall {a}. Show a => a -> (Text, Maybe a)
unknownTag Int
n)
        where
          dec :: CBOR.Term -> Either (Text, Maybe Int) Int
          dec :: Term -> Either (Text, Maybe Int) Int
dec (CBOR.TInt Int
x) | Int
x Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
`testBit` Int
nodeToClientVersionBit
                            = Int -> Either (Text, Maybe Int) Int
forall a b. b -> Either a b
Right (Int
x Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`clearBit` Int
nodeToClientVersionBit)
                            | Bool
otherwise
                            = (Text, Maybe Int) -> Either (Text, Maybe Int) Int
forall a b. a -> Either a b
Left (Int -> (Text, Maybe Int)
forall {a}. Show a => a -> (Text, Maybe a)
unknownTag Int
x)
          dec Term
_             = (Text, Maybe Int) -> Either (Text, Maybe Int) Int
forall a b. a -> Either a b
Left ( String -> Text
T.pack String
"decode NodeToClientVersion: unexpected term"
                                   , Maybe Int
forall a. Maybe a
Nothing
                                   )

          unknownTag :: a -> (Text, Maybe a)
unknownTag a
x = ( String -> Text
T.pack String
"decode NodeToClientVersion: unknown tag: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (a -> String
forall a. Show a => a -> String
show a
x), a -> Maybe a
forall a. a -> Maybe a
Just a
x)

      -- The 16th bit to distinguish `NodeToNodeVersion` and `NodeToClientVersion`.
      -- This is different than the one defined in ouroboros-network.
      nodeToClientVersionBit :: Int
      nodeToClientVersionBit :: Int
nodeToClientVersionBit = Int
12

-- | Version data for NodeToClient protocol v1
--
-- This data type is inpired by the one defined in 'ouroboros-network-api',
-- however, it is redefined here to tie it to our custom `NodeToClientVersion`
-- and to avoid divergences.
--
data NodeToClientVersionData = NodeToClientVersionData
  { NodeToClientVersionData -> NetworkMagic
networkMagic :: !NetworkMagic
  , NodeToClientVersionData -> Bool
query        :: !Bool
  }
  deriving (NodeToClientVersionData -> NodeToClientVersionData -> Bool
(NodeToClientVersionData -> NodeToClientVersionData -> Bool)
-> (NodeToClientVersionData -> NodeToClientVersionData -> Bool)
-> Eq NodeToClientVersionData
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NodeToClientVersionData -> NodeToClientVersionData -> Bool
== :: NodeToClientVersionData -> NodeToClientVersionData -> Bool
$c/= :: NodeToClientVersionData -> NodeToClientVersionData -> Bool
/= :: NodeToClientVersionData -> NodeToClientVersionData -> Bool
Eq, Int -> NodeToClientVersionData -> ShowS
[NodeToClientVersionData] -> ShowS
NodeToClientVersionData -> String
(Int -> NodeToClientVersionData -> ShowS)
-> (NodeToClientVersionData -> String)
-> ([NodeToClientVersionData] -> ShowS)
-> Show NodeToClientVersionData
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NodeToClientVersionData -> ShowS
showsPrec :: Int -> NodeToClientVersionData -> ShowS
$cshow :: NodeToClientVersionData -> String
show :: NodeToClientVersionData -> String
$cshowList :: [NodeToClientVersionData] -> ShowS
showList :: [NodeToClientVersionData] -> ShowS
Show)

instance Acceptable NodeToClientVersionData where
    acceptableVersion :: NodeToClientVersionData
-> NodeToClientVersionData -> Accept NodeToClientVersionData
acceptableVersion NodeToClientVersionData
local NodeToClientVersionData
remote
      | NodeToClientVersionData -> NetworkMagic
networkMagic NodeToClientVersionData
local NetworkMagic -> NetworkMagic -> Bool
forall a. Eq a => a -> a -> Bool
== NodeToClientVersionData -> NetworkMagic
networkMagic NodeToClientVersionData
remote
      = NodeToClientVersionData -> Accept NodeToClientVersionData
forall vData. vData -> Accept vData
Accept NodeToClientVersionData
          { networkMagic :: NetworkMagic
networkMagic  = NodeToClientVersionData -> NetworkMagic
networkMagic NodeToClientVersionData
local
          , query :: Bool
query         = NodeToClientVersionData -> Bool
query NodeToClientVersionData
local Bool -> Bool -> Bool
|| NodeToClientVersionData -> Bool
query NodeToClientVersionData
remote
          }
      | Bool
otherwise =  Text -> Accept NodeToClientVersionData
forall vData. Text -> Accept vData
Refuse (Text -> Accept NodeToClientVersionData)
-> Text -> Accept NodeToClientVersionData
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String
"version data mismatch: "
                                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ NodeToClientVersionData -> String
forall a. Show a => a -> String
show NodeToClientVersionData
local
                                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" /= " String -> ShowS
forall a. [a] -> [a] -> [a]
++ NodeToClientVersionData -> String
forall a. Show a => a -> String
show NodeToClientVersionData
remote

instance Queryable NodeToClientVersionData where
    queryVersion :: NodeToClientVersionData -> Bool
queryVersion = NodeToClientVersionData -> Bool
query

nodeToClientCodecCBORTerm :: NodeToClientVersion -> CodecCBORTerm Text NodeToClientVersionData
nodeToClientCodecCBORTerm :: NodeToClientVersion -> CodecCBORTerm Text NodeToClientVersionData
nodeToClientCodecCBORTerm NodeToClientVersion
_v = CodecCBORTerm {NodeToClientVersionData -> Term
encodeTerm :: NodeToClientVersionData -> Term
encodeTerm :: NodeToClientVersionData -> Term
encodeTerm, Term -> Either Text NodeToClientVersionData
decodeTerm :: Term -> Either Text NodeToClientVersionData
decodeTerm :: Term -> Either Text NodeToClientVersionData
decodeTerm}
    where
      encodeTerm :: NodeToClientVersionData -> CBOR.Term
      encodeTerm :: NodeToClientVersionData -> Term
encodeTerm NodeToClientVersionData { NetworkMagic
networkMagic :: NodeToClientVersionData -> NetworkMagic
networkMagic :: NetworkMagic
networkMagic, Bool
query :: NodeToClientVersionData -> Bool
query :: Bool
query }
        = [Term] -> Term
CBOR.TList [Int -> Term
CBOR.TInt (Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int) -> Word32 -> Int
forall a b. (a -> b) -> a -> b
$ NetworkMagic -> Word32
unNetworkMagic NetworkMagic
networkMagic), Bool -> Term
CBOR.TBool Bool
query]

      decodeTerm :: CBOR.Term -> Either Text NodeToClientVersionData
      decodeTerm :: Term -> Either Text NodeToClientVersionData
decodeTerm (CBOR.TList [CBOR.TInt Int
x, CBOR.TBool Bool
query])
        = Int -> Bool -> Either Text NodeToClientVersionData
decoder Int
x Bool
query
      decodeTerm Term
t
        = Text -> Either Text NodeToClientVersionData
forall a b. a -> Either a b
Left (Text -> Either Text NodeToClientVersionData)
-> Text -> Either Text NodeToClientVersionData
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String
"unknown encoding: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Term -> String
forall a. Show a => a -> String
show Term
t

      decoder :: Int -> Bool -> Either Text NodeToClientVersionData
      decoder :: Int -> Bool -> Either Text NodeToClientVersionData
decoder Int
x Bool
query | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0xffffffff = NodeToClientVersionData -> Either Text NodeToClientVersionData
forall a b. b -> Either a b
Right (NetworkMagic -> Bool -> NodeToClientVersionData
NodeToClientVersionData (Word32 -> NetworkMagic
NetworkMagic (Word32 -> NetworkMagic) -> Word32 -> NetworkMagic
forall a b. (a -> b) -> a -> b
$ Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x) Bool
query)
                      | Bool
otherwise                 = Text -> Either Text NodeToClientVersionData
forall a b. a -> Either a b
Left (Text -> Either Text NodeToClientVersionData)
-> Text -> Either Text NodeToClientVersionData
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String
"networkMagic out of bound: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
x

data Protocols =
  Protocols {
  }

type HandshakeTr ntcAddr = Mx.WithBearer (ConnectionId ntcAddr) (TraceSendRecv (Handshake NodeToClientVersion CBOR.Term))

ntcHandshakeArguments
  :: MonadST m
  => Tracer m (HandshakeTr ntcAddr)
  -> HandshakeArguments
      (ConnectionId ntcAddr)
      NodeToClientVersion
      NodeToClientVersionData
      m
ntcHandshakeArguments :: forall (m :: * -> *) ntcAddr.
MonadST m =>
Tracer m (HandshakeTr ntcAddr)
-> HandshakeArguments
     (ConnectionId ntcAddr)
     NodeToClientVersion
     NodeToClientVersionData
     m
ntcHandshakeArguments Tracer m (HandshakeTr ntcAddr)
tracer =
  HandshakeArguments {
    haHandshakeTracer :: Tracer m (HandshakeTr ntcAddr)
haHandshakeTracer  = Tracer m (HandshakeTr ntcAddr)
tracer
  , haBearerTracer :: Tracer m (WithBearer (ConnectionId ntcAddr) BearerTrace)
haBearerTracer     = Tracer m (WithBearer (ConnectionId ntcAddr) BearerTrace)
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer -- TODO
  , haHandshakeCodec :: Codec
  (Handshake NodeToClientVersion Term)
  DeserialiseFailure
  m
  ByteString
haHandshakeCodec   = CodecCBORTerm (Text, Maybe Int) NodeToClientVersion
-> Codec
     (Handshake NodeToClientVersion Term)
     DeserialiseFailure
     m
     ByteString
forall vNumber (m :: * -> *) failure.
(MonadST m, Ord vNumber, Show failure) =>
CodecCBORTerm (failure, Maybe Int) vNumber
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
codecHandshake CodecCBORTerm (Text, Maybe Int) NodeToClientVersion
nodeToClientVersionCodec
  , haVersionDataCodec :: VersionDataCodec Term NodeToClientVersion NodeToClientVersionData
haVersionDataCodec =
      (NodeToClientVersion -> CodecCBORTerm Text NodeToClientVersionData)
-> VersionDataCodec
     Term NodeToClientVersion NodeToClientVersionData
forall vNumber vData.
(vNumber -> CodecCBORTerm Text vData)
-> VersionDataCodec Term vNumber vData
cborTermVersionDataCodec
        NodeToClientVersion -> CodecCBORTerm Text NodeToClientVersionData
nodeToClientCodecCBORTerm
  , haAcceptVersion :: NodeToClientVersionData
-> NodeToClientVersionData -> Accept NodeToClientVersionData
haAcceptVersion = NodeToClientVersionData
-> NodeToClientVersionData -> Accept NodeToClientVersionData
forall v. Acceptable v => v -> v -> Accept v
acceptableVersion
  , haQueryVersion :: NodeToClientVersionData -> Bool
haQueryVersion  = NodeToClientVersionData -> Bool
forall v. Queryable v => v -> Bool
queryVersion
  , haTimeLimits :: ProtocolTimeLimits (Handshake NodeToClientVersion Term)
haTimeLimits    = ProtocolTimeLimits (Handshake NodeToClientVersion Term)
forall {k} (vNumber :: k).
ProtocolTimeLimits (Handshake vNumber Term)
noTimeLimitsHandshake
  }

stdVersionDataNTC :: NetworkMagic -> NodeToClientVersionData
stdVersionDataNTC :: NetworkMagic -> NodeToClientVersionData
stdVersionDataNTC NetworkMagic
networkMagic =
  NodeToClientVersionData
    { NetworkMagic
networkMagic :: NetworkMagic
networkMagic :: NetworkMagic
networkMagic
    , query :: Bool
query        = Bool
False
    }