{-# LANGUAGE CPP                   #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module Ouroboros.Network.RawBearer where

import Data.Word (Word8)
import Foreign.Ptr (Ptr)
import Network.Socket (Socket)
import Network.Socket qualified as Socket

#if defined(mingw32_HOST_OS)
import Foreign.Ptr (castPtr)
import System.Win32 qualified as Win32
#endif

-- | Generalized API for sending and receiving raw bytes over a file
-- descriptor, socket, or similar object.
data RawBearer m =
  RawBearer
    { forall (m :: * -> *). RawBearer m -> Ptr Word8 -> Int -> m Int
send :: Ptr Word8 -> Int -> m Int
    , forall (m :: * -> *). RawBearer m -> Ptr Word8 -> Int -> m Int
recv :: Ptr Word8 -> Int -> m Int
    }

newtype MakeRawBearer m fd = MakeRawBearer {
  forall (m :: * -> *) fd.
MakeRawBearer m fd -> fd -> m (RawBearer m)
getRawBearer :: fd -> m (RawBearer m)
}

makeSocketRawBearer :: MakeRawBearer IO Socket
makeSocketRawBearer :: MakeRawBearer IO Socket
makeSocketRawBearer = (Socket -> IO (RawBearer IO)) -> MakeRawBearer IO Socket
forall (m :: * -> *) fd.
(fd -> m (RawBearer m)) -> MakeRawBearer m fd
MakeRawBearer (RawBearer IO -> IO (RawBearer IO)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RawBearer IO -> IO (RawBearer IO))
-> (Socket -> RawBearer IO) -> Socket -> IO (RawBearer IO)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> RawBearer IO
socketToRawBearer)

socketToRawBearer :: Socket -> RawBearer IO
socketToRawBearer :: Socket -> RawBearer IO
socketToRawBearer Socket
s =
    RawBearer
      { send :: Ptr Word8 -> Int -> IO Int
send = Socket -> Ptr Word8 -> Int -> IO Int
Socket.sendBuf Socket
s
      , recv :: Ptr Word8 -> Int -> IO Int
recv = Socket -> Ptr Word8 -> Int -> IO Int
Socket.recvBuf Socket
s
      }

#if defined(mingw32_HOST_OS)

win32MakeRawBearer :: MakeRawBearer IO Win32.HANDLE
win32MakeRawBearer = MakeRawBearer (return . win32HandleToRawBearer)

win32HandleToRawBearer :: Win32.HANDLE -> RawBearer IO
win32HandleToRawBearer s =
    RawBearer
      { send = \buf size -> fromIntegral <$> Win32.win32_WriteFile s (castPtr buf) (fromIntegral size) Nothing
      , recv = \buf size -> fromIntegral <$> Win32.win32_ReadFile s (castPtr buf) (fromIntegral size) Nothing
      }
#endif