{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Ouroboros.Network.Protocol.Handshake.Server (handshakeServerPeer) where

import Codec.CBOR.Term qualified as CBOR

import Network.TypedProtocol.Peer.Server

import Ouroboros.Network.Protocol.Handshake.Client (acceptOrRefuse,
           decodeQueryResult, encodeVersions)
import Ouroboros.Network.Protocol.Handshake.Codec
import Ouroboros.Network.Protocol.Handshake.Type
import Ouroboros.Network.Protocol.Handshake.Version


-- | Server following the handshake protocol; it accepts highest version offered
-- by the peer that also belongs to the server @versions@.
--
handshakeServerPeer
  :: ( Ord vNumber
     )
  => VersionDataCodec CBOR.Term vNumber vData
  -> (vData -> vData -> Accept vData)
  -> (vData -> Bool)
  -> Versions vNumber vData r
  -> Server (Handshake vNumber CBOR.Term)
            NonPipelined StPropose m
            (Either (HandshakeProtocolError vNumber)
                    (HandshakeResult r vNumber vData))
handshakeServerPeer :: forall vNumber vData r (m :: * -> *).
Ord vNumber =>
VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> (vData -> Bool)
-> Versions vNumber vData r
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     'StPropose
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
handshakeServerPeer codec :: VersionDataCodec Term vNumber vData
codec@VersionDataCodec {vNumber -> vData -> Term
encodeData :: vNumber -> vData -> Term
encodeData :: forall bytes vNumber vData.
VersionDataCodec bytes vNumber vData -> vNumber -> vData -> bytes
encodeData, vNumber -> Term -> Either Text vData
decodeData :: vNumber -> Term -> Either Text vData
decodeData :: forall bytes vNumber vData.
VersionDataCodec bytes vNumber vData
-> vNumber -> bytes -> Either Text vData
decodeData} vData -> vData -> Accept vData
acceptVersion vData -> Bool
query Versions vNumber vData r
versions =
    (forall (st' :: Handshake vNumber Term).
 Message (Handshake vNumber Term) 'StPropose st'
 -> Server
      (Handshake vNumber Term)
      'NonPipelined
      st'
      m
      (Either
         (HandshakeProtocolError vNumber)
         (HandshakeResult r vNumber vData)))
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     'StPropose
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
forall ps (pl :: IsPipelined) (st :: ps) (m :: * -> *) a.
(StateTokenI st, StateAgency st ~ 'ClientAgency,
 Outstanding pl ~ 'Z) =>
(forall (st' :: ps). Message ps st st' -> Server ps pl st' m a)
-> Server ps pl st m a
Await ((forall (st' :: Handshake vNumber Term).
  Message (Handshake vNumber Term) 'StPropose st'
  -> Server
       (Handshake vNumber Term)
       'NonPipelined
       st'
       m
       (Either
          (HandshakeProtocolError vNumber)
          (HandshakeResult r vNumber vData)))
 -> Server
      (Handshake vNumber Term)
      'NonPipelined
      'StPropose
      m
      (Either
         (HandshakeProtocolError vNumber)
         (HandshakeResult r vNumber vData)))
-> (forall (st' :: Handshake vNumber Term).
    Message (Handshake vNumber Term) 'StPropose st'
    -> Server
         (Handshake vNumber Term)
         'NonPipelined
         st'
         m
         (Either
            (HandshakeProtocolError vNumber)
            (HandshakeResult r vNumber vData)))
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     'StPropose
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
forall a b. (a -> b) -> a -> b
$ \Message (Handshake vNumber Term) 'StPropose st'
msg -> case Message (Handshake vNumber Term) 'StPropose st'
msg of
      MsgProposeVersions Map vNumber1 vParams1
vMap  ->
        case VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Map vNumber Term
-> Either (RefuseReason vNumber) (r, vNumber, vData)
forall vParams vNumber vData r.
Ord vNumber =>
VersionDataCodec vParams vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Map vNumber vParams
-> Either (RefuseReason vNumber) (r, vNumber, vData)
acceptOrRefuse VersionDataCodec Term vNumber vData
codec vData -> vData -> Accept vData
acceptVersion Versions vNumber vData r
versions Map vNumber Term
Map vNumber1 vParams1
vMap of
          Right (r
_, vNumber
_, vData
agreedData) | vData -> Bool
query vData
agreedData ->
            Message (Handshake vNumber Term) st' 'StDone
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     'StDone
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     st'
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
forall ps (pl :: IsPipelined) (st :: ps) (m :: * -> *) a
       (st' :: ps).
(StateTokenI st, StateTokenI st', StateAgency st ~ 'ServerAgency,
 Outstanding pl ~ 'Z) =>
Message ps st st' -> Server ps pl st' m a -> Server ps pl st m a
Yield (Map vNumber Term
-> Message (Handshake vNumber Term) 'StConfirm 'StDone
forall vNumber1 vParams1.
Map vNumber1 vParams1
-> Message (Handshake vNumber1 vParams1) 'StConfirm 'StDone
MsgQueryReply (Map vNumber Term
 -> Message (Handshake vNumber Term) 'StConfirm 'StDone)
-> Map vNumber Term
-> Message (Handshake vNumber Term) 'StConfirm 'StDone
forall a b. (a -> b) -> a -> b
$ (vNumber -> vData -> Term)
-> Versions vNumber vData r -> Map vNumber Term
forall vNumber r vParams vData.
(vNumber -> vData -> vParams)
-> Versions vNumber vData r -> Map vNumber vParams
encodeVersions vNumber -> vData -> Term
encodeData Versions vNumber vData r
versions)
                  (Either
  (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData)
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     'StDone
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
forall ps (pl :: IsPipelined) (st :: ps) (m :: * -> *) a.
(StateTokenI st, StateAgency st ~ 'NobodyAgency,
 Outstanding pl ~ 'Z) =>
a -> Server ps pl st m a
Done (HandshakeResult r vNumber vData
-> Either
     (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData)
forall a b. b -> Either a b
Right (HandshakeResult r vNumber vData
 -> Either
      (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
-> HandshakeResult r vNumber vData
-> Either
     (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData)
forall a b. (a -> b) -> a -> b
$ (vNumber -> Term -> Either Text vData)
-> Map vNumber Term -> HandshakeResult r vNumber vData
forall vNumber bytes vData r.
(vNumber -> bytes -> Either Text vData)
-> Map vNumber bytes -> HandshakeResult r vNumber vData
decodeQueryResult vNumber -> Term -> Either Text vData
decodeData Map vNumber Term
Map vNumber1 vParams1
vMap))

          Right (r
r, vNumber
vNumber, vData
agreedData) ->
            Message (Handshake vNumber Term) st' 'StDone
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     'StDone
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     st'
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
forall ps (pl :: IsPipelined) (st :: ps) (m :: * -> *) a
       (st' :: ps).
(StateTokenI st, StateTokenI st', StateAgency st ~ 'ServerAgency,
 Outstanding pl ~ 'Z) =>
Message ps st st' -> Server ps pl st' m a -> Server ps pl st m a
Yield (vNumber
-> Term -> Message (Handshake vNumber Term) 'StConfirm 'StDone
forall vNumber1 vParams1.
vNumber1
-> vParams1
-> Message (Handshake vNumber1 vParams1) 'StConfirm 'StDone
MsgAcceptVersion vNumber
vNumber (Term -> Message (Handshake vNumber Term) 'StConfirm 'StDone)
-> Term -> Message (Handshake vNumber Term) 'StConfirm 'StDone
forall a b. (a -> b) -> a -> b
$ vNumber -> vData -> Term
encodeData vNumber
vNumber vData
agreedData)
                  (Either
  (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData)
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     'StDone
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
forall ps (pl :: IsPipelined) (st :: ps) (m :: * -> *) a.
(StateTokenI st, StateAgency st ~ 'NobodyAgency,
 Outstanding pl ~ 'Z) =>
a -> Server ps pl st m a
Done (HandshakeResult r vNumber vData
-> Either
     (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData)
forall a b. b -> Either a b
Right (HandshakeResult r vNumber vData
 -> Either
      (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
-> HandshakeResult r vNumber vData
-> Either
     (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData)
forall a b. (a -> b) -> a -> b
$ r -> vNumber -> vData -> HandshakeResult r vNumber vData
forall r vNumber vData.
r -> vNumber -> vData -> HandshakeResult r vNumber vData
HandshakeNegotiationResult r
r vNumber
vNumber vData
agreedData))

          Left RefuseReason vNumber
vReason ->
            Message (Handshake vNumber Term) st' 'StDone
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     'StDone
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     st'
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
forall ps (pl :: IsPipelined) (st :: ps) (m :: * -> *) a
       (st' :: ps).
(StateTokenI st, StateTokenI st', StateAgency st ~ 'ServerAgency,
 Outstanding pl ~ 'Z) =>
Message ps st st' -> Server ps pl st' m a -> Server ps pl st m a
Yield (RefuseReason vNumber
-> Message (Handshake vNumber Term) 'StConfirm 'StDone
forall {k1} vNumber1 (vParams :: k1).
RefuseReason vNumber1
-> Message (Handshake vNumber1 vParams) 'StConfirm 'StDone
MsgRefuse RefuseReason vNumber
vReason)
                  (Either
  (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData)
-> Server
     (Handshake vNumber Term)
     'NonPipelined
     'StDone
     m
     (Either
        (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
forall ps (pl :: IsPipelined) (st :: ps) (m :: * -> *) a.
(StateTokenI st, StateAgency st ~ 'NobodyAgency,
 Outstanding pl ~ 'Z) =>
a -> Server ps pl st m a
Done (HandshakeProtocolError vNumber
-> Either
     (HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData)
forall a b. a -> Either a b
Left (RefuseReason vNumber -> HandshakeProtocolError vNumber
forall vNumber.
RefuseReason vNumber -> HandshakeProtocolError vNumber
HandshakeError RefuseReason vNumber
vReason)))