module Network.TLS.Client.Enumerator
( ConnInfo
, connClose
, connIter
, connEnum
, sslClientConn
, socketConn
, TLSCertificateRejectReason(..)
, TLSCertificateUsage(..)
, ConnectionReset (..)
) where
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.Typeable (Typeable)
import Control.Exception (Exception)
import System.IO (Handle, hClose)
import Network.Socket (Socket, sClose)
import Network.Socket.ByteString (recv, sendAll)
import Network.TLS
import Control.Monad.IO.Class (MonadIO)
import Control.Monad.Trans.Class (lift)
import Data.Enumerator
( Iteratee (..), Enumerator, Step (..), Stream (..), continue, returnI
, tryIO, throwError
)
import Data.Certificate.X509 (X509)
import Network.TLS.Extra (ciphersuite_all)
import Crypto.Random.AESCtr (makeSystem)
data ConnInfo = ConnInfo
{ connRead :: IO [ByteString]
, connWrite :: [ByteString] -> IO ()
, connClose :: IO ()
}
data ConnectionReset = ConnectionReset
deriving (Show,Typeable)
instance Exception ConnectionReset
connIter :: MonadIO m => ConnInfo -> Iteratee ByteString m ()
connIter ConnInfo { connWrite = write } =
continue go
where
go EOF = return ()
go (Chunks bss) = do
tryIO $ write bss
continue go
connEnum :: MonadIO m => ConnInfo -> Enumerator ByteString m b
connEnum ConnInfo { connRead = read' } =
go
where
go (Continue k) = do
bs <- tryIO read'
if all S.null bs
then throwError ConnectionReset
else do
step <- lift $ runIteratee $ k $ Chunks bs
go step
go step = returnI step
socketConn :: Socket -> ConnInfo
socketConn sock = ConnInfo
{ connRead = fmap return $ recv sock 4096
, connWrite = mapM_ (sendAll sock)
, connClose = sClose sock
}
sslClientConn :: ([X509] -> IO TLSCertificateUsage) -> Handle -> IO ConnInfo
sslClientConn onCerts h = do
let tcp = defaultParams
{ pConnectVersion = TLS10
, pAllowedVersions = [ TLS10, TLS11 ]
, pCiphers = ciphersuite_all
, onCertificatesRecv = onCerts
}
gen <- makeSystem
istate <- client tcp gen h
handshake istate
return ConnInfo
{ connRead = recvD istate
, connWrite = sendData istate . L.fromChunks
, connClose = bye istate >> hClose h
}
where
recvD istate = do
x <- recvData istate
if S.null x
then recvD istate
else return [x]