{-# LANGUAGE DerivingStrategies  #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}

-- | The module should be imported qualified.
--
module Ouroboros.Network.TxSubmission.Mempool.Simple
  ( Mempool (..)
  , MempoolSeq (..)
  , WithIndex (..)
  , empty
  , new
  , read
  , getReader
  , TxSubmissionMempoolReader (..)
  , getWriter
  , TxSubmissionMempoolWriter (..)
  ) where

import Prelude hiding (read, seq)

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI

import Data.Either (partitionEithers)
import Data.Foldable (toList)
import Data.Foldable qualified as Foldable
import Data.Function (on)
import Data.List (find, nubBy)
import Data.Sequence (Seq)
import Data.Sequence qualified as Seq
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Typeable (Typeable)

import Ouroboros.Network.SizeInBytes
import Ouroboros.Network.TxSubmission.Inbound.V2.Types
import Ouroboros.Network.TxSubmission.Mempool.Reader

data WithIndex tx = WithIndex { forall tx. WithIndex tx -> Integer
getIdx :: !Integer,
                                forall tx. WithIndex tx -> tx
getTx  :: !tx }

data MempoolSeq txid tx = MempoolSeq {
    forall txid tx. MempoolSeq txid tx -> Set txid
mempoolSet :: !(Set txid),
    -- ^ cached set of `txid`s in the mempool
    forall txid tx. MempoolSeq txid tx -> Seq (WithIndex tx)
mempoolSeq :: !(Seq (WithIndex tx)),
    -- ^ sequence of all `tx`s
    forall txid tx. MempoolSeq txid tx -> Integer
nextIdx    :: !Integer
    -- ^ next available index
    --
    -- Invariant: larger than the index of the last element of `mempoolSeq`.
  }

-- | A simple in-memory mempool implementation.
--
newtype Mempool m txid tx = Mempool (StrictTVar m (MempoolSeq txid tx))


empty :: MonadSTM m => m (Mempool m txid tx)
empty :: forall (m :: * -> *) txid tx. MonadSTM m => m (Mempool m txid tx)
empty = StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
forall (m :: * -> *) txid tx.
StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
Mempool (StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx)
-> m (StrictTVar m (MempoolSeq txid tx)) -> m (Mempool m txid tx)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MempoolSeq txid tx -> m (StrictTVar m (MempoolSeq txid tx))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO (Set txid -> Seq (WithIndex tx) -> Integer -> MempoolSeq txid tx
forall txid tx.
Set txid -> Seq (WithIndex tx) -> Integer -> MempoolSeq txid tx
MempoolSeq Set txid
forall a. Set a
Set.empty Seq (WithIndex tx)
forall a. Seq a
Seq.empty (-Integer
1))


new :: ( MonadSTM m
       , Ord txid
       )
    => (tx -> txid)
    -> [tx]
    -> m (Mempool m txid tx)
new :: forall (m :: * -> *) txid tx.
(MonadSTM m, Ord txid) =>
(tx -> txid) -> [tx] -> m (Mempool m txid tx)
new tx -> txid
getTxId [tx]
txs =
      (StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx)
-> m (StrictTVar m (MempoolSeq txid tx)) -> m (Mempool m txid tx)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
forall (m :: * -> *) txid tx.
StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
Mempool
    (m (StrictTVar m (MempoolSeq txid tx)) -> m (Mempool m txid tx))
-> (MempoolSeq txid tx -> m (StrictTVar m (MempoolSeq txid tx)))
-> MempoolSeq txid tx
-> m (Mempool m txid tx)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MempoolSeq txid tx -> m (StrictTVar m (MempoolSeq txid tx))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO
    (MempoolSeq txid tx -> m (Mempool m txid tx))
-> MempoolSeq txid tx -> m (Mempool m txid tx)
forall a b. (a -> b) -> a -> b
$ MempoolSeq { mempoolSet :: Set txid
mempoolSet = [txid] -> Set txid
forall a. Ord a => [a] -> Set a
Set.fromList (tx -> txid
getTxId (tx -> txid) -> [tx] -> [txid]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [tx]
txs),
                   Seq (WithIndex tx)
mempoolSeq :: Seq (WithIndex tx)
mempoolSeq :: Seq (WithIndex tx)
mempoolSeq,
                   nextIdx :: Integer
nextIdx = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Seq (WithIndex tx) -> Int
forall a. Seq a -> Int
Seq.length Seq (WithIndex tx)
mempoolSeq) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1
                 }
  where
    mempoolSeq :: Seq (WithIndex tx)
mempoolSeq = [WithIndex tx] -> Seq (WithIndex tx)
forall a. [a] -> Seq a
Seq.fromList ([WithIndex tx] -> Seq (WithIndex tx))
-> [WithIndex tx] -> Seq (WithIndex tx)
forall a b. (a -> b) -> a -> b
$ (Integer -> tx -> WithIndex tx)
-> [Integer] -> [tx] -> [WithIndex tx]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Integer -> tx -> WithIndex tx
forall tx. Integer -> tx -> WithIndex tx
WithIndex [Integer
0..] [tx]
txs


read :: MonadSTM m => Mempool m txid tx -> m [tx]
read :: forall (m :: * -> *) txid tx.
MonadSTM m =>
Mempool m txid tx -> m [tx]
read (Mempool StrictTVar m (MempoolSeq txid tx)
mempool) = (WithIndex tx -> tx) -> [WithIndex tx] -> [tx]
forall a b. (a -> b) -> [a] -> [b]
map WithIndex tx -> tx
forall tx. WithIndex tx -> tx
getTx ([WithIndex tx] -> [tx])
-> (MempoolSeq txid tx -> [WithIndex tx])
-> MempoolSeq txid tx
-> [tx]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (WithIndex tx) -> [WithIndex tx]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq (WithIndex tx) -> [WithIndex tx])
-> (MempoolSeq txid tx -> Seq (WithIndex tx))
-> MempoolSeq txid tx
-> [WithIndex tx]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MempoolSeq txid tx -> Seq (WithIndex tx)
forall txid tx. MempoolSeq txid tx -> Seq (WithIndex tx)
mempoolSeq (MempoolSeq txid tx -> [tx]) -> m (MempoolSeq txid tx) -> m [tx]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTVar m (MempoolSeq txid tx) -> m (MempoolSeq txid tx)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> m a
readTVarIO StrictTVar m (MempoolSeq txid tx)
mempool


getReader :: forall tx txid m.
             ( MonadSTM m
             , Ord txid
             )
          => (tx -> txid)
          -> (tx -> SizeInBytes)
          -> Mempool m txid tx
          -> TxSubmissionMempoolReader txid tx Integer m
getReader :: forall tx txid (m :: * -> *).
(MonadSTM m, Ord txid) =>
(tx -> txid)
-> (tx -> SizeInBytes)
-> Mempool m txid tx
-> TxSubmissionMempoolReader txid tx Integer m
getReader tx -> txid
getTxId tx -> SizeInBytes
getTxSize (Mempool StrictTVar m (MempoolSeq txid tx)
mempool) =
    -- Using `0`-based index.  `mempoolZeroIdx = -1` so that
    -- `mempoolTxIdsAfter mempoolZeroIdx` returns all txs.
    TxSubmissionMempoolReader { STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot,
                                mempoolZeroIdx :: Integer
mempoolZeroIdx = -Integer
1
                              }
  where
    mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
    mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot = MempoolSeq txid tx -> MempoolSnapshot txid tx Integer
getSnapshot (MempoolSeq txid tx -> MempoolSnapshot txid tx Integer)
-> STM m (MempoolSeq txid tx)
-> STM m (MempoolSnapshot txid tx Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTVar m (MempoolSeq txid tx) -> STM m (MempoolSeq txid tx)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (MempoolSeq txid tx)
mempool

    getSnapshot :: MempoolSeq txid tx
                -> MempoolSnapshot txid tx Integer
    getSnapshot :: MempoolSeq txid tx -> MempoolSnapshot txid tx Integer
getSnapshot MempoolSeq { mempoolSeq :: forall txid tx. MempoolSeq txid tx -> Seq (WithIndex tx)
mempoolSeq = Seq (WithIndex tx)
seq, Set txid
mempoolSet :: forall txid tx. MempoolSeq txid tx -> Set txid
mempoolSet :: Set txid
mempoolSet } =
      MempoolSnapshot {
          mempoolTxIdsAfter :: Integer -> [(txid, Integer, SizeInBytes)]
mempoolTxIdsAfter =
            \Integer
idx ->
              (WithIndex tx
 -> [(txid, Integer, SizeInBytes)]
 -> [(txid, Integer, SizeInBytes)])
-> [(txid, Integer, SizeInBytes)]
-> Seq (WithIndex tx)
-> [(txid, Integer, SizeInBytes)]
forall a b. (a -> b -> b) -> b -> Seq a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
                (\WithIndex {Integer
getIdx :: forall tx. WithIndex tx -> Integer
getIdx :: Integer
getIdx, tx
getTx :: forall tx. WithIndex tx -> tx
getTx :: tx
getTx} [(txid, Integer, SizeInBytes)]
acc ->
                  if Integer
getIdx Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
idx
                  then (tx -> txid
getTxId tx
getTx, Integer
getIdx, tx -> SizeInBytes
getTxSize tx
getTx) (txid, Integer, SizeInBytes)
-> [(txid, Integer, SizeInBytes)] -> [(txid, Integer, SizeInBytes)]
forall a. a -> [a] -> [a]
: [(txid, Integer, SizeInBytes)]
acc
                  else [(txid, Integer, SizeInBytes)]
acc
                )
                []
                Seq (WithIndex tx)
seq,
          mempoolLookupTx :: Integer -> Maybe tx
mempoolLookupTx =
            \Integer
idx -> WithIndex tx -> tx
forall tx. WithIndex tx -> tx
getTx (WithIndex tx -> tx) -> Maybe (WithIndex tx) -> Maybe tx
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithIndex tx -> Bool)
-> Seq (WithIndex tx) -> Maybe (WithIndex tx)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\WithIndex {Integer
getIdx :: forall tx. WithIndex tx -> Integer
getIdx :: Integer
getIdx} -> Integer
getIdx Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
idx) Seq (WithIndex tx)
seq,
          mempoolHasTx :: txid -> Bool
mempoolHasTx =
            \txid
txid -> txid
txid txid -> Set txid -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set txid
mempoolSet
        }


data InvalidTxsError where
    InvalidTxsError :: forall txid failure.
                       ( Typeable txid
                       , Typeable failure
                       , Show txid
                       , Show failure
                       )
                    => [(txid, failure)]
                    -> InvalidTxsError

deriving instance Show InvalidTxsError
instance Exception InvalidTxsError


-- | A simple mempool writer.
--
getWriter :: forall tx txid failure m.
             ( MonadSTM m
             , MonadTime m
             , Ord txid
             )
          => failure
          -- ^ duplicate tx error
          -> (tx -> txid)
          -- ^ get transaction hash
          -> (UTCTime -> [tx] -> STM m [Either (txid, failure) tx])
          -- ^ validate a tx in an `STM` transaction, this allows acquiring and
          -- updating validation context.
          -> ([(txid, failure)] -> m ())
          -- ^ handle invalid txs, e.g. logging, throwing exceptions, etc
          -> Mempool m txid tx
          -- ^ mempool
          -> TxSubmissionMempoolWriter txid tx Integer m failure
getWriter :: forall tx txid failure (m :: * -> *).
(MonadSTM m, MonadTime m, Ord txid) =>
failure
-> (tx -> txid)
-> (UTCTime -> [tx] -> STM m [Either (txid, failure) tx])
-> ([(txid, failure)] -> m ())
-> Mempool m txid tx
-> TxSubmissionMempoolWriter txid tx Integer m failure
getWriter failure
duplicateTx tx -> txid
getTxId UTCTime -> [tx] -> STM m [Either (txid, failure) tx]
validateTx [(txid, failure)] -> m ()
handleInvalidTxs (Mempool StrictTVar m (MempoolSeq txid tx)
mempool) =
    TxSubmissionMempoolWriter {
        txId :: tx -> txid
txId = tx -> txid
getTxId,

        mempoolAddTxs :: [tx] -> m ([txid], [(txid, failure)])
mempoolAddTxs = \[tx]
txs -> do
          now <- m UTCTime
forall (m :: * -> *). MonadTime m => m UTCTime
getCurrentTime
          (invalidTxIds, validTxs) <- atomically $ do
            MempoolSeq { mempoolSet, mempoolSeq, nextIdx } <- readTVar mempool
            (invalidTxIds, validTxs) <-
                fmap partitionEithers
              . validateTx now
              . filter (\tx
tx -> tx -> txid
getTxId tx
tx txid -> Set txid -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.notMember` Set txid
mempoolSet)
              . nubBy (on (==) getTxId)
              $ txs
            let mempoolTxs' = MempoolSeq {
                    mempoolSet :: Set txid
mempoolSet = (Set txid -> tx -> Set txid) -> Set txid -> [tx] -> Set txid
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl' (\Set txid
s tx
tx -> tx -> txid
getTxId tx
tx txid -> Set txid -> Set txid
forall a. Ord a => a -> Set a -> Set a
`Set.insert` Set txid
s)
                                                 Set txid
mempoolSet
                                                 [tx]
validTxs,
                    mempoolSeq :: Seq (WithIndex tx)
mempoolSeq = (Seq (WithIndex tx) -> WithIndex tx -> Seq (WithIndex tx))
-> Seq (WithIndex tx) -> [WithIndex tx] -> Seq (WithIndex tx)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl' Seq (WithIndex tx) -> WithIndex tx -> Seq (WithIndex tx)
forall a. Seq a -> a -> Seq a
(Seq.|>) Seq (WithIndex tx)
mempoolSeq ((Integer -> tx -> WithIndex tx)
-> [Integer] -> [tx] -> [WithIndex tx]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Integer -> tx -> WithIndex tx
forall tx. Integer -> tx -> WithIndex tx
WithIndex [Integer
nextIdx..] [tx]
validTxs),
                    nextIdx :: Integer
nextIdx    = Integer
nextIdx Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([tx] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [tx]
validTxs)
                  }
            writeTVar mempool mempoolTxs'
            return ( invalidTxIds
                     ++
                     [ (txid, duplicateTx)
                     | txid <- filter (`Set.notMember` mempoolSet)
                             . map getTxId
                             $ txs
                     ]
                   , map getTxId validTxs
                   )
          handleInvalidTxs invalidTxIds
          return (validTxs, invalidTxIds)
      }