{-# LANGUAGE CPP, DeriveDataTypeable #-}
module Data.ByteString.Handle.Write
    ( writeHandle
    ) where

import Control.Monad ( when )
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Internal as BI
import qualified Data.ByteString.Lazy.Internal as BLI
import Data.IORef ( IORef, newIORef, readIORef, modifyIORef, writeIORef )
import Data.Typeable ( Typeable )
import System.IO
    ( Handle, hClose, IOMode( WriteMode )
    , noNewlineTranslation, nativeNewlineMode
    )

import GHC.IO.Buffer ( BufferState(..), emptyBuffer, Buffer(..) )
import GHC.IO.BufferedIO ( BufferedIO(..) )
import GHC.IO.Device ( IODevice(..), IODeviceType(..), SeekMode(..) )
#if MIN_VERSION_base(4,5,0)
import GHC.IO.Encoding ( getLocaleEncoding )
#else
import GHC.IO.Encoding ( localeEncoding )
#endif
import GHC.IO.Exception
    ( ioException, unsupportedOperation
    , IOException(IOError), IOErrorType(InvalidArgument)
    )
import GHC.IO.Handle ( mkFileHandle )

data SeekState =
    SeekState {
        SeekState -> Integer
seek_pos :: Integer,
        -- the start position of the current chunk
        SeekState -> Integer
seek_base :: Integer
    }

data WriteState =
    WriteState {
        -- the Integer is the cumulative size of the chunk + all those before it in the 
        WriteState -> IORef [(Integer, ByteString)]
write_chunks_backwards :: IORef [(Integer, B.ByteString)],
        WriteState -> IORef SeekState
write_seek_state :: IORef SeekState,
        WriteState -> IORef Integer
write_size :: IORef Integer
    }
    deriving Typeable

nextChunkSize :: Int -> Int
nextChunkSize :: Int -> Int
nextChunkSize lastSize :: Int
lastSize
    | Int
lastSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 16 = 16 -- the minimum size currently targeted by Data.ByteString.Lazy.cons'
    | 2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
lastSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
BLI.defaultChunkSize = Int
BLI.defaultChunkSize
    | Bool
otherwise = 2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
lastSize

chunkForPosition :: Integer -> IORef [(Integer, B.ByteString)] -> IO (Integer, B.ByteString)
chunkForPosition :: Integer
-> IORef [(Integer, ByteString)] -> IO (Integer, ByteString)
chunkForPosition pos :: Integer
pos chunks_backwards_ref :: IORef [(Integer, ByteString)]
chunks_backwards_ref = do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Integer
pos Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> 1000000) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error "gone"
    [(Integer, ByteString)]
chunks_backwards <- IORef [(Integer, ByteString)] -> IO [(Integer, ByteString)]
forall a. IORef a -> IO a
readIORef IORef [(Integer, ByteString)]
chunks_backwards_ref
    let (curSize :: Integer
curSize, lastSize :: Int
lastSize) =
         case [(Integer, ByteString)]
chunks_backwards of
             [] -> (0, 0)
             ((sz :: Integer
sz, c :: ByteString
c):_) -> (Integer
sz, ByteString -> Int
B.length ByteString
c)
    if Integer
pos Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
curSize
        then do
              let (sz :: Integer
sz, c :: ByteString
c) = [(Integer, ByteString)] -> (Integer, ByteString)
forall a. [a] -> a
head ([(Integer, ByteString)] -> (Integer, ByteString))
-> [(Integer, ByteString)] -> (Integer, ByteString)
forall a b. (a -> b) -> a -> b
$ ((Integer, ByteString) -> Bool)
-> [(Integer, ByteString)] -> [(Integer, ByteString)]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (\(sz :: Integer
sz, c :: ByteString
c) -> Integer
pos Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
sz Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
c)) [(Integer, ByteString)]
chunks_backwards
              (Integer, ByteString) -> IO (Integer, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
sz Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
c), ByteString
c)
        else do
              let sz :: Int
sz = Int -> Int
nextChunkSize Int
lastSize
              ForeignPtr Word8
newChunk <- Int -> IO (ForeignPtr Word8)
forall a. Int -> IO (ForeignPtr a)
BI.mallocByteString Int
sz
              let bs :: ByteString
bs = ForeignPtr Word8 -> Int -> Int -> ByteString
BI.fromForeignPtr ForeignPtr Word8
newChunk 0 Int
sz
              IORef [(Integer, ByteString)] -> [(Integer, ByteString)] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [(Integer, ByteString)]
chunks_backwards_ref ((Integer
curSize Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sz, ByteString
bs)(Integer, ByteString)
-> [(Integer, ByteString)] -> [(Integer, ByteString)]
forall a. a -> [a] -> [a]
:[(Integer, ByteString)]
chunks_backwards)
              Integer
-> IORef [(Integer, ByteString)] -> IO (Integer, ByteString)
chunkForPosition Integer
pos IORef [(Integer, ByteString)]
chunks_backwards_ref

initialWriteState :: IO WriteState
initialWriteState :: IO WriteState
initialWriteState = do
    IORef [(Integer, ByteString)]
chunks <- [(Integer, ByteString)] -> IO (IORef [(Integer, ByteString)])
forall a. a -> IO (IORef a)
newIORef []
    IORef SeekState
pos <- SeekState -> IO (IORef SeekState)
forall a. a -> IO (IORef a)
newIORef (SeekState -> IO (IORef SeekState))
-> SeekState -> IO (IORef SeekState)
forall a b. (a -> b) -> a -> b
$ SeekState :: Integer -> Integer -> SeekState
SeekState { seek_pos :: Integer
seek_pos = 0, seek_base :: Integer
seek_base = 0 }
    IORef Integer
sz <- Integer -> IO (IORef Integer)
forall a. a -> IO (IORef a)
newIORef 0

    WriteState -> IO WriteState
forall (m :: * -> *) a. Monad m => a -> m a
return (WriteState -> IO WriteState) -> WriteState -> IO WriteState
forall a b. (a -> b) -> a -> b
$    
        WriteState :: IORef [(Integer, ByteString)]
-> IORef SeekState -> IORef Integer -> WriteState
WriteState {
            write_chunks_backwards :: IORef [(Integer, ByteString)]
write_chunks_backwards = IORef [(Integer, ByteString)]
chunks,
            write_seek_state :: IORef SeekState
write_seek_state = IORef SeekState
pos,
            write_size :: IORef Integer
write_size = IORef Integer
sz
        }

instance BufferedIO WriteState where
    newBuffer :: WriteState -> BufferState -> IO (Buffer Word8)
newBuffer _ ReadBuffer = IOException -> IO (Buffer Word8)
forall a. IOException -> IO a
ioException IOException
unsupportedOperation
    newBuffer ws :: WriteState
ws WriteBuffer = do
       SeekState
ss <- IORef SeekState -> IO SeekState
forall a. IORef a -> IO a
readIORef (WriteState -> IORef SeekState
write_seek_state WriteState
ws) 
       (chunkBase :: Integer
chunkBase, chunk :: ByteString
chunk) <- Integer
-> IORef [(Integer, ByteString)] -> IO (Integer, ByteString)
chunkForPosition (SeekState -> Integer
seek_pos SeekState
ss) (WriteState -> IORef [(Integer, ByteString)]
write_chunks_backwards WriteState
ws)
       let chunkOffset :: Int
chunkOffset = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SeekState -> Integer
seek_pos SeekState
ss Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
chunkBase)
       let (ptr :: ForeignPtr Word8
ptr, bsOffset :: Int
bsOffset, len :: Int
len) = ByteString -> (ForeignPtr Word8, Int, Int)
BI.toForeignPtr ByteString
chunk
           buf :: Buffer Word8
buf = (ForeignPtr Word8 -> Int -> BufferState -> Buffer Word8
forall e. RawBuffer e -> Int -> BufferState -> Buffer e
emptyBuffer ForeignPtr Word8
ptr (Int
bsOffset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len) BufferState
WriteBuffer) {
                     bufL :: Int
bufL = Int
bsOffset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
chunkOffset, bufR :: Int
bufR = Int
bsOffset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
chunkOffset
           }
       IORef SeekState -> SeekState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (WriteState -> IORef SeekState
write_seek_state WriteState
ws) (SeekState
ss { seek_base :: Integer
seek_base = Integer
chunkBase Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bsOffset })
       Buffer Word8 -> IO (Buffer Word8)
forall (m :: * -> *) a. Monad m => a -> m a
return Buffer Word8
buf

    -- default impl for emptyWriteBuffer

    flushWriteBuffer :: WriteState -> Buffer Word8 -> IO (Buffer Word8)
flushWriteBuffer ws :: WriteState
ws buf :: Buffer Word8
buf = do
        SeekState
ss <- IORef SeekState -> IO SeekState
forall a. IORef a -> IO a
readIORef (WriteState -> IORef SeekState
write_seek_state WriteState
ws)
        let newPos :: Integer
newPos = SeekState -> Integer
seek_base SeekState
ss Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Buffer Word8 -> Int
forall e. Buffer e -> Int
bufR Buffer Word8
buf)
        IORef SeekState -> SeekState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (WriteState -> IORef SeekState
write_seek_state WriteState
ws)
                   (SeekState :: Integer -> Integer -> SeekState
SeekState { seek_pos :: Integer
seek_pos = Integer
newPos,
                                seek_base :: Integer
seek_base = [Char] -> Integer
forall a. HasCallStack => [Char] -> a
error "seek_base needs to be updated"
                   })
        IORef Integer -> (Integer -> Integer) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef (WriteState -> IORef Integer
write_size WriteState
ws) (Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
`max` Integer
newPos)
        WriteState -> BufferState -> IO (Buffer Word8)
forall dev.
BufferedIO dev =>
dev -> BufferState -> IO (Buffer Word8)
newBuffer WriteState
ws BufferState
WriteBuffer

    flushWriteBuffer0 :: WriteState -> Buffer Word8 -> IO (Int, Buffer Word8)
flushWriteBuffer0 ws :: WriteState
ws buf :: Buffer Word8
buf = do
        let count :: Int
count = Buffer Word8 -> Int
forall e. Buffer e -> Int
bufR Buffer Word8
buf Int -> Int -> Int
forall a. Num a => a -> a -> a
- Buffer Word8 -> Int
forall e. Buffer e -> Int
bufL Buffer Word8
buf
        Buffer Word8
newBuf <- WriteState -> Buffer Word8 -> IO (Buffer Word8)
forall dev.
BufferedIO dev =>
dev -> Buffer Word8 -> IO (Buffer Word8)
flushWriteBuffer WriteState
ws Buffer Word8
buf
        (Int, Buffer Word8) -> IO (Int, Buffer Word8)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
count, Buffer Word8
newBuf)

    fillReadBuffer :: WriteState -> Buffer Word8 -> IO (Int, Buffer Word8)
fillReadBuffer _ _ = IOException -> IO (Int, Buffer Word8)
forall a. IOException -> IO a
ioException IOException
unsupportedOperation

    fillReadBuffer0 :: WriteState -> Buffer Word8 -> IO (Maybe Int, Buffer Word8)
fillReadBuffer0 _ _ = IOException -> IO (Maybe Int, Buffer Word8)
forall a. IOException -> IO a
ioException IOException
unsupportedOperation

instance IODevice WriteState where
    ready :: WriteState -> Bool -> Int -> IO Bool
ready _ _ _ = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    close :: WriteState -> IO ()
close ws :: WriteState
ws = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    isSeekable :: WriteState -> IO Bool
isSeekable _ = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

    seek :: WriteState -> SeekMode -> Integer -> IO ()
seek ws :: WriteState
ws seekMode :: SeekMode
seekMode seekPos :: Integer
seekPos = do
        SeekState
curSeekState <- IORef SeekState -> IO SeekState
forall a. IORef a -> IO a
readIORef (WriteState -> IORef SeekState
write_seek_state WriteState
ws)
        Integer
newSeekPos <-
              case SeekMode
seekMode of
                  AbsoluteSeek -> Integer -> IO Integer
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
seekPos
                  RelativeSeek -> Integer -> IO Integer
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> IO Integer) -> Integer -> IO Integer
forall a b. (a -> b) -> a -> b
$ SeekState -> Integer
seek_pos SeekState
curSeekState Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
seekPos
                  -- can probably assume last buffer is flushed, so could probably count the
                  -- current end pos if we really wanted to
                  SeekFromEnd -> IOException -> IO Integer
forall a. IOException -> IO a
ioException IOException
unsupportedOperation
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Integer
newSeekPos Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO ()
forall a. IO a
ioe_seekOutOfRange
        IORef SeekState -> SeekState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (WriteState -> IORef SeekState
write_seek_state WriteState
ws)
                   (SeekState :: Integer -> Integer -> SeekState
SeekState { seek_pos :: Integer
seek_pos = Integer
newSeekPos,
                                seek_base :: Integer
seek_base = [Char] -> Integer
forall a. HasCallStack => [Char] -> a
error "seek_base needs to be updated"
                   })
        IORef Integer -> (Integer -> Integer) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef (WriteState -> IORef Integer
write_size WriteState
ws) (Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
`max` Integer
newSeekPos)

    tell :: WriteState -> IO Integer
tell ws :: WriteState
ws = do
        SeekState
ss <- IORef SeekState -> IO SeekState
forall a. IORef a -> IO a
readIORef (WriteState -> IORef SeekState
write_seek_state WriteState
ws)
        Integer -> IO Integer
forall (m :: * -> *) a. Monad m => a -> m a
return (SeekState -> Integer
seek_pos SeekState
ss)

    getSize :: WriteState -> IO Integer
getSize ws :: WriteState
ws = IORef Integer -> IO Integer
forall a. IORef a -> IO a
readIORef (WriteState -> IORef Integer
write_size WriteState
ws)
    setSize :: WriteState -> Integer -> IO ()
setSize ws :: WriteState
ws sz :: Integer
sz = do
        IORef Integer -> Integer -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (WriteState -> IORef Integer
write_size WriteState
ws) Integer
sz
        -- force chunk creation
        (Integer, ByteString)
_ <- Integer
-> IORef [(Integer, ByteString)] -> IO (Integer, ByteString)
chunkForPosition Integer
sz (WriteState -> IORef [(Integer, ByteString)]
write_chunks_backwards WriteState
ws)
        () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    devType :: WriteState -> IO IODeviceType
devType _ = IODeviceType -> IO IODeviceType
forall (m :: * -> *) a. Monad m => a -> m a
return IODeviceType
RegularFile -- TODO: is this correct?

ioe_seekOutOfRange :: IO a
ioe_seekOutOfRange :: IO a
ioe_seekOutOfRange =
    IOException -> IO a
forall a. IOException -> IO a
ioException (IOException -> IO a) -> IOException -> IO a
forall a b. (a -> b) -> a -> b
$ Maybe Handle
-> IOErrorType
-> [Char]
-> [Char]
-> Maybe CInt
-> Maybe [Char]
-> IOException
IOError Maybe Handle
forall a. Maybe a
Nothing IOErrorType
InvalidArgument ""
                          "attempt to seek outside the file" Maybe CInt
forall a. Maybe a
Nothing Maybe [Char]
forall a. Maybe a
Nothing

writeHandle :: Bool -> (Handle -> IO a) -> IO (BL.ByteString, a)
writeHandle :: Bool -> (Handle -> IO a) -> IO (ByteString, a)
writeHandle binary :: Bool
binary doOutput :: Handle -> IO a
doOutput = do

    WriteState
ws <- IO WriteState
initialWriteState

#if MIN_VERSION_base(4,5,0)
    TextEncoding
localeEnc <- IO TextEncoding
getLocaleEncoding
#else
    localeEnc <- return localeEncoding
#endif

    let (encoding :: Maybe TextEncoding
encoding, newline :: NewlineMode
newline)
         | Bool
binary    = (Maybe TextEncoding
forall a. Maybe a
Nothing       , NewlineMode
noNewlineTranslation)
         | Bool
otherwise = (TextEncoding -> Maybe TextEncoding
forall a. a -> Maybe a
Just TextEncoding
localeEnc, NewlineMode
nativeNewlineMode   )

    Handle
handle <- WriteState
-> [Char]
-> IOMode
-> Maybe TextEncoding
-> NewlineMode
-> IO Handle
forall dev.
(IODevice dev, BufferedIO dev, Typeable dev) =>
dev
-> [Char]
-> IOMode
-> Maybe TextEncoding
-> NewlineMode
-> IO Handle
mkFileHandle WriteState
ws "ByteString" IOMode
WriteMode Maybe TextEncoding
encoding NewlineMode
newline

    a
res <- Handle -> IO a
doOutput Handle
handle

    Handle -> IO ()
hClose Handle
handle

    Integer
sz <- IORef Integer -> IO Integer
forall a. IORef a -> IO a
readIORef (WriteState -> IORef Integer
write_size WriteState
ws)
    [(Integer, ByteString)]
chunks_backwards <- IORef [(Integer, ByteString)] -> IO [(Integer, ByteString)]
forall a. IORef a -> IO a
readIORef (WriteState -> IORef [(Integer, ByteString)]
write_chunks_backwards WriteState
ws)

    let bs :: ByteString
bs = Int64 -> ByteString -> ByteString
BL.take (Integer -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
sz) (ByteString -> ByteString)
-> ([(Integer, ByteString)] -> ByteString)
-> [(Integer, ByteString)]
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
BL.fromChunks ([ByteString] -> ByteString)
-> ([(Integer, ByteString)] -> [ByteString])
-> [(Integer, ByteString)]
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse ([ByteString] -> [ByteString])
-> ([(Integer, ByteString)] -> [ByteString])
-> [(Integer, ByteString)]
-> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Integer, ByteString) -> ByteString)
-> [(Integer, ByteString)] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (Integer, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ([(Integer, ByteString)] -> ByteString)
-> [(Integer, ByteString)] -> ByteString
forall a b. (a -> b) -> a -> b
$ [(Integer, ByteString)]
chunks_backwards

    (ByteString, a) -> IO (ByteString, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
bs, a
res)