{-# LANGUAGE BangPatterns #-}

-- | STM TMergeVar mini-abstraction
--
module Control.Concurrent.Class.MonadSTM.Strict.TMergeVar
  ( TMergeVar
  , newTMergeVar
  , writeTMergeVar
  , takeTMergeVar
  , tryReadTMergeVar
  ) where

import Control.Concurrent.Class.MonadSTM.Strict

-- | The 'TMergeVar' is like a 'TMVar' in that we take it, leaving it empty.
-- Unlike an ordinary 'TMVar' with a blocking \'put\' operation, it has a
-- non-blocking combining write operation: if a value is already present then
-- the values are combined using the 'Semigroup' operator.
--
-- This is used much like a 'TMVar' as a one-place queue between threads but
-- with the property that we can \"improve\" the current value (if any).
--
newtype TMergeVar m a = TMergeVar (StrictTMVar m a)

newTMergeVar :: MonadSTM m => STM m (TMergeVar m a)
newTMergeVar :: forall (m :: * -> *) a. MonadSTM m => STM m (TMergeVar m a)
newTMergeVar = StrictTMVar m a -> TMergeVar m a
forall (m :: * -> *) a. StrictTMVar m a -> TMergeVar m a
TMergeVar (StrictTMVar m a -> TMergeVar m a)
-> STM m (StrictTMVar m a) -> STM m (TMergeVar m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM m (StrictTMVar m a)
forall (m :: * -> *) a. MonadSTM m => STM m (StrictTMVar m a)
newEmptyTMVar

-- | Merge the current value with the given one and store it, return the updated
-- value.
--
writeTMergeVar :: (MonadSTM m, Semigroup a) => TMergeVar m a -> a -> STM m a
writeTMergeVar :: forall (m :: * -> *) a.
(MonadSTM m, Semigroup a) =>
TMergeVar m a -> a -> STM m a
writeTMergeVar (TMergeVar StrictTMVar m a
v) a
x = do
    mx0 <- StrictTMVar m a -> STM m (Maybe a)
forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> STM m (Maybe a)
tryTakeTMVar StrictTMVar m a
v
    case mx0 of
      Maybe a
Nothing -> a
x  a -> STM m () -> STM m a
forall a b. a -> STM m b -> STM m a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ StrictTMVar m a -> a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> a -> STM m ()
putTMVar StrictTMVar m a
v a
x
      Just a
x0 -> a
x' a -> STM m () -> STM m a
forall a b. a -> STM m b -> STM m a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ StrictTMVar m a -> a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> a -> STM m ()
putTMVar StrictTMVar m a
v a
x' where !x' :: a
x' = a
x0 a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
x

takeTMergeVar :: MonadSTM m => TMergeVar m a -> STM m a
takeTMergeVar :: forall (m :: * -> *) a. MonadSTM m => TMergeVar m a -> STM m a
takeTMergeVar (TMergeVar StrictTMVar m a
v) = StrictTMVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => StrictTMVar m a -> STM m a
takeTMVar StrictTMVar m a
v

tryReadTMergeVar :: MonadSTM m
                 => TMergeVar m a
                 -> STM m (Maybe a)
tryReadTMergeVar :: forall (m :: * -> *) a.
MonadSTM m =>
TMergeVar m a -> STM m (Maybe a)
tryReadTMergeVar (TMergeVar StrictTMVar m a
v) = StrictTMVar m a -> STM m (Maybe a)
forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> STM m (Maybe a)
tryReadTMVar StrictTMVar m a
v