{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | The module should be imported qualified.
--
module Ouroboros.Network.TxSubmission.Mempool.Simple
  ( Mempool (..)
  , empty
  , new
  , read
  , getReader
  , getWriter
  ) where

import Prelude hiding (read, seq)

import Control.Concurrent.Class.MonadSTM.Strict

import Data.Foldable (toList)
import Data.Foldable qualified as Foldable
import Data.Function (on)
import Data.List (find, nubBy)
import Data.Maybe (isJust)
import Data.Sequence (Seq)
import Data.Sequence qualified as Seq
import Data.Set qualified as Set

import Ouroboros.Network.SizeInBytes
import Ouroboros.Network.TxSubmission.Inbound.V2.Types
import Ouroboros.Network.TxSubmission.Mempool.Reader


-- | A simple in-memory mempool implementation.
--
newtype Mempool m tx = Mempool (StrictTVar m (Seq tx))


empty :: MonadSTM m => m (Mempool m tx)
empty :: forall (m :: * -> *) tx. MonadSTM m => m (Mempool m tx)
empty = StrictTVar m (Seq tx) -> Mempool m tx
forall (m :: * -> *) tx. StrictTVar m (Seq tx) -> Mempool m tx
Mempool (StrictTVar m (Seq tx) -> Mempool m tx)
-> m (StrictTVar m (Seq tx)) -> m (Mempool m tx)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Seq tx -> m (StrictTVar m (Seq tx))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO Seq tx
forall a. Seq a
Seq.empty


new :: MonadSTM m
    => [tx]
    -> m (Mempool m tx)
new :: forall (m :: * -> *) tx. MonadSTM m => [tx] -> m (Mempool m tx)
new = (StrictTVar m (Seq tx) -> Mempool m tx)
-> m (StrictTVar m (Seq tx)) -> m (Mempool m tx)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap StrictTVar m (Seq tx) -> Mempool m tx
forall (m :: * -> *) tx. StrictTVar m (Seq tx) -> Mempool m tx
Mempool
    (m (StrictTVar m (Seq tx)) -> m (Mempool m tx))
-> ([tx] -> m (StrictTVar m (Seq tx))) -> [tx] -> m (Mempool m tx)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq tx -> m (StrictTVar m (Seq tx))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO
    (Seq tx -> m (StrictTVar m (Seq tx)))
-> ([tx] -> Seq tx) -> [tx] -> m (StrictTVar m (Seq tx))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [tx] -> Seq tx
forall a. [a] -> Seq a
Seq.fromList


read :: MonadSTM m => Mempool m tx -> m [tx]
read :: forall (m :: * -> *) tx. MonadSTM m => Mempool m tx -> m [tx]
read (Mempool StrictTVar m (Seq tx)
mempool) = Seq tx -> [tx]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq tx -> [tx]) -> m (Seq tx) -> m [tx]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTVar m (Seq tx) -> m (Seq tx)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> m a
readTVarIO StrictTVar m (Seq tx)
mempool


getReader :: forall tx txid m.
             ( MonadSTM m
             , Eq txid
             )
          => (tx -> txid)
          -> (tx -> SizeInBytes)
          -> Mempool m tx
          -> TxSubmissionMempoolReader txid tx Int m
getReader :: forall tx txid (m :: * -> *).
(MonadSTM m, Eq txid) =>
(tx -> txid)
-> (tx -> SizeInBytes)
-> Mempool m tx
-> TxSubmissionMempoolReader txid tx Int m
getReader tx -> txid
getTxId tx -> SizeInBytes
getTxSize (Mempool StrictTVar m (Seq tx)
mempool) =
    -- Using `0`-based index.  `mempoolZeroIdx = -1` so that
    -- `mempoolTxIdsAfter mempoolZeroIdx` returns all txs.
    TxSubmissionMempoolReader { STM m (MempoolSnapshot txid tx Int)
mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Int)
mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Int)
mempoolGetSnapshot,
                                mempoolZeroIdx :: Int
mempoolZeroIdx = -Int
1
                              }
  where
    mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Int)
    mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Int)
mempoolGetSnapshot = Seq tx -> MempoolSnapshot txid tx Int
getSnapshot (Seq tx -> MempoolSnapshot txid tx Int)
-> STM m (Seq tx) -> STM m (MempoolSnapshot txid tx Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTVar m (Seq tx) -> STM m (Seq tx)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (Seq tx)
mempool

    getSnapshot :: Seq tx
                -> MempoolSnapshot txid tx Int
    getSnapshot :: Seq tx -> MempoolSnapshot txid tx Int
getSnapshot Seq tx
seq =
      MempoolSnapshot {
          mempoolTxIdsAfter :: Int -> [(txid, Int, SizeInBytes)]
mempoolTxIdsAfter = \Int
idx -> (Int -> tx -> (txid, Int, SizeInBytes))
-> [Int] -> [tx] -> [(txid, Int, SizeInBytes)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> tx -> (txid, Int, SizeInBytes)
f [Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1..]
                                                (Seq tx -> [tx]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq tx -> [tx]) -> Seq tx -> [tx]
forall a b. (a -> b) -> a -> b
$ Int -> Seq tx -> Seq tx
forall a. Int -> Seq a -> Seq a
Seq.drop (Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Seq tx
seq),
          mempoolLookupTx :: Int -> Maybe tx
mempoolLookupTx   = \Int
idx -> Int -> Seq tx -> Maybe tx
forall a. Int -> Seq a -> Maybe a
Seq.lookup Int
idx Seq tx
seq,
          mempoolHasTx :: txid -> Bool
mempoolHasTx      = \txid
txid -> Maybe tx -> Bool
forall a. Maybe a -> Bool
isJust (Maybe tx -> Bool) -> Maybe tx -> Bool
forall a b. (a -> b) -> a -> b
$ (tx -> Bool) -> Seq tx -> Maybe tx
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\tx
tx -> tx -> txid
getTxId tx
tx txid -> txid -> Bool
forall a. Eq a => a -> a -> Bool
== txid
txid) Seq tx
seq
       }

    f :: Int -> tx -> (txid, Int, SizeInBytes)
    f :: Int -> tx -> (txid, Int, SizeInBytes)
f Int
idx tx
tx = (tx -> txid
getTxId tx
tx, Int
idx, tx -> SizeInBytes
getTxSize tx
tx)


-- | A simple mempool writer.
--
getWriter :: forall tx txid m.
             ( MonadSTM m
             , Ord txid
             )
          => (tx -> txid)
          -> (tx -> Bool)
          -- ^ validate a tx
          -> Mempool m tx
          -> TxSubmissionMempoolWriter txid tx Int m
getWriter :: forall tx txid (m :: * -> *).
(MonadSTM m, Ord txid) =>
(tx -> txid)
-> (tx -> Bool)
-> Mempool m tx
-> TxSubmissionMempoolWriter txid tx Int m
getWriter tx -> txid
getTxId tx -> Bool
validateTx (Mempool StrictTVar m (Seq tx)
mempool) =
    TxSubmissionMempoolWriter {
        txId :: tx -> txid
txId = tx -> txid
getTxId,

        mempoolAddTxs :: [tx] -> m [txid]
mempoolAddTxs = \[tx]
txs -> do
          STM m [txid] -> m [txid]
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m [txid] -> m [txid]) -> STM m [txid] -> m [txid]
forall a b. (a -> b) -> a -> b
$ do
            mempoolTxs <- StrictTVar m (Seq tx) -> STM m (Seq tx)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (Seq tx)
mempool
            let currentIds = [txid] -> Set txid
forall a. Ord a => [a] -> Set a
Set.fromList ((tx -> txid) -> [tx] -> [txid]
forall a b. (a -> b) -> [a] -> [b]
map tx -> txid
getTxId (Seq tx -> [tx]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList Seq tx
mempoolTxs))
                validTxs = (tx -> tx -> Bool) -> [tx] -> [tx]
forall a. (a -> a -> Bool) -> [a] -> [a]
nubBy ((txid -> txid -> Bool) -> (tx -> txid) -> tx -> tx -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
on txid -> txid -> Bool
forall a. Eq a => a -> a -> Bool
(==) tx -> txid
getTxId)
                         ([tx] -> [tx]) -> [tx] -> [tx]
forall a b. (a -> b) -> a -> b
$ (tx -> Bool) -> [tx] -> [tx]
forall a. (a -> Bool) -> [a] -> [a]
filter
                            (\tx
tx -> tx -> Bool
validateTx tx
tx
                                 Bool -> Bool -> Bool
&& tx -> txid
getTxId tx
tx txid -> Set txid -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.notMember` Set txid
currentIds)
                           [tx]
txs
                mempoolTxs' = (Seq tx -> tx -> Seq tx) -> Seq tx -> [tx] -> Seq tx
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl' Seq tx -> tx -> Seq tx
forall a. Seq a -> a -> Seq a
(Seq.|>) Seq tx
mempoolTxs [tx]
validTxs
            writeTVar mempool mempoolTxs'
            return (map getTxId validTxs)
      }