{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE PolyKinds           #-}
{-# LANGUAGE ScopedTypeVariables #-}

module DMQ.Protocol.SigSubmission.Codec
  ( codecSigSubmission
  , byteLimitsSigSubmission
  , timeLimitsSigSubmission
  , codecSigSubmissionId
    -- * Exported utility functions
  , encodeSig
  , decodeSig
  , encodeSigId
  , decodeSigId
  , encodeSigOpCertificate
  , decodeSigOpCertificate
  ) where

import Control.Monad (when)
import Control.Monad.Class.MonadST
import Control.Monad.Class.MonadTime.SI
import Data.ByteString.Lazy (ByteString)
import Text.Printf

import Codec.CBOR.Decoding qualified as CBOR
import Codec.CBOR.Encoding qualified as CBOR
import Codec.CBOR.Read qualified as CBOR

import Network.TypedProtocol.Codec.CBOR

import Cardano.Binary (FromCBOR (..), ToCBOR (..))
import Cardano.Crypto.DSIGN.Class (decodeSignedDSIGN, encodeSignedDSIGN)
import Cardano.Crypto.KES.Class (decodeVerKeyKES, encodeVerKeyKES)
import Cardano.KESAgent.KES.Crypto (Crypto (..))
import Cardano.KESAgent.KES.OCert (OCert (..))

import DMQ.Protocol.SigSubmission.Type
import Ouroboros.Network.Protocol.Codec.Utils qualified as Utils
import Ouroboros.Network.Protocol.Limits
import Ouroboros.Network.Protocol.TxSubmission2.Codec qualified as TX



-- | 'SigSubmission' time limits.
--
-- +-----------------------------+---------------+
-- | 'SigSubmission' state       | timeout (s)   |
-- +=============================+===============+
-- | `StInit`                    | `waitForever` |
-- +-----------------------------+---------------+
-- | `StIdle`                    | `waitForever` |
-- +-----------------------------+---------------+
-- | @'StTxIds' 'StBlocking'@    | `waitForever` |
-- +-----------------------------+---------------+
-- | @'StTxIds' 'StNonBlocking'@ | `shortWait`   |
-- +-----------------------------+---------------+
-- | `StTxs`                     | `shortWait`   |
-- +-----------------------------+---------------+
--
timeLimitsSigSubmission :: forall crypto. ProtocolTimeLimits (SigSubmission crypto)
timeLimitsSigSubmission :: forall crypto. ProtocolTimeLimits (SigSubmission crypto)
timeLimitsSigSubmission = (forall (st :: SigSubmission crypto).
 ActiveState st =>
 StateToken st -> Maybe DiffTime)
-> ProtocolTimeLimits (SigSubmission crypto)
forall ps.
(forall (st :: ps).
 ActiveState st =>
 StateToken st -> Maybe DiffTime)
-> ProtocolTimeLimits ps
ProtocolTimeLimits StateToken st -> Maybe DiffTime
forall (st :: SigSubmission crypto).
ActiveState st =>
StateToken st -> Maybe DiffTime
stateToLimit
  where
    stateToLimit :: forall (st :: SigSubmission crypto).
                    ActiveState st => StateToken st -> Maybe DiffTime
    stateToLimit :: forall (st :: SigSubmission crypto).
ActiveState st =>
StateToken st -> Maybe DiffTime
stateToLimit StateToken st
SingTxSubmission st
SingInit                    = Maybe DiffTime
waitForever
    stateToLimit (SingTxIds SingBlockingStyle stBlocking
SingBlocking)    = Maybe DiffTime
waitForever
    stateToLimit (SingTxIds SingBlockingStyle stBlocking
SingNonBlocking) = Maybe DiffTime
shortWait
    stateToLimit StateToken st
SingTxSubmission st
SingTxs                     = Maybe DiffTime
shortWait
    stateToLimit StateToken st
SingTxSubmission st
SingIdle                    = Maybe DiffTime
waitForever
    stateToLimit a :: StateToken st
a@StateToken st
SingTxSubmission st
SingDone                  = StateToken 'StDone -> forall a. a
forall ps (st :: ps).
(StateAgency st ~ 'NobodyAgency, ActiveState st) =>
StateToken st -> forall a. a
notActiveState StateToken st
StateToken 'StDone
a


-- TODO: these limits needs to be checked with the mithril team
byteLimitsSigSubmission :: forall crypto bytes.
                           (bytes -> Word)
                        -> ProtocolSizeLimits (SigSubmission crypto) bytes
byteLimitsSigSubmission :: forall crypto bytes.
(bytes -> Word) -> ProtocolSizeLimits (SigSubmission crypto) bytes
byteLimitsSigSubmission = (forall (st :: SigSubmission crypto).
 ActiveState st =>
 StateToken st -> Word)
-> (bytes -> Word)
-> ProtocolSizeLimits (SigSubmission crypto) bytes
forall ps bytes.
(forall (st :: ps). ActiveState st => StateToken st -> Word)
-> (bytes -> Word) -> ProtocolSizeLimits ps bytes
ProtocolSizeLimits StateToken st -> Word
forall (st :: SigSubmission crypto).
ActiveState st =>
StateToken st -> Word
stateToLimit
  where
    stateToLimit :: forall (st :: SigSubmission crypto).
                    ActiveState st => StateToken st -> Word
    stateToLimit :: forall (st :: SigSubmission crypto).
ActiveState st =>
StateToken st -> Word
stateToLimit StateToken st
SingTxSubmission st
SingInit                    = Word
smallByteLimit
    stateToLimit (SingTxIds SingBlockingStyle stBlocking
SingBlocking)    = Word
smallByteLimit
    stateToLimit (SingTxIds SingBlockingStyle stBlocking
SingNonBlocking) = Word
smallByteLimit
    stateToLimit StateToken st
SingTxSubmission st
SingTxs                     = Word
smallByteLimit
    stateToLimit StateToken st
SingTxSubmission st
SingIdle                    = Word
smallByteLimit
    stateToLimit a :: StateToken st
a@StateToken st
SingTxSubmission st
SingDone                  = StateToken 'StDone -> forall a. a
forall ps (st :: ps).
(StateAgency st ~ 'NobodyAgency, ActiveState st) =>
StateToken st -> forall a. a
notActiveState StateToken st
StateToken 'StDone
a


encodeSigId :: SigId -> CBOR.Encoding
encodeSigId :: SigId -> Encoding
encodeSigId SigId { SigHash
getSigId :: SigHash
getSigId :: SigId -> SigHash
getSigId } = ByteString -> Encoding
CBOR.encodeBytes (SigHash -> ByteString
getSigHash SigHash
getSigId)

decodeSigId :: forall s. CBOR.Decoder s SigId
decodeSigId :: forall s. Decoder s SigId
decodeSigId = SigHash -> SigId
SigId (SigHash -> SigId)
-> (ByteString -> SigHash) -> ByteString -> SigId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> SigHash
SigHash (ByteString -> SigId) -> Decoder s ByteString -> Decoder s SigId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Decoder s ByteString
forall s. Decoder s ByteString
CBOR.decodeBytes


-- | We follow the same encoding as in `cardano-ledger` for `OCert`.
--
encodeSigOpCertificate :: Crypto crypto
                       => SigOpCertificate crypto -> CBOR.Encoding
encodeSigOpCertificate :: forall crypto. Crypto crypto => SigOpCertificate crypto -> Encoding
encodeSigOpCertificate (SigOpCertificate OCert crypto
ocert) =
       Word -> Encoding
CBOR.encodeListLen Word
4
    Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> VerKeyKES (KES crypto) -> Encoding
forall v. KESAlgorithm v => VerKeyKES v -> Encoding
encodeVerKeyKES (OCert crypto -> VerKeyKES (KES crypto)
forall c. OCert c -> VerKeyKES (KES c)
ocertVkHot OCert crypto
ocert)
    Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> Word64 -> Encoding
forall a. ToCBOR a => a -> Encoding
toCBOR (OCert crypto -> Word64
forall c. OCert c -> Word64
ocertN OCert crypto
ocert)
    Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> KESPeriod -> Encoding
forall a. ToCBOR a => a -> Encoding
toCBOR (OCert crypto -> KESPeriod
forall c. OCert c -> KESPeriod
ocertKESPeriod OCert crypto
ocert)
    Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> SignedDSIGN (DSIGN crypto) (OCertSignable crypto) -> Encoding
forall v a. DSIGNAlgorithm v => SignedDSIGN v a -> Encoding
encodeSignedDSIGN (OCert crypto -> SignedDSIGN (DSIGN crypto) (OCertSignable crypto)
forall c. OCert c -> SignedDSIGN (DSIGN c) (OCertSignable c)
ocertSigma OCert crypto
ocert)


decodeSigOpCertificate :: forall s crypto. Crypto crypto
                       => CBOR.Decoder s (SigOpCertificate crypto)
decodeSigOpCertificate :: forall s crypto.
Crypto crypto =>
Decoder s (SigOpCertificate crypto)
decodeSigOpCertificate = do
    len <- Decoder s Int
forall s. Decoder s Int
CBOR.decodeListLen
    when (len /= 4) $ fail (printf "decodeSigOpCertificate: unexpected number of parameters %d" len)
    ocertVkHot <- decodeVerKeyKES
    ocertN <- fromCBOR
    ocertKESPeriod <- fromCBOR
    ocertSigma <- decodeSignedDSIGN
    return $ SigOpCertificate $ OCert {
        ocertVkHot,
        ocertN,
        ocertKESPeriod,
        ocertSigma
      }


-- | 'SigSubmission' protocol codec.
--
codecSigSubmission
  :: forall crypto m.
     ( Crypto crypto
     , MonadST m
     )
  => AnnotatedCodec (SigSubmission crypto) CBOR.DeserialiseFailure m ByteString
codecSigSubmission :: forall crypto (m :: * -> *).
(Crypto crypto, MonadST m) =>
AnnotatedCodec
  (SigSubmission crypto) DeserialiseFailure m ByteString
codecSigSubmission =
    (ByteString -> SigRawWithSignedBytes crypto -> Sig crypto)
-> (SigId -> Encoding)
-> (forall s. Decoder s SigId)
-> (Sig crypto -> Encoding)
-> (forall s.
    Decoder s (ByteString -> SigRawWithSignedBytes crypto))
-> AnnotatedCodec
     (TxSubmission2 SigId (Sig crypto)) DeserialiseFailure m ByteString
forall txid tx txWithBytes (m :: * -> *).
MonadST m =>
(ByteString -> tx -> txWithBytes)
-> (txid -> Encoding)
-> (forall s. Decoder s txid)
-> (txWithBytes -> Encoding)
-> (forall s. Decoder s (ByteString -> tx))
-> AnnotatedCodec
     (TxSubmission2 txid txWithBytes) DeserialiseFailure m ByteString
TX.anncodecTxSubmission2'
      ByteString -> SigRawWithSignedBytes crypto -> Sig crypto
forall crypto.
ByteString -> SigRawWithSignedBytes crypto -> Sig crypto
SigWithBytes
      SigId -> Encoding
encodeSigId Decoder s SigId
forall s. Decoder s SigId
decodeSigId
      Sig crypto -> Encoding
forall crypto. Sig crypto -> Encoding
encodeSig   Decoder s (ByteString -> SigRawWithSignedBytes crypto)
forall s. Decoder s (ByteString -> SigRawWithSignedBytes crypto)
forall crypto s.
Crypto crypto =>
Decoder s (ByteString -> SigRawWithSignedBytes crypto)
decodeSig


encodeSig :: Sig crypto -> CBOR.Encoding
encodeSig :: forall crypto. Sig crypto -> Encoding
encodeSig = ByteString -> Encoding
Utils.encodeBytes (ByteString -> Encoding)
-> (Sig crypto -> ByteString) -> Sig crypto -> Encoding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sig crypto -> ByteString
forall crypto. Sig crypto -> ByteString
sigRawBytes

decodeSig :: forall crypto s.
             ( Crypto crypto
             )
          => CBOR.Decoder s (ByteString -> SigRawWithSignedBytes crypto)
decodeSig :: forall crypto s.
Crypto crypto =>
Decoder s (ByteString -> SigRawWithSignedBytes crypto)
decodeSig = do
    a <- Decoder s Int
forall s. Decoder s Int
CBOR.decodeListLen
    when (a /= 4) $ fail (printf "decodeSig: unexpected number of parameters %d for Sig" a)

    -- start of signed data
    startOffset <- CBOR.peekByteOffset
    (sigRawId, sigRawBody, sigRawKESPeriod, sigRawExpiresAt)
      <- decodePayload
    endOffset <- CBOR.peekByteOffset
    -- end of signed data

    sigRawKESSignature <- SigKESSignature <$> CBOR.decodeBytes
    sigRawOpCertificate <- decodeSigOpCertificate
    sigRawColdKey <- SigColdKey <$> CBOR.decodeBytes
    return $ \ByteString
bytes -- ^ full bytes of the message, not just the sig part
           -> SigRawWithSignedBytes {
        sigRawSignedBytes :: ByteString
sigRawSignedBytes = ByteOffset -> ByteOffset -> ByteString -> ByteString
Utils.bytesBetweenOffsets ByteOffset
startOffset ByteOffset
endOffset ByteString
bytes,
        sigRaw :: SigRaw crypto
sigRaw = SigRaw {
          SigId
sigRawId :: SigId
sigRawId :: SigId
sigRawId,
          SigBody
sigRawBody :: SigBody
sigRawBody :: SigBody
sigRawBody,
          SigKESSignature
sigRawKESSignature :: SigKESSignature
sigRawKESSignature :: SigKESSignature
sigRawKESSignature,
          Word
sigRawKESPeriod :: Word
sigRawKESPeriod :: Word
sigRawKESPeriod,
          SigOpCertificate crypto
sigRawOpCertificate :: SigOpCertificate crypto
sigRawOpCertificate :: SigOpCertificate crypto
sigRawOpCertificate,
          SigColdKey
sigRawColdKey :: SigColdKey
sigRawColdKey :: SigColdKey
sigRawColdKey,
          POSIXTime
sigRawExpiresAt :: POSIXTime
sigRawExpiresAt :: POSIXTime
sigRawExpiresAt
        }
      }
  where
    decodePayload :: CBOR.Decoder s (SigId, SigBody, SigKESPeriod, POSIXTime)
    decodePayload :: Decoder s (SigId, SigBody, Word, POSIXTime)
decodePayload = do
      a <- Decoder s Int
forall s. Decoder s Int
CBOR.decodeListLen
      when (a /= 4) $ fail (printf "decodeSig: unexpected number of parameters %d for Sig's payload" a)
      (,,,) <$> decodeSigId
            <*> (SigBody <$> CBOR.decodeBytes)
            <*> CBOR.decodeWord
            <*> (realToFrac <$> CBOR.decodeWord32)



codecSigSubmissionId
  :: Monad m
  => Codec (SigSubmission crypto) CodecFailure m (AnyMessage (SigSubmission crypto))
codecSigSubmissionId :: forall (m :: * -> *) crypto.
Monad m =>
Codec
  (SigSubmission crypto)
  CodecFailure
  m
  (AnyMessage (SigSubmission crypto))
codecSigSubmissionId = Codec
  (TxSubmission2 SigId (Sig crypto))
  CodecFailure
  m
  (AnyMessage (TxSubmission2 SigId (Sig crypto)))
forall txid tx (m :: * -> *).
Monad m =>
Codec
  (TxSubmission2 txid tx)
  CodecFailure
  m
  (AnyMessage (TxSubmission2 txid tx))
TX.codecTxSubmission2Id