{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE Rank2Types #-}

module Data.Conduit.Cereal.Internal
  ( ConduitErrorHandler
  , SinkErrorHandler
  , SinkTerminationHandler

  , mkConduitGet
  , mkSinkGet
  ) where

import           Control.Monad (forever, when)
import qualified Data.ByteString as BS
import           Data.Conduit (ConduitT, await, leftover, yield)
import           Data.Serialize hiding (get, put)

-- | What should we do if the Get fails?
type ConduitErrorHandler m o = String -> ConduitT BS.ByteString o m ()
type SinkErrorHandler m r = forall o. String -> ConduitT BS.ByteString o m r

-- | What should we do if the stream is done before the Get is done?
type SinkTerminationHandler m r = forall o. (BS.ByteString -> Result r) -> ConduitT BS.ByteString o m r

-- | Construct a conduitGet with the specified 'ErrorHandler'
mkConduitGet :: Monad m
             => ConduitErrorHandler m o
             -> Get o
             -> ConduitT BS.ByteString o m ()
mkConduitGet :: ConduitErrorHandler m o -> Get o -> ConduitT ByteString o m ()
mkConduitGet errorHandler :: ConduitErrorHandler m o
errorHandler get :: Get o
get = Bool
-> (ByteString -> Result o)
-> [ByteString]
-> ByteString
-> ConduitT ByteString o m ()
consume Bool
True (Get o -> ByteString -> Result o
forall a. Get a -> ByteString -> Result a
runGetPartial Get o
get) [] ByteString
BS.empty
  where pull :: (ByteString -> Result o)
-> [ByteString] -> ByteString -> ConduitT ByteString o m ()
pull f :: ByteString -> Result o
f b :: [ByteString]
b s :: ByteString
s
          | ByteString -> Bool
BS.null ByteString
s = ConduitT ByteString o m (Maybe ByteString)
forall (m :: * -> *) i. Monad m => Consumer i m (Maybe i)
await ConduitT ByteString o m (Maybe ByteString)
-> (Maybe ByteString -> ConduitT ByteString o m ())
-> ConduitT ByteString o m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ConduitT ByteString o m ()
-> (ByteString -> ConduitT ByteString o m ())
-> Maybe ByteString
-> ConduitT ByteString o m ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Bool -> ConduitT ByteString o m () -> ConduitT ByteString o m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ByteString] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ByteString]
b) (ByteString -> ConduitT ByteString o m ()
forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover (ByteString -> ConduitT ByteString o m ())
-> ByteString -> ConduitT ByteString o m ()
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
b)) ((ByteString -> Result o)
-> [ByteString] -> ByteString -> ConduitT ByteString o m ()
pull ByteString -> Result o
f [ByteString]
b)
          | Bool
otherwise = Bool
-> (ByteString -> Result o)
-> [ByteString]
-> ByteString
-> ConduitT ByteString o m ()
consume Bool
False ByteString -> Result o
f [ByteString]
b ByteString
s
        consume :: Bool
-> (ByteString -> Result o)
-> [ByteString]
-> ByteString
-> ConduitT ByteString o m ()
consume initial :: Bool
initial f :: ByteString -> Result o
f b :: [ByteString]
b s :: ByteString
s = case ByteString -> Result o
f ByteString
s of
          Fail msg :: String
msg _ -> do
            Bool -> ConduitT ByteString o m () -> ConduitT ByteString o m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ByteString] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ByteString]
b) (ByteString -> ConduitT ByteString o m ()
forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover (ByteString -> ConduitT ByteString o m ())
-> ByteString -> ConduitT ByteString o m ()
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
consumed)
            ConduitErrorHandler m o
errorHandler String
msg
          Partial p :: ByteString -> Result o
p -> (ByteString -> Result o)
-> [ByteString] -> ByteString -> ConduitT ByteString o m ()
pull ByteString -> Result o
p [ByteString]
consumed ByteString
BS.empty
          Done a :: o
a s' :: ByteString
s' -> case Bool
initial of
                         -- this only works because the Get will either _always_ consume no input, or _never_ consume no input.
                         True  -> ConduitT ByteString o m () -> ConduitT ByteString o m ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (ConduitT ByteString o m () -> ConduitT ByteString o m ())
-> ConduitT ByteString o m () -> ConduitT ByteString o m ()
forall a b. (a -> b) -> a -> b
$ o -> ConduitT ByteString o m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield o
a
                         False -> o -> ConduitT ByteString o m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield o
a ConduitT ByteString o m ()
-> ConduitT ByteString o m () -> ConduitT ByteString o m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (ByteString -> Result o)
-> [ByteString] -> ByteString -> ConduitT ByteString o m ()
pull (Get o -> ByteString -> Result o
forall a. Get a -> ByteString -> Result a
runGetPartial Get o
get) [] ByteString
s'
--                         False -> yield a >> leftover s' >> mkConduitGet errorHandler get
          where consumed :: [ByteString]
consumed = ByteString
s ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
b

-- | Construct a sinkGet with the specified 'ErrorHandler' and 'TerminationHandler'
mkSinkGet :: Monad m
          => SinkErrorHandler m r
          -> SinkTerminationHandler m r
          -> Get r
          -> ConduitT BS.ByteString o m r
mkSinkGet :: SinkErrorHandler m r
-> SinkTerminationHandler m r -> Get r -> ConduitT ByteString o m r
mkSinkGet errorHandler :: SinkErrorHandler m r
errorHandler terminationHandler :: SinkTerminationHandler m r
terminationHandler get :: Get r
get = (ByteString -> Result r)
-> [ByteString] -> ByteString -> ConduitT ByteString o m r
forall o.
(ByteString -> Result r)
-> [ByteString] -> ByteString -> ConduitT ByteString o m r
consume (Get r -> ByteString -> Result r
forall a. Get a -> ByteString -> Result a
runGetPartial Get r
get) [] ByteString
BS.empty
  where pull :: (ByteString -> Result r)
-> [ByteString] -> ByteString -> ConduitT ByteString o m r
pull f :: ByteString -> Result r
f b :: [ByteString]
b s :: ByteString
s
          | ByteString -> Bool
BS.null ByteString
s = ConduitT ByteString o m (Maybe ByteString)
forall (m :: * -> *) i. Monad m => Consumer i m (Maybe i)
await ConduitT ByteString o m (Maybe ByteString)
-> (Maybe ByteString -> ConduitT ByteString o m r)
-> ConduitT ByteString o m r
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ x :: Maybe ByteString
x -> case Maybe ByteString
x of
                          Nothing -> Bool -> ConduitT ByteString o m () -> ConduitT ByteString o m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ByteString] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ByteString]
b) (ByteString -> ConduitT ByteString o m ()
forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover (ByteString -> ConduitT ByteString o m ())
-> ByteString -> ConduitT ByteString o m ()
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
b) ConduitT ByteString o m ()
-> ConduitT ByteString o m r -> ConduitT ByteString o m r
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (ByteString -> Result r) -> ConduitT ByteString o m r
SinkTerminationHandler m r
terminationHandler ByteString -> Result r
f
                          Just a :: ByteString
a -> (ByteString -> Result r)
-> [ByteString] -> ByteString -> ConduitT ByteString o m r
pull ByteString -> Result r
f [ByteString]
b ByteString
a
          | Bool
otherwise = (ByteString -> Result r)
-> [ByteString] -> ByteString -> ConduitT ByteString o m r
consume ByteString -> Result r
f [ByteString]
b ByteString
s
        consume :: (ByteString -> Result r)
-> [ByteString] -> ByteString -> ConduitT ByteString o m r
consume f :: ByteString -> Result r
f b :: [ByteString]
b s :: ByteString
s = case ByteString -> Result r
f ByteString
s of
          Fail msg :: String
msg _ -> do
            Bool -> ConduitT ByteString o m () -> ConduitT ByteString o m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [ByteString] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ByteString]
b) (ByteString -> ConduitT ByteString o m ()
forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover (ByteString -> ConduitT ByteString o m ())
-> ByteString -> ConduitT ByteString o m ()
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
consumed)
            String -> ConduitT ByteString o m r
SinkErrorHandler m r
errorHandler String
msg
          Partial p :: ByteString -> Result r
p -> (ByteString -> Result r)
-> [ByteString] -> ByteString -> ConduitT ByteString o m r
pull ByteString -> Result r
p [ByteString]
consumed ByteString
BS.empty
          Done r :: r
r s' :: ByteString
s' -> Bool -> ConduitT ByteString o m () -> ConduitT ByteString o m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ByteString -> Bool
BS.null ByteString
s') (ByteString -> ConduitT ByteString o m ()
forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover ByteString
s') ConduitT ByteString o m ()
-> ConduitT ByteString o m r -> ConduitT ByteString o m r
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> r -> ConduitT ByteString o m r
forall (m :: * -> *) a. Monad m => a -> m a
return r
r
          where consumed :: [ByteString]
consumed = ByteString
s ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
b