{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Ouroboros.Network.Protocol.Handshake
( runHandshakeClient
, runHandshakeServer
, HandshakeArguments (..)
, Versions (..)
, HandshakeException (..)
, HandshakeProtocolError (..)
, HandshakeResult (..)
, RefuseReason (..)
, Accept (..)
, handshake_QUERY_SHUTDOWN_DELAY
) where
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTimer.SI
import Codec.CBOR.Read qualified as CBOR
import Codec.CBOR.Term qualified as CBOR
import Control.Tracer (Tracer, contramap)
import Data.ByteString.Lazy qualified as BL
import Network.Mux.Trace qualified as Mx
import Network.Mux.Types qualified as Mx
import Network.TypedProtocol.Codec
import Ouroboros.Network.Driver.Limits
import Ouroboros.Network.Protocol.Handshake.Client
import Ouroboros.Network.Protocol.Handshake.Codec
import Ouroboros.Network.Protocol.Handshake.Server
import Ouroboros.Network.Protocol.Handshake.Type
import Ouroboros.Network.Protocol.Handshake.Version
handshakeProtocolNum :: Mx.MiniProtocolNum
handshakeProtocolNum :: MiniProtocolNum
handshakeProtocolNum = Word16 -> MiniProtocolNum
Mx.MiniProtocolNum Word16
0
data HandshakeException vNumber =
HandshakeProtocolLimit ProtocolLimitFailure
| HandshakeProtocolError (HandshakeProtocolError vNumber)
deriving Int -> HandshakeException vNumber -> ShowS
[HandshakeException vNumber] -> ShowS
HandshakeException vNumber -> String
(Int -> HandshakeException vNumber -> ShowS)
-> (HandshakeException vNumber -> String)
-> ([HandshakeException vNumber] -> ShowS)
-> Show (HandshakeException vNumber)
forall vNumber.
Show vNumber =>
Int -> HandshakeException vNumber -> ShowS
forall vNumber.
Show vNumber =>
[HandshakeException vNumber] -> ShowS
forall vNumber.
Show vNumber =>
HandshakeException vNumber -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall vNumber.
Show vNumber =>
Int -> HandshakeException vNumber -> ShowS
showsPrec :: Int -> HandshakeException vNumber -> ShowS
$cshow :: forall vNumber.
Show vNumber =>
HandshakeException vNumber -> String
show :: HandshakeException vNumber -> String
$cshowList :: forall vNumber.
Show vNumber =>
[HandshakeException vNumber] -> ShowS
showList :: [HandshakeException vNumber] -> ShowS
Show
tryHandshake :: forall m vNumber r.
( MonadAsync m
, MonadMask m
)
=> m (Either (HandshakeProtocolError vNumber) r)
-> m (Either (HandshakeException vNumber) r)
tryHandshake :: forall (m :: * -> *) vNumber r.
(MonadAsync m, MonadMask m) =>
m (Either (HandshakeProtocolError vNumber) r)
-> m (Either (HandshakeException vNumber) r)
tryHandshake m (Either (HandshakeProtocolError vNumber) r)
doHandshake = do
mapp <- m (Either (HandshakeProtocolError vNumber) r)
-> m (Either
ProtocolLimitFailure (Either (HandshakeProtocolError vNumber) r))
forall e a. Exception e => m a -> m (Either e a)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try m (Either (HandshakeProtocolError vNumber) r)
doHandshake
case mapp of
Left ProtocolLimitFailure
err ->
Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r))
-> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a b. (a -> b) -> a -> b
$ HandshakeException vNumber -> Either (HandshakeException vNumber) r
forall a b. a -> Either a b
Left (HandshakeException vNumber
-> Either (HandshakeException vNumber) r)
-> HandshakeException vNumber
-> Either (HandshakeException vNumber) r
forall a b. (a -> b) -> a -> b
$ ProtocolLimitFailure -> HandshakeException vNumber
forall vNumber. ProtocolLimitFailure -> HandshakeException vNumber
HandshakeProtocolLimit ProtocolLimitFailure
err
Right (Left HandshakeProtocolError vNumber
err) ->
Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r))
-> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a b. (a -> b) -> a -> b
$ HandshakeException vNumber -> Either (HandshakeException vNumber) r
forall a b. a -> Either a b
Left (HandshakeException vNumber
-> Either (HandshakeException vNumber) r)
-> HandshakeException vNumber
-> Either (HandshakeException vNumber) r
forall a b. (a -> b) -> a -> b
$ HandshakeProtocolError vNumber -> HandshakeException vNumber
forall vNumber.
HandshakeProtocolError vNumber -> HandshakeException vNumber
HandshakeProtocolError HandshakeProtocolError vNumber
err
Right (Right r
r) -> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r))
-> Either (HandshakeException vNumber) r
-> m (Either (HandshakeException vNumber) r)
forall a b. (a -> b) -> a -> b
$ r -> Either (HandshakeException vNumber) r
forall a b. b -> Either a b
Right r
r
data HandshakeArguments connectionId vNumber vData m = HandshakeArguments {
forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Tracer
m
(WithBearer connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer m (Mx.WithBearer connectionId
(TraceSendRecv (Handshake vNumber CBOR.Term))),
forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
:: Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure m BL.ByteString,
forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> VersionDataCodec Term vNumber vData
haVersionDataCodec
:: VersionDataCodec CBOR.Term vNumber vData,
forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> vData -> vData -> Accept vData
haAcceptVersion :: vData -> vData -> Accept vData,
forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m -> vData -> Bool
haQueryVersion :: vData -> Bool,
forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
:: ProtocolTimeLimits (Handshake vNumber CBOR.Term)
}
runHandshakeClient
:: ( MonadAsync m
, MonadFork m
, MonadTimer m
, MonadMask m
, MonadThrow (STM m)
, Ord vNumber
)
=> Mx.Bearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either (HandshakeException vNumber)
(HandshakeResult application vNumber vData))
runHandshakeClient :: forall (m :: * -> *) vNumber connectionId vData application.
(MonadAsync m, MonadFork m, MonadTimer m, MonadMask m,
MonadThrow (STM m), Ord vNumber) =>
Bearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either
(HandshakeException vNumber)
(HandshakeResult application vNumber vData))
runHandshakeClient Bearer m
bearer
connectionId
connectionId
HandshakeArguments {
Tracer
m
(WithBearer connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Tracer
m
(WithBearer connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer
m
(WithBearer connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer,
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec,
VersionDataCodec Term vNumber vData
haVersionDataCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> VersionDataCodec Term vNumber vData
haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec,
vData -> vData -> Accept vData
haAcceptVersion :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> vData -> vData -> Accept vData
haAcceptVersion :: vData -> vData -> Accept vData
haAcceptVersion,
ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
}
Versions vNumber vData application
versions =
m (Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData))
-> m (Either
(HandshakeException vNumber)
(HandshakeResult application vNumber vData))
forall (m :: * -> *) vNumber r.
(MonadAsync m, MonadMask m) =>
m (Either (HandshakeProtocolError vNumber) r)
-> m (Either (HandshakeException vNumber) r)
tryHandshake
((Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData),
Maybe ByteString)
-> Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData)
forall a b. (a, b) -> a
fst ((Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData),
Maybe ByteString)
-> Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData))
-> m (Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData),
Maybe ByteString)
-> m (Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
Tracer m (TraceSendRecv (Handshake vNumber Term))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> ProtocolSizeLimits (Handshake vNumber Term) ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> Channel m ByteString
-> Peer
(Handshake vNumber Term)
'AsClient
'NonPipelined
'StPropose
m
(Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData))
-> m (Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData),
Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
MonadTimer m, ShowProxy ps,
forall (st' :: ps) stok. (stok ~ StateToken st') => Show stok,
Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr 'NonPipelined st m a
-> m (a, Maybe bytes)
runPeerWithLimits
(connectionId
-> TraceSendRecv (Handshake vNumber Term)
-> WithBearer connectionId (TraceSendRecv (Handshake vNumber Term))
forall peerid a. peerid -> a -> WithBearer peerid a
Mx.WithBearer connectionId
connectionId (TraceSendRecv (Handshake vNumber Term)
-> WithBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer
m
(WithBearer connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer m (TraceSendRecv (Handshake vNumber Term))
forall a' a. (a' -> a) -> Tracer m a -> Tracer m a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
`contramap` Tracer
m
(WithBearer connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer)
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
ProtocolSizeLimits (Handshake vNumber Term) ByteString
forall {k} (vNumber :: k).
ProtocolSizeLimits (Handshake vNumber Term) ByteString
byteLimitsHandshake
ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
(Bearer m
-> MiniProtocolNum -> MiniProtocolDir -> Channel m ByteString
forall (m :: * -> *).
Functor m =>
Bearer m -> MiniProtocolNum -> MiniProtocolDir -> ByteChannel m
Mx.bearerAsChannel Bearer m
bearer MiniProtocolNum
handshakeProtocolNum MiniProtocolDir
Mx.InitiatorDir)
(VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData application
-> Peer
(Handshake vNumber Term)
'AsClient
'NonPipelined
'StPropose
m
(Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData))
forall vNumber vData r (m :: * -> *).
Ord vNumber =>
VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Client
(Handshake vNumber Term)
'NonPipelined
'StPropose
m
(Either
(HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
handshakeClientPeer VersionDataCodec Term vNumber vData
haVersionDataCodec vData -> vData -> Accept vData
haAcceptVersion Versions vNumber vData application
versions))
runHandshakeServer
:: ( MonadAsync m
, MonadFork m
, MonadTimer m
, MonadMask m
, MonadThrow (STM m)
, Ord vNumber
)
=> Mx.Bearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either (HandshakeException vNumber)
(HandshakeResult application vNumber vData))
runHandshakeServer :: forall (m :: * -> *) vNumber connectionId vData application.
(MonadAsync m, MonadFork m, MonadTimer m, MonadMask m,
MonadThrow (STM m), Ord vNumber) =>
Bearer m
-> connectionId
-> HandshakeArguments connectionId vNumber vData m
-> Versions vNumber vData application
-> m (Either
(HandshakeException vNumber)
(HandshakeResult application vNumber vData))
runHandshakeServer Bearer m
bearer
connectionId
connectionId
HandshakeArguments {
Tracer
m
(WithBearer connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Tracer
m
(WithBearer connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer
m
(WithBearer connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer,
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec,
VersionDataCodec Term vNumber vData
haVersionDataCodec :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> VersionDataCodec Term vNumber vData
haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec,
vData -> vData -> Accept vData
haAcceptVersion :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> vData -> vData -> Accept vData
haAcceptVersion :: vData -> vData -> Accept vData
haAcceptVersion,
vData -> Bool
haQueryVersion :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m -> vData -> Bool
haQueryVersion :: vData -> Bool
haQueryVersion,
ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: forall connectionId vNumber vData (m :: * -> *).
HandshakeArguments connectionId vNumber vData m
-> ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits :: ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
}
Versions vNumber vData application
versions =
m (Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData))
-> m (Either
(HandshakeException vNumber)
(HandshakeResult application vNumber vData))
forall (m :: * -> *) vNumber r.
(MonadAsync m, MonadMask m) =>
m (Either (HandshakeProtocolError vNumber) r)
-> m (Either (HandshakeException vNumber) r)
tryHandshake
((Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData),
Maybe ByteString)
-> Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData)
forall a b. (a, b) -> a
fst ((Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData),
Maybe ByteString)
-> Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData))
-> m (Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData),
Maybe ByteString)
-> m (Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
Tracer m (TraceSendRecv (Handshake vNumber Term))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> ProtocolSizeLimits (Handshake vNumber Term) ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> Channel m ByteString
-> Peer
(Handshake vNumber Term)
'AsServer
'NonPipelined
'StPropose
m
(Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData))
-> m (Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData),
Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
MonadTimer m, ShowProxy ps,
forall (st' :: ps) stok. (stok ~ StateToken st') => Show stok,
Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr 'NonPipelined st m a
-> m (a, Maybe bytes)
runPeerWithLimits
(connectionId
-> TraceSendRecv (Handshake vNumber Term)
-> WithBearer connectionId (TraceSendRecv (Handshake vNumber Term))
forall peerid a. peerid -> a -> WithBearer peerid a
Mx.WithBearer connectionId
connectionId (TraceSendRecv (Handshake vNumber Term)
-> WithBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer
m
(WithBearer connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer m (TraceSendRecv (Handshake vNumber Term))
forall a' a. (a' -> a) -> Tracer m a -> Tracer m a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
`contramap` Tracer
m
(WithBearer connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer)
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
ProtocolSizeLimits (Handshake vNumber Term) ByteString
forall {k} (vNumber :: k).
ProtocolSizeLimits (Handshake vNumber Term) ByteString
byteLimitsHandshake
ProtocolTimeLimits (Handshake vNumber Term)
haTimeLimits
(Bearer m
-> MiniProtocolNum -> MiniProtocolDir -> Channel m ByteString
forall (m :: * -> *).
Functor m =>
Bearer m -> MiniProtocolNum -> MiniProtocolDir -> ByteChannel m
Mx.bearerAsChannel Bearer m
bearer MiniProtocolNum
handshakeProtocolNum MiniProtocolDir
Mx.ResponderDir)
(VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> (vData -> Bool)
-> Versions vNumber vData application
-> Peer
(Handshake vNumber Term)
'AsServer
'NonPipelined
'StPropose
m
(Either
(HandshakeProtocolError vNumber)
(HandshakeResult application vNumber vData))
forall vNumber vData r (m :: * -> *).
Ord vNumber =>
VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> (vData -> Bool)
-> Versions vNumber vData r
-> Server
(Handshake vNumber Term)
'NonPipelined
'StPropose
m
(Either
(HandshakeProtocolError vNumber) (HandshakeResult r vNumber vData))
handshakeServerPeer VersionDataCodec Term vNumber vData
haVersionDataCodec vData -> vData -> Accept vData
haAcceptVersion vData -> Bool
haQueryVersion Versions vNumber vData application
versions))
handshake_QUERY_SHUTDOWN_DELAY :: DiffTime
handshake_QUERY_SHUTDOWN_DELAY :: DiffTime
handshake_QUERY_SHUTDOWN_DELAY = DiffTime
20