{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
module Ouroboros.Network.TxSubmission.Mempool.Simple
( Mempool (..)
, MempoolSeq (..)
, WithIndex (..)
, empty
, new
, read
, getReader
, TxSubmissionMempoolReader (..)
, getWriter
, TxSubmissionMempoolWriter (..)
) where
import Prelude hiding (read, seq)
import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI
import Data.Either (partitionEithers)
import Data.Foldable (toList)
import Data.Foldable qualified as Foldable
import Data.Function (on)
import Data.List (find, nubBy)
import Data.Sequence (Seq)
import Data.Sequence qualified as Seq
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Typeable (Typeable)
import Ouroboros.Network.SizeInBytes
import Ouroboros.Network.TxSubmission.Inbound.V2.Types
import Ouroboros.Network.TxSubmission.Mempool.Reader
data WithIndex tx = WithIndex { forall tx. WithIndex tx -> Integer
getIdx :: !Integer,
forall tx. WithIndex tx -> tx
getTx :: !tx }
data MempoolSeq txid tx = MempoolSeq {
forall txid tx. MempoolSeq txid tx -> Set txid
mempoolSet :: !(Set txid),
forall txid tx. MempoolSeq txid tx -> Seq (WithIndex tx)
mempoolSeq :: !(Seq (WithIndex tx)),
forall txid tx. MempoolSeq txid tx -> Integer
nextIdx :: !Integer
}
newtype Mempool m txid tx = Mempool (StrictTVar m (MempoolSeq txid tx))
empty :: MonadSTM m => m (Mempool m txid tx)
empty :: forall (m :: * -> *) txid tx. MonadSTM m => m (Mempool m txid tx)
empty = StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
forall (m :: * -> *) txid tx.
StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
Mempool (StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx)
-> m (StrictTVar m (MempoolSeq txid tx)) -> m (Mempool m txid tx)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MempoolSeq txid tx -> m (StrictTVar m (MempoolSeq txid tx))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO (Set txid -> Seq (WithIndex tx) -> Integer -> MempoolSeq txid tx
forall txid tx.
Set txid -> Seq (WithIndex tx) -> Integer -> MempoolSeq txid tx
MempoolSeq Set txid
forall a. Set a
Set.empty Seq (WithIndex tx)
forall a. Seq a
Seq.empty (-Integer
1))
new :: ( MonadSTM m
, Ord txid
)
=> (tx -> txid)
-> [tx]
-> m (Mempool m txid tx)
new :: forall (m :: * -> *) txid tx.
(MonadSTM m, Ord txid) =>
(tx -> txid) -> [tx] -> m (Mempool m txid tx)
new tx -> txid
getTxId [tx]
txs =
(StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx)
-> m (StrictTVar m (MempoolSeq txid tx)) -> m (Mempool m txid tx)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
forall (m :: * -> *) txid tx.
StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
Mempool
(m (StrictTVar m (MempoolSeq txid tx)) -> m (Mempool m txid tx))
-> (MempoolSeq txid tx -> m (StrictTVar m (MempoolSeq txid tx)))
-> MempoolSeq txid tx
-> m (Mempool m txid tx)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MempoolSeq txid tx -> m (StrictTVar m (MempoolSeq txid tx))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO
(MempoolSeq txid tx -> m (Mempool m txid tx))
-> MempoolSeq txid tx -> m (Mempool m txid tx)
forall a b. (a -> b) -> a -> b
$ MempoolSeq { mempoolSet :: Set txid
mempoolSet = [txid] -> Set txid
forall a. Ord a => [a] -> Set a
Set.fromList (tx -> txid
getTxId (tx -> txid) -> [tx] -> [txid]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [tx]
txs),
Seq (WithIndex tx)
mempoolSeq :: Seq (WithIndex tx)
mempoolSeq :: Seq (WithIndex tx)
mempoolSeq,
nextIdx :: Integer
nextIdx = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Seq (WithIndex tx) -> Int
forall a. Seq a -> Int
Seq.length Seq (WithIndex tx)
mempoolSeq) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1
}
where
mempoolSeq :: Seq (WithIndex tx)
mempoolSeq = [WithIndex tx] -> Seq (WithIndex tx)
forall a. [a] -> Seq a
Seq.fromList ([WithIndex tx] -> Seq (WithIndex tx))
-> [WithIndex tx] -> Seq (WithIndex tx)
forall a b. (a -> b) -> a -> b
$ (Integer -> tx -> WithIndex tx)
-> [Integer] -> [tx] -> [WithIndex tx]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Integer -> tx -> WithIndex tx
forall tx. Integer -> tx -> WithIndex tx
WithIndex [Integer
0..] [tx]
txs
read :: MonadSTM m => Mempool m txid tx -> m [tx]
read :: forall (m :: * -> *) txid tx.
MonadSTM m =>
Mempool m txid tx -> m [tx]
read (Mempool StrictTVar m (MempoolSeq txid tx)
mempool) = (WithIndex tx -> tx) -> [WithIndex tx] -> [tx]
forall a b. (a -> b) -> [a] -> [b]
map WithIndex tx -> tx
forall tx. WithIndex tx -> tx
getTx ([WithIndex tx] -> [tx])
-> (MempoolSeq txid tx -> [WithIndex tx])
-> MempoolSeq txid tx
-> [tx]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (WithIndex tx) -> [WithIndex tx]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq (WithIndex tx) -> [WithIndex tx])
-> (MempoolSeq txid tx -> Seq (WithIndex tx))
-> MempoolSeq txid tx
-> [WithIndex tx]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MempoolSeq txid tx -> Seq (WithIndex tx)
forall txid tx. MempoolSeq txid tx -> Seq (WithIndex tx)
mempoolSeq (MempoolSeq txid tx -> [tx]) -> m (MempoolSeq txid tx) -> m [tx]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTVar m (MempoolSeq txid tx) -> m (MempoolSeq txid tx)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> m a
readTVarIO StrictTVar m (MempoolSeq txid tx)
mempool
getReader :: forall tx txid m.
( MonadSTM m
, Ord txid
)
=> (tx -> txid)
-> (tx -> SizeInBytes)
-> Mempool m txid tx
-> TxSubmissionMempoolReader txid tx Integer m
getReader :: forall tx txid (m :: * -> *).
(MonadSTM m, Ord txid) =>
(tx -> txid)
-> (tx -> SizeInBytes)
-> Mempool m txid tx
-> TxSubmissionMempoolReader txid tx Integer m
getReader tx -> txid
getTxId tx -> SizeInBytes
getTxSize (Mempool StrictTVar m (MempoolSeq txid tx)
mempool) =
TxSubmissionMempoolReader { STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot,
mempoolZeroIdx :: Integer
mempoolZeroIdx = -Integer
1
}
where
mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot = MempoolSeq txid tx -> MempoolSnapshot txid tx Integer
getSnapshot (MempoolSeq txid tx -> MempoolSnapshot txid tx Integer)
-> STM m (MempoolSeq txid tx)
-> STM m (MempoolSnapshot txid tx Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTVar m (MempoolSeq txid tx) -> STM m (MempoolSeq txid tx)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (MempoolSeq txid tx)
mempool
getSnapshot :: MempoolSeq txid tx
-> MempoolSnapshot txid tx Integer
getSnapshot :: MempoolSeq txid tx -> MempoolSnapshot txid tx Integer
getSnapshot MempoolSeq { mempoolSeq :: forall txid tx. MempoolSeq txid tx -> Seq (WithIndex tx)
mempoolSeq = Seq (WithIndex tx)
seq, Set txid
mempoolSet :: forall txid tx. MempoolSeq txid tx -> Set txid
mempoolSet :: Set txid
mempoolSet } =
MempoolSnapshot {
mempoolTxIdsAfter :: Integer -> [(txid, Integer, SizeInBytes)]
mempoolTxIdsAfter =
\Integer
idx ->
(WithIndex tx
-> [(txid, Integer, SizeInBytes)]
-> [(txid, Integer, SizeInBytes)])
-> [(txid, Integer, SizeInBytes)]
-> Seq (WithIndex tx)
-> [(txid, Integer, SizeInBytes)]
forall a b. (a -> b -> b) -> b -> Seq a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
(\WithIndex {Integer
getIdx :: forall tx. WithIndex tx -> Integer
getIdx :: Integer
getIdx, tx
getTx :: forall tx. WithIndex tx -> tx
getTx :: tx
getTx} [(txid, Integer, SizeInBytes)]
acc ->
if Integer
getIdx Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
idx
then (tx -> txid
getTxId tx
getTx, Integer
getIdx, tx -> SizeInBytes
getTxSize tx
getTx) (txid, Integer, SizeInBytes)
-> [(txid, Integer, SizeInBytes)] -> [(txid, Integer, SizeInBytes)]
forall a. a -> [a] -> [a]
: [(txid, Integer, SizeInBytes)]
acc
else [(txid, Integer, SizeInBytes)]
acc
)
[]
Seq (WithIndex tx)
seq,
mempoolLookupTx :: Integer -> Maybe tx
mempoolLookupTx =
\Integer
idx -> WithIndex tx -> tx
forall tx. WithIndex tx -> tx
getTx (WithIndex tx -> tx) -> Maybe (WithIndex tx) -> Maybe tx
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithIndex tx -> Bool)
-> Seq (WithIndex tx) -> Maybe (WithIndex tx)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\WithIndex {Integer
getIdx :: forall tx. WithIndex tx -> Integer
getIdx :: Integer
getIdx} -> Integer
getIdx Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
idx) Seq (WithIndex tx)
seq,
mempoolHasTx :: txid -> Bool
mempoolHasTx =
\txid
txid -> txid
txid txid -> Set txid -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set txid
mempoolSet
}
data InvalidTxsError where
InvalidTxsError :: forall txid failure.
( Typeable txid
, Typeable failure
, Show txid
, Show failure
)
=> [(txid, failure)]
-> InvalidTxsError
deriving instance Show InvalidTxsError
instance Exception InvalidTxsError
getWriter :: forall tx txid failure m.
( MonadSTM m
, MonadTime m
, Ord txid
)
=> failure
-> (tx -> txid)
-> (UTCTime -> [tx] -> STM m [Either (txid, failure) tx])
-> ([(txid, failure)] -> m ())
-> Mempool m txid tx
-> TxSubmissionMempoolWriter txid tx Integer m failure
getWriter :: forall tx txid failure (m :: * -> *).
(MonadSTM m, MonadTime m, Ord txid) =>
failure
-> (tx -> txid)
-> (UTCTime -> [tx] -> STM m [Either (txid, failure) tx])
-> ([(txid, failure)] -> m ())
-> Mempool m txid tx
-> TxSubmissionMempoolWriter txid tx Integer m failure
getWriter failure
duplicateTx tx -> txid
getTxId UTCTime -> [tx] -> STM m [Either (txid, failure) tx]
validateTx [(txid, failure)] -> m ()
handleInvalidTxs (Mempool StrictTVar m (MempoolSeq txid tx)
mempool) =
TxSubmissionMempoolWriter {
txId :: tx -> txid
txId = tx -> txid
getTxId,
mempoolAddTxs :: [tx] -> m ([txid], [(txid, failure)])
mempoolAddTxs = \[tx]
txs -> do
now <- m UTCTime
forall (m :: * -> *). MonadTime m => m UTCTime
getCurrentTime
(invalidTxIds, validTxs) <- atomically $ do
MempoolSeq { mempoolSet, mempoolSeq, nextIdx } <- readTVar mempool
(invalidTxIds, validTxs) <-
fmap partitionEithers
. validateTx now
. filter (\tx
tx -> tx -> txid
getTxId tx
tx txid -> Set txid -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.notMember` Set txid
mempoolSet)
. nubBy (on (==) getTxId)
$ txs
let mempoolTxs' = MempoolSeq {
mempoolSet :: Set txid
mempoolSet = (Set txid -> tx -> Set txid) -> Set txid -> [tx] -> Set txid
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl' (\Set txid
s tx
tx -> tx -> txid
getTxId tx
tx txid -> Set txid -> Set txid
forall a. Ord a => a -> Set a -> Set a
`Set.insert` Set txid
s)
Set txid
mempoolSet
[tx]
validTxs,
mempoolSeq :: Seq (WithIndex tx)
mempoolSeq = (Seq (WithIndex tx) -> WithIndex tx -> Seq (WithIndex tx))
-> Seq (WithIndex tx) -> [WithIndex tx] -> Seq (WithIndex tx)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl' Seq (WithIndex tx) -> WithIndex tx -> Seq (WithIndex tx)
forall a. Seq a -> a -> Seq a
(Seq.|>) Seq (WithIndex tx)
mempoolSeq ((Integer -> tx -> WithIndex tx)
-> [Integer] -> [tx] -> [WithIndex tx]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Integer -> tx -> WithIndex tx
forall tx. Integer -> tx -> WithIndex tx
WithIndex [Integer
nextIdx..] [tx]
validTxs),
nextIdx :: Integer
nextIdx = Integer
nextIdx Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([tx] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [tx]
validTxs)
}
writeTVar mempool mempoolTxs'
return ( invalidTxIds
++
[ (txid, duplicateTx)
| txid <- filter (`Set.notMember` mempoolSet)
. map getTxId
$ txs
]
, map getTxId validTxs
)
handleInvalidTxs invalidTxIds
return (validTxs, invalidTxIds)
}