{-# LANGUAGE DerivingStrategies  #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}

-- | The module should be imported qualified.
--
module Ouroboros.Network.TxSubmission.Mempool.Simple
  ( Mempool (..)
  , MempoolSeq (..)
  , WithIndex (..)
  , empty
  , new
  , read
  , getReader
  , TxSubmissionMempoolReader (..)
  , getWriter
  , TxSubmissionMempoolWriter (..)
  ) where

import Prelude hiding (read, seq)

import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime.SI

import Data.Bifunctor (first)
import Data.Either (partitionEithers)
import Data.Foldable (toList)
import Data.Foldable qualified as Foldable
import Data.List (find)
import Data.List qualified as List
import Data.Sequence (Seq)
import Data.Sequence qualified as Seq
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Typeable (Typeable)

import Ouroboros.Network.SizeInBytes
import Ouroboros.Network.TxSubmission.Inbound.V2.Types
import Ouroboros.Network.TxSubmission.Mempool.Reader

data WithIndex tx = WithIndex { forall tx. WithIndex tx -> Integer
getIdx :: !Integer,
                                forall tx. WithIndex tx -> tx
getTx  :: !tx }

data MempoolSeq txid tx = MempoolSeq {
    forall txid tx. MempoolSeq txid tx -> Set txid
mempoolSet :: !(Set txid),
    -- ^ cached set of `txid`s in the mempool
    forall txid tx. MempoolSeq txid tx -> Seq (WithIndex tx)
mempoolSeq :: !(Seq (WithIndex tx)),
    -- ^ sequence of all `tx`s
    forall txid tx. MempoolSeq txid tx -> Integer
nextIdx    :: !Integer
    -- ^ next available index
    --
    -- Invariant: larger than the index of the last element of `mempoolSeq`.
  }

-- | A simple in-memory mempool implementation.
--
newtype Mempool m txid tx = Mempool (StrictTVar m (MempoolSeq txid tx))


empty :: MonadSTM m => m (Mempool m txid tx)
empty :: forall (m :: * -> *) txid tx. MonadSTM m => m (Mempool m txid tx)
empty = StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
forall (m :: * -> *) txid tx.
StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
Mempool (StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx)
-> m (StrictTVar m (MempoolSeq txid tx)) -> m (Mempool m txid tx)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MempoolSeq txid tx -> m (StrictTVar m (MempoolSeq txid tx))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO (Set txid -> Seq (WithIndex tx) -> Integer -> MempoolSeq txid tx
forall txid tx.
Set txid -> Seq (WithIndex tx) -> Integer -> MempoolSeq txid tx
MempoolSeq Set txid
forall a. Set a
Set.empty Seq (WithIndex tx)
forall a. Seq a
Seq.empty (-Integer
1))


new :: ( MonadSTM m
       , Ord txid
       )
    => (tx -> txid)
    -> [tx]
    -> m (Mempool m txid tx)
new :: forall (m :: * -> *) txid tx.
(MonadSTM m, Ord txid) =>
(tx -> txid) -> [tx] -> m (Mempool m txid tx)
new tx -> txid
getTxId [tx]
txs =
      (StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx)
-> m (StrictTVar m (MempoolSeq txid tx)) -> m (Mempool m txid tx)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
forall (m :: * -> *) txid tx.
StrictTVar m (MempoolSeq txid tx) -> Mempool m txid tx
Mempool
    (m (StrictTVar m (MempoolSeq txid tx)) -> m (Mempool m txid tx))
-> (MempoolSeq txid tx -> m (StrictTVar m (MempoolSeq txid tx)))
-> MempoolSeq txid tx
-> m (Mempool m txid tx)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MempoolSeq txid tx -> m (StrictTVar m (MempoolSeq txid tx))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO
    (MempoolSeq txid tx -> m (Mempool m txid tx))
-> MempoolSeq txid tx -> m (Mempool m txid tx)
forall a b. (a -> b) -> a -> b
$ MempoolSeq { mempoolSet :: Set txid
mempoolSet = [txid] -> Set txid
forall a. Ord a => [a] -> Set a
Set.fromList (tx -> txid
getTxId (tx -> txid) -> [tx] -> [txid]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [tx]
txs),
                   Seq (WithIndex tx)
mempoolSeq :: Seq (WithIndex tx)
mempoolSeq :: Seq (WithIndex tx)
mempoolSeq,
                   nextIdx :: Integer
nextIdx = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Seq (WithIndex tx) -> Int
forall a. Seq a -> Int
Seq.length Seq (WithIndex tx)
mempoolSeq)
                 }
  where
    mempoolSeq :: Seq (WithIndex tx)
mempoolSeq = [WithIndex tx] -> Seq (WithIndex tx)
forall a. [a] -> Seq a
Seq.fromList ([WithIndex tx] -> Seq (WithIndex tx))
-> [WithIndex tx] -> Seq (WithIndex tx)
forall a b. (a -> b) -> a -> b
$ (Integer -> tx -> WithIndex tx)
-> [Integer] -> [tx] -> [WithIndex tx]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Integer -> tx -> WithIndex tx
forall tx. Integer -> tx -> WithIndex tx
WithIndex [Integer
0..] [tx]
txs


read :: MonadSTM m => Mempool m txid tx -> m [tx]
read :: forall (m :: * -> *) txid tx.
MonadSTM m =>
Mempool m txid tx -> m [tx]
read (Mempool StrictTVar m (MempoolSeq txid tx)
mempool) = (WithIndex tx -> tx) -> [WithIndex tx] -> [tx]
forall a b. (a -> b) -> [a] -> [b]
map WithIndex tx -> tx
forall tx. WithIndex tx -> tx
getTx ([WithIndex tx] -> [tx])
-> (MempoolSeq txid tx -> [WithIndex tx])
-> MempoolSeq txid tx
-> [tx]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (WithIndex tx) -> [WithIndex tx]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq (WithIndex tx) -> [WithIndex tx])
-> (MempoolSeq txid tx -> Seq (WithIndex tx))
-> MempoolSeq txid tx
-> [WithIndex tx]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MempoolSeq txid tx -> Seq (WithIndex tx)
forall txid tx. MempoolSeq txid tx -> Seq (WithIndex tx)
mempoolSeq (MempoolSeq txid tx -> [tx]) -> m (MempoolSeq txid tx) -> m [tx]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTVar m (MempoolSeq txid tx) -> m (MempoolSeq txid tx)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> m a
readTVarIO StrictTVar m (MempoolSeq txid tx)
mempool


getReader :: forall tx txid m.
             ( MonadSTM m
             , Ord txid
             )
          => (tx -> txid)
          -> (tx -> SizeInBytes)
          -> Mempool m txid tx
          -> TxSubmissionMempoolReader txid tx Integer m
getReader :: forall tx txid (m :: * -> *).
(MonadSTM m, Ord txid) =>
(tx -> txid)
-> (tx -> SizeInBytes)
-> Mempool m txid tx
-> TxSubmissionMempoolReader txid tx Integer m
getReader tx -> txid
getTxId tx -> SizeInBytes
getTxSize (Mempool StrictTVar m (MempoolSeq txid tx)
mempool) =
    -- Using `0`-based index.  `mempoolZeroIdx = -1` so that
    -- `mempoolTxIdsAfter mempoolZeroIdx` returns all txs.
    TxSubmissionMempoolReader { STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot,
                                mempoolZeroIdx :: Integer
mempoolZeroIdx = -Integer
1
                              }
  where
    mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
    mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Integer)
mempoolGetSnapshot = MempoolSeq txid tx -> MempoolSnapshot txid tx Integer
getSnapshot (MempoolSeq txid tx -> MempoolSnapshot txid tx Integer)
-> STM m (MempoolSeq txid tx)
-> STM m (MempoolSnapshot txid tx Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTVar m (MempoolSeq txid tx) -> STM m (MempoolSeq txid tx)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (MempoolSeq txid tx)
mempool

    getSnapshot :: MempoolSeq txid tx
                -> MempoolSnapshot txid tx Integer
    getSnapshot :: MempoolSeq txid tx -> MempoolSnapshot txid tx Integer
getSnapshot MempoolSeq { mempoolSeq :: forall txid tx. MempoolSeq txid tx -> Seq (WithIndex tx)
mempoolSeq = Seq (WithIndex tx)
seq, Set txid
mempoolSet :: forall txid tx. MempoolSeq txid tx -> Set txid
mempoolSet :: Set txid
mempoolSet } =
      MempoolSnapshot {
          mempoolTxIdsAfter :: Integer -> [(txid, Integer, SizeInBytes)]
mempoolTxIdsAfter =
            \Integer
idx ->
              (WithIndex tx
 -> [(txid, Integer, SizeInBytes)]
 -> [(txid, Integer, SizeInBytes)])
-> [(txid, Integer, SizeInBytes)]
-> Seq (WithIndex tx)
-> [(txid, Integer, SizeInBytes)]
forall a b. (a -> b -> b) -> b -> Seq a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
                (\WithIndex {Integer
getIdx :: forall tx. WithIndex tx -> Integer
getIdx :: Integer
getIdx, tx
getTx :: forall tx. WithIndex tx -> tx
getTx :: tx
getTx} [(txid, Integer, SizeInBytes)]
acc ->
                  if Integer
getIdx Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
idx
                  then (tx -> txid
getTxId tx
getTx, Integer
getIdx, tx -> SizeInBytes
getTxSize tx
getTx) (txid, Integer, SizeInBytes)
-> [(txid, Integer, SizeInBytes)] -> [(txid, Integer, SizeInBytes)]
forall a. a -> [a] -> [a]
: [(txid, Integer, SizeInBytes)]
acc
                  else [(txid, Integer, SizeInBytes)]
acc
                )
                []
                Seq (WithIndex tx)
seq,
          mempoolLookupTx :: Integer -> Maybe tx
mempoolLookupTx =
            \Integer
idx -> WithIndex tx -> tx
forall tx. WithIndex tx -> tx
getTx (WithIndex tx -> tx) -> Maybe (WithIndex tx) -> Maybe tx
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithIndex tx -> Bool)
-> Seq (WithIndex tx) -> Maybe (WithIndex tx)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\WithIndex {Integer
getIdx :: forall tx. WithIndex tx -> Integer
getIdx :: Integer
getIdx} -> Integer
getIdx Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
idx) Seq (WithIndex tx)
seq,
          mempoolHasTx :: txid -> Bool
mempoolHasTx =
            \txid
txid -> txid
txid txid -> Set txid -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set txid
mempoolSet
        }


data InvalidTxsError where
    InvalidTxsError :: forall txid failure.
                       ( Typeable txid
                       , Typeable failure
                       , Show txid
                       , Show failure
                       )
                    => [(txid, failure)]
                    -> InvalidTxsError

deriving instance Show InvalidTxsError
instance Exception InvalidTxsError


-- | A simple mempool writer.
--
getWriter :: forall tx txid failure m.
             ( MonadSTM m
             , MonadTime m
             , Ord txid
             )
          => failure
          -- ^ duplicate tx error
          -> (tx -> txid)
          -- ^ get transaction hash
          -> (UTCTime -> [tx] -> STM m [Either (txid, failure) tx])
          -- ^ validate a tx in an `STM` transaction, this allows acquiring and
          -- updating validation context.
          -> ([(txid, failure)] -> m ())
          -- ^ handle invalid txs, e.g. logging, throwing exceptions, etc
          -> Mempool m txid tx
          -- ^ mempool
          -> TxSubmissionMempoolWriter txid tx Integer m failure
getWriter :: forall tx txid failure (m :: * -> *).
(MonadSTM m, MonadTime m, Ord txid) =>
failure
-> (tx -> txid)
-> (UTCTime -> [tx] -> STM m [Either (txid, failure) tx])
-> ([(txid, failure)] -> m ())
-> Mempool m txid tx
-> TxSubmissionMempoolWriter txid tx Integer m failure
getWriter failure
duplicateTxError tx -> txid
getTxId UTCTime -> [tx] -> STM m [Either (txid, failure) tx]
validateTx [(txid, failure)] -> m ()
handleInvalidTxs (Mempool StrictTVar m (MempoolSeq txid tx)
mempoolVar) =
    TxSubmissionMempoolWriter {
        txId :: tx -> txid
txId = tx -> txid
getTxId,

        mempoolAddTxs :: [tx] -> m ([txid], [(txid, failure)])
mempoolAddTxs = \[tx]
txs -> do
          now <- m UTCTime
forall (m :: * -> *). MonadTime m => m UTCTime
getCurrentTime
          (rejectedTxIds, acceptedTxs) <- atomically $ do
            MempoolSeq { mempoolSet, mempoolSeq, nextIdx } <- readTVar mempoolVar

            -- remove txs that are already in the mempool
            -- so we don't validate txs which are already in the mempool
            let duplicateTxIds :: [txid]
                (duplicateTxIds, txs') =
                    first (map getTxId)
                  $ List.partition (\tx
tx -> tx -> txid
getTxId tx
tx txid -> Set txid -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set txid
mempoolSet) txs

            -- validate `txs'`
            -- NOTE: we might have duplicates in the `txs'` list
            (invalidTxIds, validTxs) <-
                fmap partitionEithers
              . validateTx now
              $ txs'

            let acceptedTxs     :: [txid]
                -- duplicate txids in the submitted list `txs`
                duplicateTxIds' :: [txid]

                -- delta - set of newly accepted txids
                -- NOTE: `validTxs` are not in the mempool, we just need to
                -- check that we are not adding multiple copies of the same tx.
                (delta, mempoolSeq', nextIdx', acceptedTxs, duplicateTxIds') =
                  Foldable.foldl'
                    (\(Set txid
set, Seq (WithIndex tx)
seq, Integer
idx, [txid]
as, [txid]
rs) tx
tx ->
                      if tx -> txid
getTxId tx
tx txid -> Set txid -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set txid
set
                      then ( Set txid
set
                           , Seq (WithIndex tx)
seq
                           , Integer
idx
                           , [txid]
as
                           , tx -> txid
getTxId tx
tx txid -> [txid] -> [txid]
forall a. a -> [a] -> [a]
: [txid]
rs
                           )
                      else ( txid -> Set txid -> Set txid
forall a. Ord a => a -> Set a -> Set a
Set.insert (tx -> txid
getTxId tx
tx) Set txid
set
                           , Seq (WithIndex tx)
seq Seq (WithIndex tx) -> WithIndex tx -> Seq (WithIndex tx)
forall a. Seq a -> a -> Seq a
Seq.|> Integer -> tx -> WithIndex tx
forall tx. Integer -> tx -> WithIndex tx
WithIndex Integer
idx tx
tx
                           , Integer -> Integer
forall a. Enum a => a -> a
succ Integer
idx
                           , tx -> txid
getTxId tx
tx txid -> [txid] -> [txid]
forall a. a -> [a] -> [a]
: [txid]
as
                           , [txid]
rs
                           )
                    )
                    (Set.empty, mempoolSeq, nextIdx, [], [])
                  validTxs
            writeTVar mempoolVar MempoolSeq { mempoolSet = mempoolSet `Set.union` delta
                                            , mempoolSeq = mempoolSeq'
                                            , nextIdx    = nextIdx'
                                            }
            return ( invalidTxIds
                     ++
                     [ (txid, duplicateTxError)
                     | txid <- duplicateTxIds ++ duplicateTxIds'
                     ]
                   , acceptedTxs
                   )
          handleInvalidTxs rejectedTxIds
          return (acceptedTxs, rejectedTxIds)
      }