{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Ouroboros.Network.Server.ConnectionTable
  ( ConnectionTable
  , ConnectionTableRef (..)
  , ConnectionDirection (..)
  , ValencyCounter
  , newConnectionTableSTM
  , newConnectionTable
  , refConnectionSTM
  , refConnection
  , addConnection
  , removeConnectionSTM
  , removeConnection
  , newValencyCounter
  , addValencyCounter
  , remValencyCounter
  , waitValencyCounter
  , readValencyCounter
  ) where

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad (when)
--import           Control.Tracer XXX Not Yet
import Data.Map.Strict qualified as M
import Data.Set (Set)
import Data.Set qualified as S
import Network.Socket qualified as Socket
import Text.Printf

-- A ConnectionTable represent a set of connections that is shared between
-- servers and subscription workers. It's main purpose is to avoid the creation of duplicate
-- connections (especially connections with identical source address, destination address, source
-- port and destination port which would be rejected by the kernel anyway.).
-- It is only used for bookkeeping, the sockets represented by the connections are not accessable
-- through this structure.
--
data ConnectionTable m addr = ConnectionTable {
    forall (m :: * -> *) addr.
ConnectionTable m addr
-> StrictTVar
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable     :: StrictTVar m (M.Map (addr,ConnectionDirection) (ConnectionTableEntry m addr))
  , forall (m :: * -> *) addr.
ConnectionTable m addr -> StrictTVar m Int
ctLastRefId :: StrictTVar m Int
  }

-- | ValencyCounter represents how many active connections we have towards a given peer.
-- It starts out with a positive value representing a desired number of connections for a specific
-- subscription worker. It can become negative, for example if a peer opens multiple connections
-- to us.
-- The vcId is unique per ConnectionTable and ensures that we won't count the same connection twice.
--
data ValencyCounter m = ValencyCounter {
    forall (m :: * -> *). ValencyCounter m -> Int
vcId  :: Int
  , forall (m :: * -> *). ValencyCounter m -> StrictTVar m Int
vcRef :: StrictTVar m Int
  }

-- | Tracks connection direction. Used to differentiate between outbound connections we are
-- required to create and inbound connections from the same peer.
data ConnectionDirection = ConnectionInbound | ConnectionOutbound
  deriving (ConnectionDirection -> ConnectionDirection -> Bool
(ConnectionDirection -> ConnectionDirection -> Bool)
-> (ConnectionDirection -> ConnectionDirection -> Bool)
-> Eq ConnectionDirection
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectionDirection -> ConnectionDirection -> Bool
== :: ConnectionDirection -> ConnectionDirection -> Bool
$c/= :: ConnectionDirection -> ConnectionDirection -> Bool
/= :: ConnectionDirection -> ConnectionDirection -> Bool
Eq, Eq ConnectionDirection
Eq ConnectionDirection =>
(ConnectionDirection -> ConnectionDirection -> Ordering)
-> (ConnectionDirection -> ConnectionDirection -> Bool)
-> (ConnectionDirection -> ConnectionDirection -> Bool)
-> (ConnectionDirection -> ConnectionDirection -> Bool)
-> (ConnectionDirection -> ConnectionDirection -> Bool)
-> (ConnectionDirection
    -> ConnectionDirection -> ConnectionDirection)
-> (ConnectionDirection
    -> ConnectionDirection -> ConnectionDirection)
-> Ord ConnectionDirection
ConnectionDirection -> ConnectionDirection -> Bool
ConnectionDirection -> ConnectionDirection -> Ordering
ConnectionDirection -> ConnectionDirection -> ConnectionDirection
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ConnectionDirection -> ConnectionDirection -> Ordering
compare :: ConnectionDirection -> ConnectionDirection -> Ordering
$c< :: ConnectionDirection -> ConnectionDirection -> Bool
< :: ConnectionDirection -> ConnectionDirection -> Bool
$c<= :: ConnectionDirection -> ConnectionDirection -> Bool
<= :: ConnectionDirection -> ConnectionDirection -> Bool
$c> :: ConnectionDirection -> ConnectionDirection -> Bool
> :: ConnectionDirection -> ConnectionDirection -> Bool
$c>= :: ConnectionDirection -> ConnectionDirection -> Bool
>= :: ConnectionDirection -> ConnectionDirection -> Bool
$cmax :: ConnectionDirection -> ConnectionDirection -> ConnectionDirection
max :: ConnectionDirection -> ConnectionDirection -> ConnectionDirection
$cmin :: ConnectionDirection -> ConnectionDirection -> ConnectionDirection
min :: ConnectionDirection -> ConnectionDirection -> ConnectionDirection
Ord, Int -> ConnectionDirection -> ShowS
[ConnectionDirection] -> ShowS
ConnectionDirection -> String
(Int -> ConnectionDirection -> ShowS)
-> (ConnectionDirection -> String)
-> ([ConnectionDirection] -> ShowS)
-> Show ConnectionDirection
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectionDirection -> ShowS
showsPrec :: Int -> ConnectionDirection -> ShowS
$cshow :: ConnectionDirection -> String
show :: ConnectionDirection -> String
$cshowList :: [ConnectionDirection] -> ShowS
showList :: [ConnectionDirection] -> ShowS
Show)

-- | Create a new ValencyCounter
newValencyCounter
  :: MonadSTM m
  => ConnectionTable m addr
  -> Int
  -- ^ Desired valency, that is number of connections a subscription worker will attempt to
  -- maintain.
  -> STM m (ValencyCounter m)
newValencyCounter :: forall (m :: * -> *) addr.
MonadSTM m =>
ConnectionTable m addr -> Int -> STM m (ValencyCounter m)
newValencyCounter ConnectionTable m addr
tbl Int
valency =  do
  Int
lr <- StrictTVar m Int -> STM m Int
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar (StrictTVar m Int -> STM m Int) -> StrictTVar m Int -> STM m Int
forall a b. (a -> b) -> a -> b
$ ConnectionTable m addr -> StrictTVar m Int
forall (m :: * -> *) addr.
ConnectionTable m addr -> StrictTVar m Int
ctLastRefId ConnectionTable m addr
tbl
  let !lr' :: Int
lr' = Int
lr Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  StrictTVar m Int -> Int -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar (ConnectionTable m addr -> StrictTVar m Int
forall (m :: * -> *) addr.
ConnectionTable m addr -> StrictTVar m Int
ctLastRefId ConnectionTable m addr
tbl) Int
lr'
  StrictTVar m Int
v <- Int -> STM m (StrictTVar m Int)
forall (m :: * -> *) a. MonadSTM m => a -> STM m (StrictTVar m a)
newTVar Int
valency
  ValencyCounter m -> STM m (ValencyCounter m)
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ValencyCounter m -> STM m (ValencyCounter m))
-> ValencyCounter m -> STM m (ValencyCounter m)
forall a b. (a -> b) -> a -> b
$ Int -> StrictTVar m Int -> ValencyCounter m
forall (m :: * -> *). Int -> StrictTVar m Int -> ValencyCounter m
ValencyCounter Int
lr' StrictTVar m Int
v

instance Ord (ValencyCounter m) where
    compare :: ValencyCounter m -> ValencyCounter m -> Ordering
compare ValencyCounter m
a ValencyCounter m
b = Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (ValencyCounter m -> Int
forall (m :: * -> *). ValencyCounter m -> Int
vcId ValencyCounter m
a) (ValencyCounter m -> Int
forall (m :: * -> *). ValencyCounter m -> Int
vcId ValencyCounter m
b)

instance Eq (ValencyCounter m) where
    == :: ValencyCounter m -> ValencyCounter m -> Bool
(==) ValencyCounter m
a ValencyCounter m
b = ValencyCounter m -> Int
forall (m :: * -> *). ValencyCounter m -> Int
vcId ValencyCounter m
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ValencyCounter m -> Int
forall (m :: * -> *). ValencyCounter m -> Int
vcId ValencyCounter m
b

-- | Returns current ValencyCounter value, represent the number of additional connections that
-- can be created. May be negative.
readValencyCounter :: MonadSTM m => ValencyCounter m -> STM m Int
readValencyCounter :: forall (m :: * -> *). MonadSTM m => ValencyCounter m -> STM m Int
readValencyCounter ValencyCounter m
vc = StrictTVar m Int -> STM m Int
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar (StrictTVar m Int -> STM m Int) -> StrictTVar m Int -> STM m Int
forall a b. (a -> b) -> a -> b
$ ValencyCounter m -> StrictTVar m Int
forall (m :: * -> *). ValencyCounter m -> StrictTVar m Int
vcRef ValencyCounter m
vc

data ConnectionTableEntry m addr = ConnectionTableEntry {
    -- | Set of ValencyCounter's for subscriptions interested in this peer.
      forall (m :: * -> *) addr.
ConnectionTableEntry m addr -> Set (ValencyCounter m)
cteRefs           :: !(Set (ValencyCounter m))
    -- | Set of local SockAddr connected to this peer.
    , forall (m :: * -> *) addr. ConnectionTableEntry m addr -> Set addr
cteLocalAddresses :: !(Set addr)
    }

data ConnectionTableRef =
    ConnectionTableCreate
  -- ^ No connection to peer exists, attempt to create one.
  | ConnectionTableExist
  -- ^ A connection to the peer existed, either from another subscriber or the peer opened one
  -- towards us.
  | ConnectionTableDuplicate
  -- ^ This subscriber already has counted a connection to this peer. It must try another target.
  deriving Int -> ConnectionTableRef -> ShowS
[ConnectionTableRef] -> ShowS
ConnectionTableRef -> String
(Int -> ConnectionTableRef -> ShowS)
-> (ConnectionTableRef -> String)
-> ([ConnectionTableRef] -> ShowS)
-> Show ConnectionTableRef
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectionTableRef -> ShowS
showsPrec :: Int -> ConnectionTableRef -> ShowS
$cshow :: ConnectionTableRef -> String
show :: ConnectionTableRef -> String
$cshowList :: [ConnectionTableRef] -> ShowS
showList :: [ConnectionTableRef] -> ShowS
Show

-- | Add a connection.
addValencyCounter :: MonadSTM m => ValencyCounter m -> STM m ()
addValencyCounter :: forall (m :: * -> *). MonadSTM m => ValencyCounter m -> STM m ()
addValencyCounter ValencyCounter m
vc = StrictTVar m Int -> (Int -> Int) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar (ValencyCounter m -> StrictTVar m Int
forall (m :: * -> *). ValencyCounter m -> StrictTVar m Int
vcRef ValencyCounter m
vc) (\Int
r -> Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- | Remove a connection.
remValencyCounter :: MonadSTM m => ValencyCounter m -> STM m ()
remValencyCounter :: forall (m :: * -> *). MonadSTM m => ValencyCounter m -> STM m ()
remValencyCounter ValencyCounter m
vc = StrictTVar m Int -> (Int -> Int) -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar (ValencyCounter m -> StrictTVar m Int
forall (m :: * -> *). ValencyCounter m -> StrictTVar m Int
vcRef ValencyCounter m
vc) (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- | Wait until ValencyCounter becomes positive, used for detecting when
-- we can create new connections.
waitValencyCounter :: MonadSTM m => ValencyCounter m -> STM m ()
waitValencyCounter :: forall (m :: * -> *). MonadSTM m => ValencyCounter m -> STM m ()
waitValencyCounter ValencyCounter m
vc = do
  Int
v <- StrictTVar m Int -> STM m Int
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar (StrictTVar m Int -> STM m Int) -> StrictTVar m Int -> STM m Int
forall a b. (a -> b) -> a -> b
$ ValencyCounter m -> StrictTVar m Int
forall (m :: * -> *). ValencyCounter m -> StrictTVar m Int
vcRef ValencyCounter m
vc
  Bool -> STM m () -> STM m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0)
    STM m ()
forall a. STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry

-- | Create a new ConnectionTable.
newConnectionTableSTM :: MonadSTM m => STM m (ConnectionTable m addr)
newConnectionTableSTM :: forall (m :: * -> *) addr.
MonadSTM m =>
STM m (ConnectionTable m addr)
newConnectionTableSTM =  do
    StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
tbl <- Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
-> STM
     m
     (StrictTVar
        m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)))
forall (m :: * -> *) a. MonadSTM m => a -> STM m (StrictTVar m a)
newTVar Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
forall k a. Map k a
M.empty
    StrictTVar m Int
li <- Int -> STM m (StrictTVar m Int)
forall (m :: * -> *) a. MonadSTM m => a -> STM m (StrictTVar m a)
newTVar Int
0
    ConnectionTable m addr -> STM m (ConnectionTable m addr)
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnectionTable m addr -> STM m (ConnectionTable m addr))
-> ConnectionTable m addr -> STM m (ConnectionTable m addr)
forall a b. (a -> b) -> a -> b
$ StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> StrictTVar m Int -> ConnectionTable m addr
forall (m :: * -> *) addr.
StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> StrictTVar m Int -> ConnectionTable m addr
ConnectionTable StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
tbl StrictTVar m Int
li

newConnectionTable :: MonadSTM m => m (ConnectionTable m addr)
newConnectionTable :: forall (m :: * -> *) addr. MonadSTM m => m (ConnectionTable m addr)
newConnectionTable = STM m (ConnectionTable m addr) -> m (ConnectionTable m addr)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m (ConnectionTable m addr)
forall (m :: * -> *) addr.
MonadSTM m =>
STM m (ConnectionTable m addr)
newConnectionTableSTM

-- | Insert a new connection into the ConnectionTable.
addConnection
    :: forall m addr.
       ( MonadSTM m
       , Ord addr
       )
    => ConnectionTable m addr
    -> addr
    -> addr
    -> ConnectionDirection
    -> Maybe (ValencyCounter m)
    -- ^ Optional ValencyCounter, used by subscription worker and set to Nothing when
    -- called by a local server.
    -> STM m ()
addConnection :: forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
ConnectionTable m addr
-> addr
-> addr
-> ConnectionDirection
-> Maybe (ValencyCounter m)
-> STM m ()
addConnection ConnectionTable{StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable :: forall (m :: * -> *) addr.
ConnectionTable m addr
-> StrictTVar
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable :: StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable} addr
remoteAddr addr
localAddr ConnectionDirection
dir Maybe (ValencyCounter m)
ref_m = do
    StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> STM
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable STM
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
    -> STM
         m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)))
-> STM
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Maybe (ConnectionTableEntry m addr)
 -> STM m (Maybe (ConnectionTableEntry m addr)))
-> (addr, ConnectionDirection)
-> Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
-> STM
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
forall (f :: * -> *) k a.
(Functor f, Ord k) =>
(Maybe a -> f (Maybe a)) -> k -> Map k a -> f (Map k a)
M.alterF Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
fn (addr
remoteAddr,ConnectionDirection
dir) STM
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
    -> STM m ())
-> STM m ()
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
-> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable
  where
    fn :: Maybe (ConnectionTableEntry m addr) -> STM m (Maybe (ConnectionTableEntry m addr))
    fn :: Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
fn Maybe (ConnectionTableEntry m addr)
Nothing = do
        Set (ValencyCounter m)
refs <- case Maybe (ValencyCounter m)
ref_m of
                     Just ValencyCounter m
ref -> do
                         ValencyCounter m -> STM m ()
forall (m :: * -> *). MonadSTM m => ValencyCounter m -> STM m ()
addValencyCounter ValencyCounter m
ref
                         Set (ValencyCounter m) -> STM m (Set (ValencyCounter m))
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Set (ValencyCounter m) -> STM m (Set (ValencyCounter m)))
-> Set (ValencyCounter m) -> STM m (Set (ValencyCounter m))
forall a b. (a -> b) -> a -> b
$ ValencyCounter m -> Set (ValencyCounter m)
forall a. a -> Set a
S.singleton ValencyCounter m
ref
                     Maybe (ValencyCounter m)
Nothing -> Set (ValencyCounter m) -> STM m (Set (ValencyCounter m))
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return Set (ValencyCounter m)
forall a. Set a
S.empty
        Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ConnectionTableEntry m addr)
 -> STM m (Maybe (ConnectionTableEntry m addr)))
-> Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
forall a b. (a -> b) -> a -> b
$ ConnectionTableEntry m addr -> Maybe (ConnectionTableEntry m addr)
forall a. a -> Maybe a
Just (ConnectionTableEntry m addr
 -> Maybe (ConnectionTableEntry m addr))
-> ConnectionTableEntry m addr
-> Maybe (ConnectionTableEntry m addr)
forall a b. (a -> b) -> a -> b
$ Set (ValencyCounter m) -> Set addr -> ConnectionTableEntry m addr
forall (m :: * -> *) addr.
Set (ValencyCounter m) -> Set addr -> ConnectionTableEntry m addr
ConnectionTableEntry Set (ValencyCounter m)
refs (addr -> Set addr
forall a. a -> Set a
S.singleton addr
localAddr)
    fn (Just ConnectionTableEntry m addr
cte) = do
          let refs' :: Set (ValencyCounter m)
refs' = case Maybe (ValencyCounter m)
ref_m of
                           Just ValencyCounter m
ref -> ValencyCounter m
-> Set (ValencyCounter m) -> Set (ValencyCounter m)
forall a. Ord a => a -> Set a -> Set a
S.insert ValencyCounter m
ref (ConnectionTableEntry m addr -> Set (ValencyCounter m)
forall (m :: * -> *) addr.
ConnectionTableEntry m addr -> Set (ValencyCounter m)
cteRefs ConnectionTableEntry m addr
cte)
                           Maybe (ValencyCounter m)
Nothing  -> ConnectionTableEntry m addr -> Set (ValencyCounter m)
forall (m :: * -> *) addr.
ConnectionTableEntry m addr -> Set (ValencyCounter m)
cteRefs ConnectionTableEntry m addr
cte
          -- Signal to all parties (dnsSubscriptionWorkers) that are interested in tracking the
          -- number of connections to this particlar peer that we've created a new connection.
          (ValencyCounter m -> STM m ())
-> Set (ValencyCounter m) -> STM m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ValencyCounter m -> STM m ()
forall (m :: * -> *). MonadSTM m => ValencyCounter m -> STM m ()
addValencyCounter Set (ValencyCounter m)
refs'
          Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ConnectionTableEntry m addr)
 -> STM m (Maybe (ConnectionTableEntry m addr)))
-> Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
forall a b. (a -> b) -> a -> b
$ ConnectionTableEntry m addr -> Maybe (ConnectionTableEntry m addr)
forall a. a -> Maybe a
Just (ConnectionTableEntry m addr
 -> Maybe (ConnectionTableEntry m addr))
-> ConnectionTableEntry m addr
-> Maybe (ConnectionTableEntry m addr)
forall a b. (a -> b) -> a -> b
$ ConnectionTableEntry m addr
cte {
                cteRefs = refs'
              , cteLocalAddresses = S.insert localAddr (cteLocalAddresses cte)
              }

-- TODO This should use Control.Tracer
-- TODO shoult this be removed? Doesn't seem to be used anywhere
_dumpConnectionTable
    :: ConnectionTable IO Socket.SockAddr
    -> IO ()
_dumpConnectionTable :: ConnectionTable IO SockAddr -> IO ()
_dumpConnectionTable ConnectionTable{StrictTVar
  IO
  (Map
     (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr))
ctTable :: forall (m :: * -> *) addr.
ConnectionTable m addr
-> StrictTVar
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable :: StrictTVar
  IO
  (Map
     (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr))
ctTable} = do
    Map
  (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr)
tbl <- STM
  IO
  (Map
     (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr))
-> IO
     (Map
        (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr))
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM
   IO
   (Map
      (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr))
 -> IO
      (Map
         (SockAddr, ConnectionDirection)
         (ConnectionTableEntry IO SockAddr)))
-> STM
     IO
     (Map
        (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr))
-> IO
     (Map
        (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr))
forall a b. (a -> b) -> a -> b
$ StrictTVar
  IO
  (Map
     (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr))
-> STM
     IO
     (Map
        (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar
  IO
  (Map
     (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr))
ctTable
    String -> IO ()
forall r. PrintfType r => String -> r
printf String
"Dumping Table:\n"
    (((SockAddr, ConnectionDirection),
  ConnectionTableEntry IO SockAddr)
 -> IO ())
-> [((SockAddr, ConnectionDirection),
     ConnectionTableEntry IO SockAddr)]
-> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((SockAddr, ConnectionDirection), ConnectionTableEntry IO SockAddr)
-> IO ()
dumpTableEntry (Map
  (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr)
-> [((SockAddr, ConnectionDirection),
     ConnectionTableEntry IO SockAddr)]
forall k a. Map k a -> [(k, a)]
M.toList Map
  (SockAddr, ConnectionDirection) (ConnectionTableEntry IO SockAddr)
tbl)
  where
    dumpTableEntry :: ((Socket.SockAddr, ConnectionDirection), ConnectionTableEntry IO Socket.SockAddr) -> IO ()
    dumpTableEntry :: ((SockAddr, ConnectionDirection), ConnectionTableEntry IO SockAddr)
-> IO ()
dumpTableEntry ((SockAddr
remoteAddr, ConnectionDirection
dir), ConnectionTableEntry IO SockAddr
ce) = do
        [Int]
refs <- (ValencyCounter IO -> IO Int) -> [ValencyCounter IO] -> IO [Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (STM Int -> IO Int
STM IO Int -> IO Int
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM Int -> IO Int)
-> (ValencyCounter IO -> STM Int) -> ValencyCounter IO -> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrictTVar IO Int -> STM Int
StrictTVar IO Int -> STM IO Int
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar (StrictTVar IO Int -> STM Int)
-> (ValencyCounter IO -> StrictTVar IO Int)
-> ValencyCounter IO
-> STM Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValencyCounter IO -> StrictTVar IO Int
forall (m :: * -> *). ValencyCounter m -> StrictTVar m Int
vcRef) (Set (ValencyCounter IO) -> [ValencyCounter IO]
forall a. Set a -> [a]
S.elems (Set (ValencyCounter IO) -> [ValencyCounter IO])
-> Set (ValencyCounter IO) -> [ValencyCounter IO]
forall a b. (a -> b) -> a -> b
$ ConnectionTableEntry IO SockAddr -> Set (ValencyCounter IO)
forall (m :: * -> *) addr.
ConnectionTableEntry m addr -> Set (ValencyCounter m)
cteRefs ConnectionTableEntry IO SockAddr
ce)
        let rids :: [Int]
rids = (ValencyCounter IO -> Int) -> [ValencyCounter IO] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ValencyCounter IO -> Int
forall (m :: * -> *). ValencyCounter m -> Int
vcId ([ValencyCounter IO] -> [Int]) -> [ValencyCounter IO] -> [Int]
forall a b. (a -> b) -> a -> b
$ Set (ValencyCounter IO) -> [ValencyCounter IO]
forall a. Set a -> [a]
S.elems (Set (ValencyCounter IO) -> [ValencyCounter IO])
-> Set (ValencyCounter IO) -> [ValencyCounter IO]
forall a b. (a -> b) -> a -> b
$ ConnectionTableEntry IO SockAddr -> Set (ValencyCounter IO)
forall (m :: * -> *) addr.
ConnectionTableEntry m addr -> Set (ValencyCounter m)
cteRefs ConnectionTableEntry IO SockAddr
ce
            refids :: [(Int, Int)]
refids = [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
rids [Int]
refs
        String -> String -> String -> String -> String -> IO ()
forall r. PrintfType r => String -> r
printf String
"Remote Address: %s\nLocal Addresses %s\nDirection %s\nReferenses %s\n"
            (SockAddr -> String
forall a. Show a => a -> String
show SockAddr
remoteAddr) (Set SockAddr -> String
forall a. Show a => a -> String
show (Set SockAddr -> String) -> Set SockAddr -> String
forall a b. (a -> b) -> a -> b
$ ConnectionTableEntry IO SockAddr -> Set SockAddr
forall (m :: * -> *) addr. ConnectionTableEntry m addr -> Set addr
cteLocalAddresses ConnectionTableEntry IO SockAddr
ce) (ConnectionDirection -> String
forall a. Show a => a -> String
show ConnectionDirection
dir) ([(Int, Int)] -> String
forall a. Show a => a -> String
show [(Int, Int)]
refids)

-- | Remove a Connection.
removeConnectionSTM
    :: forall m addr.
       ( MonadSTM m
       , Ord addr
       )
    => ConnectionTable m addr
    -> addr
    -> addr
    -> ConnectionDirection
    -> STM m ()
removeConnectionSTM :: forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
ConnectionTable m addr
-> addr -> addr -> ConnectionDirection -> STM m ()
removeConnectionSTM ConnectionTable{StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable :: forall (m :: * -> *) addr.
ConnectionTable m addr
-> StrictTVar
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable :: StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable} addr
remoteAddr addr
localAddr ConnectionDirection
dir =
    StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> STM
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable STM
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
    -> STM
         m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)))
-> STM
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Maybe (ConnectionTableEntry m addr)
 -> STM m (Maybe (ConnectionTableEntry m addr)))
-> (addr, ConnectionDirection)
-> Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
-> STM
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
forall (f :: * -> *) k a.
(Functor f, Ord k) =>
(Maybe a -> f (Maybe a)) -> k -> Map k a -> f (Map k a)
M.alterF Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
fn (addr
remoteAddr, ConnectionDirection
dir) STM
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
    -> STM m ())
-> STM m ()
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
-> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable
  where
    fn :: Maybe (ConnectionTableEntry m addr)
       -> STM m (Maybe (ConnectionTableEntry m addr))
    fn :: Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
fn Maybe (ConnectionTableEntry m addr)
Nothing = Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (ConnectionTableEntry m addr)
forall a. Maybe a
Nothing -- XXX removing non existent address
    fn (Just ConnectionTableEntry{Set (ValencyCounter m)
cteRefs :: forall (m :: * -> *) addr.
ConnectionTableEntry m addr -> Set (ValencyCounter m)
cteRefs :: Set (ValencyCounter m)
cteRefs, Set addr
cteLocalAddresses :: forall (m :: * -> *) addr. ConnectionTableEntry m addr -> Set addr
cteLocalAddresses :: Set addr
cteLocalAddresses}) = do
        (ValencyCounter m -> STM m ())
-> Set (ValencyCounter m) -> STM m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ValencyCounter m -> STM m ()
forall (m :: * -> *). MonadSTM m => ValencyCounter m -> STM m ()
remValencyCounter Set (ValencyCounter m)
cteRefs
        let localAddresses' :: Set addr
localAddresses' = addr -> Set addr -> Set addr
forall a. Ord a => a -> Set a -> Set a
S.delete addr
localAddr Set addr
cteLocalAddresses
        if Set addr -> Bool
forall a. Set a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Set addr
localAddresses'
            then Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (ConnectionTableEntry m addr)
forall a. Maybe a
Nothing
            else Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ConnectionTableEntry m addr)
 -> STM m (Maybe (ConnectionTableEntry m addr)))
-> Maybe (ConnectionTableEntry m addr)
-> STM m (Maybe (ConnectionTableEntry m addr))
forall a b. (a -> b) -> a -> b
$ ConnectionTableEntry m addr -> Maybe (ConnectionTableEntry m addr)
forall a. a -> Maybe a
Just (ConnectionTableEntry m addr
 -> Maybe (ConnectionTableEntry m addr))
-> ConnectionTableEntry m addr
-> Maybe (ConnectionTableEntry m addr)
forall a b. (a -> b) -> a -> b
$ Set (ValencyCounter m) -> Set addr -> ConnectionTableEntry m addr
forall (m :: * -> *) addr.
Set (ValencyCounter m) -> Set addr -> ConnectionTableEntry m addr
ConnectionTableEntry Set (ValencyCounter m)
cteRefs Set addr
localAddresses'

removeConnection
    :: forall m addr.
       ( MonadSTM m
       , Ord addr
       )
    => ConnectionTable m addr
    -> addr
    -> addr
    -> ConnectionDirection
    -> m ()
removeConnection :: forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
ConnectionTable m addr
-> addr -> addr -> ConnectionDirection -> m ()
removeConnection ConnectionTable m addr
tbl addr
remoteAddr addr
localAddr ConnectionDirection
dir = STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ ConnectionTable m addr
-> addr -> addr -> ConnectionDirection -> STM m ()
forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
ConnectionTable m addr
-> addr -> addr -> ConnectionDirection -> STM m ()
removeConnectionSTM ConnectionTable m addr
tbl addr
remoteAddr addr
localAddr ConnectionDirection
dir

-- | Try to see if it is possible to reference an existing connection rather
-- than creating a new one to the provied peer.
--
refConnectionSTM
    :: ( MonadSTM m
       , Ord addr
       )
    => ConnectionTable m addr
    -> addr
    -> ConnectionDirection
    -> ValencyCounter m
    -> STM m ConnectionTableRef
refConnectionSTM :: forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
ConnectionTable m addr
-> addr
-> ConnectionDirection
-> ValencyCounter m
-> STM m ConnectionTableRef
refConnectionSTM ConnectionTable{StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable :: forall (m :: * -> *) addr.
ConnectionTable m addr
-> StrictTVar
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable :: StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable} addr
remoteAddr ConnectionDirection
dir ValencyCounter m
refVar = do
    Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
tbl <- StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> STM
     m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable
    case (addr, ConnectionDirection)
-> Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
-> Maybe (ConnectionTableEntry m addr)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (addr
remoteAddr, ConnectionDirection
dir) Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
tbl of
         Maybe (ConnectionTableEntry m addr)
Nothing -> ConnectionTableRef -> STM m ConnectionTableRef
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return ConnectionTableRef
ConnectionTableCreate
         Just ConnectionTableEntry m addr
cte ->
             if ValencyCounter m -> Set (ValencyCounter m) -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member ValencyCounter m
refVar (Set (ValencyCounter m) -> Bool) -> Set (ValencyCounter m) -> Bool
forall a b. (a -> b) -> a -> b
$ ConnectionTableEntry m addr -> Set (ValencyCounter m)
forall (m :: * -> *) addr.
ConnectionTableEntry m addr -> Set (ValencyCounter m)
cteRefs ConnectionTableEntry m addr
cte
                 then ConnectionTableRef -> STM m ConnectionTableRef
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return ConnectionTableRef
ConnectionTableDuplicate
                 else do
                     -- TODO We look up remoteAddr twice, is it possible
                     -- to use M.alterF given that we need to be able to return
                     -- ConnectionTableCreate or ConnectionTableExist?
                     let refs' :: Set (ValencyCounter m)
refs' = ValencyCounter m
-> Set (ValencyCounter m) -> Set (ValencyCounter m)
forall a. Ord a => a -> Set a -> Set a
S.insert ValencyCounter m
refVar (ConnectionTableEntry m addr -> Set (ValencyCounter m)
forall (m :: * -> *) addr.
ConnectionTableEntry m addr -> Set (ValencyCounter m)
cteRefs ConnectionTableEntry m addr
cte)
                     (ValencyCounter m -> STM m ()) -> [ValencyCounter m] -> STM m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ValencyCounter m -> STM m ()
forall (m :: * -> *). MonadSTM m => ValencyCounter m -> STM m ()
addValencyCounter ([ValencyCounter m] -> STM m ()) -> [ValencyCounter m] -> STM m ()
forall a b. (a -> b) -> a -> b
$ Set (ValencyCounter m) -> [ValencyCounter m]
forall a. Set a -> [a]
S.toList Set (ValencyCounter m)
refs'

                     StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
-> Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
-> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar
  m (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr))
ctTable (Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
 -> STM m ())
-> Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
-> STM m ()
forall a b. (a -> b) -> a -> b
$ (addr, ConnectionDirection)
-> ConnectionTableEntry m addr
-> Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
-> Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (addr
remoteAddr, ConnectionDirection
dir)
                         (ConnectionTableEntry m addr
cte { cteRefs = refs'}) Map (addr, ConnectionDirection) (ConnectionTableEntry m addr)
tbl
                     ConnectionTableRef -> STM m ConnectionTableRef
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return ConnectionTableRef
ConnectionTableExist

refConnection
    :: ( MonadSTM m
       , Ord addr
       )
    => ConnectionTable m addr
    -> addr
    -> ConnectionDirection
    -> ValencyCounter m
    -> m ConnectionTableRef
refConnection :: forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
ConnectionTable m addr
-> addr
-> ConnectionDirection
-> ValencyCounter m
-> m ConnectionTableRef
refConnection ConnectionTable m addr
tbl addr
remoteAddr ConnectionDirection
dir ValencyCounter m
refVar =
    STM m ConnectionTableRef -> m ConnectionTableRef
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m ConnectionTableRef -> m ConnectionTableRef)
-> STM m ConnectionTableRef -> m ConnectionTableRef
forall a b. (a -> b) -> a -> b
$ ConnectionTable m addr
-> addr
-> ConnectionDirection
-> ValencyCounter m
-> STM m ConnectionTableRef
forall (m :: * -> *) addr.
(MonadSTM m, Ord addr) =>
ConnectionTable m addr
-> addr
-> ConnectionDirection
-> ValencyCounter m
-> STM m ConnectionTableRef
refConnectionSTM  ConnectionTable m addr
tbl addr
remoteAddr ConnectionDirection
dir ValencyCounter m
refVar