{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE PatternGuards #-}

-- | HTTP over TLS support for Warp via the TLS package.
--
--   If HTTP\/2 is negotiated by ALPN, HTTP\/2 over TLS is used.
--   Otherwise HTTP\/1.1 over TLS is used.
--
--   Support for SSL is now obsoleted.

module Network.Wai.Handler.WarpTLS (
    -- * Settings
      TLSSettings
    , defaultTlsSettings
    -- * Smart constructors
    , tlsSettings
    , tlsSettingsMemory
    , tlsSettingsChain
    , tlsSettingsChainMemory
    -- * Accessors
    , certFile
    , keyFile
    , tlsLogging
    , tlsAllowedVersions
    , tlsCiphers
    , tlsWantClientCert
    , tlsServerHooks
    , tlsServerDHEParams
    , tlsSessionManagerConfig
    , onInsecure
    , OnInsecure (..)
    -- * Runner
    , runTLS
    , runTLSSocket
    -- * Exception
    , WarpTLSException (..)
    , DH.Params
    , DH.generateParams
    ) where

import Control.Applicative ((<|>))
import Control.Exception (Exception, throwIO, bracket, finally, handle, fromException, try, IOException, onException, SomeException(..), handleJust)
import qualified Control.Exception as E
import Control.Monad (void, guard)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.Default.Class (def)
import qualified Data.IORef as I
import Data.Streaming.Network (bindPortTCP, safeRecv)
import Data.Typeable (Typeable)
import Network.Socket (Socket, close, withSocketsDo, SockAddr, accept)
#if MIN_VERSION_network(3,1,1)
import Network.Socket (gracefulClose)
#endif
import Network.Socket.ByteString (sendAll)
import qualified Network.TLS as TLS
import qualified Crypto.PubKey.DH as DH
import qualified Network.TLS.Extra as TLSExtra
import qualified Network.TLS.SessionManager as SM
import Network.Wai (Application)
import Network.Wai.Handler.Warp
import Network.Wai.Handler.Warp.Internal
import System.IO.Error (isEOFError)

----------------------------------------------------------------

-- | Settings for WarpTLS.
data TLSSettings = TLSSettings {
    TLSSettings -> FilePath
certFile :: FilePath
    -- ^ File containing the certificate.
  , TLSSettings -> [FilePath]
chainCertFiles :: [FilePath]
    -- ^ Files containing chain certificates.
  , TLSSettings -> FilePath
keyFile :: FilePath
    -- ^ File containing the key
  , TLSSettings -> Maybe ByteString
certMemory :: Maybe S.ByteString
  , TLSSettings -> [ByteString]
chainCertsMemory :: [S.ByteString]
  , TLSSettings -> Maybe ByteString
keyMemory :: Maybe S.ByteString
  , TLSSettings -> OnInsecure
onInsecure :: OnInsecure
    -- ^ Do we allow insecure connections with this server as well?
    --
    -- >>> onInsecure defaultTlsSettings
    -- DenyInsecure "This server only accepts secure HTTPS connections."
    --
    -- Since 1.4.0
  , TLSSettings -> Logging
tlsLogging :: TLS.Logging
    -- ^ The level of logging to turn on.
    --
    -- Default: 'TLS.defaultLogging'.
    --
    -- Since 1.4.0
  , TLSSettings -> [Version]
tlsAllowedVersions :: [TLS.Version]
#if MIN_VERSION_tls(1,5,0)
    -- ^ The TLS versions this server accepts.
    --
    -- >>> tlsAllowedVersions defaultTlsSettings
    -- [TLS13,TLS12,TLS11,TLS10]
    --
    -- Since 1.4.2
#else
    -- ^ The TLS versions this server accepts.
    --
    -- >>> tlsAllowedVersions defaultTlsSettings
    -- [TLS12,TLS11,TLS10]
    --
    -- Since 1.4.2
#endif
  , TLSSettings -> [Cipher]
tlsCiphers :: [TLS.Cipher]
#if MIN_VERSION_tls(1,5,0)
    -- ^ The TLS ciphers this server accepts.
    --
    -- >>> tlsCiphers defaultTlsSettings
    -- [ECDHE-ECDSA-AES256GCM-SHA384,ECDHE-ECDSA-AES128GCM-SHA256,ECDHE-RSA-AES256GCM-SHA384,ECDHE-RSA-AES128GCM-SHA256,DHE-RSA-AES256GCM-SHA384,DHE-RSA-AES128GCM-SHA256,ECDHE-ECDSA-AES256CBC-SHA384,ECDHE-RSA-AES256CBC-SHA384,DHE-RSA-AES256-SHA256,ECDHE-ECDSA-AES256CBC-SHA,ECDHE-RSA-AES256CBC-SHA,DHE-RSA-AES256-SHA1,RSA-AES256GCM-SHA384,RSA-AES256-SHA256,RSA-AES256-SHA1,AES128GCM-SHA256,AES256GCM-SHA384]
    --
    -- Since 1.4.2
#else
    -- ^ The TLS ciphers this server accepts.
    --
    -- >>> tlsCiphers defaultTlsSettings
    -- [ECDHE-ECDSA-AES256GCM-SHA384,ECDHE-ECDSA-AES128GCM-SHA256,ECDHE-RSA-AES256GCM-SHA384,ECDHE-RSA-AES128GCM-SHA256,DHE-RSA-AES256GCM-SHA384,DHE-RSA-AES128GCM-SHA256,ECDHE-ECDSA-AES256CBC-SHA384,ECDHE-RSA-AES256CBC-SHA384,DHE-RSA-AES256-SHA256,ECDHE-ECDSA-AES256CBC-SHA,ECDHE-RSA-AES256CBC-SHA,DHE-RSA-AES256-SHA1,RSA-AES256GCM-SHA384,RSA-AES256-SHA256,RSA-AES256-SHA1]
    --
    -- Since 1.4.2
#endif
  , TLSSettings -> Bool
tlsWantClientCert :: Bool
    -- ^ Whether or not to demand a certificate from the client.  If this
    -- is set to True, you must handle received certificates in a server hook
    -- or all connections will fail.
    --
    -- >>> tlsWantClientCert defaultTlsSettings
    -- False
    --
    -- Since 3.0.2
  , TLSSettings -> ServerHooks
tlsServerHooks :: TLS.ServerHooks
    -- ^ The server-side hooks called by the tls package, including actions
    -- to take when a client certificate is received.  See the "Network.TLS"
    -- module for details.
    --
    -- Default: def
    --
    -- Since 3.0.2
  , TLSSettings -> Maybe Params
tlsServerDHEParams :: Maybe DH.Params
    -- ^ Configuration for ServerDHEParams
    -- more function lives in `cryptonite` package
    --
    -- Default: Nothing
    --
    -- Since 3.2.2
  , TLSSettings -> Maybe Config
tlsSessionManagerConfig :: Maybe SM.Config
    -- ^ Configuration for in-memory TLS session manager.
    -- If Nothing, 'TLS.noSessionManager' is used.
    -- Otherwise, an in-memory TLS session manager is created
    -- according to 'Config'.
    --
    -- Default: Nothing
    --
    -- Since 3.2.4
  }

-- | Default 'TLSSettings'. Use this to create 'TLSSettings' with the field record name (aka accessors).
defaultTlsSettings :: TLSSettings
defaultTlsSettings :: TLSSettings
defaultTlsSettings = $WTLSSettings :: FilePath
-> [FilePath]
-> FilePath
-> Maybe ByteString
-> [ByteString]
-> Maybe ByteString
-> OnInsecure
-> Logging
-> [Version]
-> [Cipher]
-> Bool
-> ServerHooks
-> Maybe Params
-> Maybe Config
-> TLSSettings
TLSSettings {
    certFile :: FilePath
certFile = "certificate.pem"
  , chainCertFiles :: [FilePath]
chainCertFiles = []
  , keyFile :: FilePath
keyFile = "key.pem"
  , certMemory :: Maybe ByteString
certMemory = Maybe ByteString
forall a. Maybe a
Nothing
  , chainCertsMemory :: [ByteString]
chainCertsMemory = []
  , keyMemory :: Maybe ByteString
keyMemory = Maybe ByteString
forall a. Maybe a
Nothing
  , onInsecure :: OnInsecure
onInsecure = ByteString -> OnInsecure
DenyInsecure "This server only accepts secure HTTPS connections."
  , tlsLogging :: Logging
tlsLogging = Logging
forall a. Default a => a
def
#if MIN_VERSION_tls(1,5,0)
  , tlsAllowedVersions :: [Version]
tlsAllowedVersions = [Version
TLS.TLS13,Version
TLS.TLS12,Version
TLS.TLS11,Version
TLS.TLS10]
#else
  , tlsAllowedVersions = [TLS.TLS12,TLS.TLS11,TLS.TLS10]
#endif
  , tlsCiphers :: [Cipher]
tlsCiphers = [Cipher]
ciphers
  , tlsWantClientCert :: Bool
tlsWantClientCert = Bool
False
  , tlsServerHooks :: ServerHooks
tlsServerHooks = ServerHooks
forall a. Default a => a
def
  , tlsServerDHEParams :: Maybe Params
tlsServerDHEParams = Maybe Params
forall a. Maybe a
Nothing
  , tlsSessionManagerConfig :: Maybe Config
tlsSessionManagerConfig = Maybe Config
forall a. Maybe a
Nothing
  }

-- taken from stunnel example in tls-extra
ciphers :: [TLS.Cipher]
ciphers :: [Cipher]
ciphers = [Cipher]
TLSExtra.ciphersuite_strong

----------------------------------------------------------------

-- | An action when a plain HTTP comes to HTTP over TLS/SSL port.
data OnInsecure = DenyInsecure L.ByteString
                | AllowInsecure
                deriving (Int -> OnInsecure -> ShowS
[OnInsecure] -> ShowS
OnInsecure -> FilePath
(Int -> OnInsecure -> ShowS)
-> (OnInsecure -> FilePath)
-> ([OnInsecure] -> ShowS)
-> Show OnInsecure
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [OnInsecure] -> ShowS
$cshowList :: [OnInsecure] -> ShowS
show :: OnInsecure -> FilePath
$cshow :: OnInsecure -> FilePath
showsPrec :: Int -> OnInsecure -> ShowS
$cshowsPrec :: Int -> OnInsecure -> ShowS
Show)

----------------------------------------------------------------

-- | A smart constructor for 'TLSSettings' based on 'defaultTlsSettings'.
tlsSettings :: FilePath -- ^ Certificate file
            -> FilePath -- ^ Key file
            -> TLSSettings
tlsSettings :: FilePath -> FilePath -> TLSSettings
tlsSettings cert :: FilePath
cert key :: FilePath
key = TLSSettings
defaultTlsSettings {
    certFile :: FilePath
certFile = FilePath
cert
  , keyFile :: FilePath
keyFile = FilePath
key
  }

-- | A smart constructor for 'TLSSettings' that allows specifying
-- chain certificates based on 'defaultTlsSettings'.
--
-- Since 3.0.3
tlsSettingsChain
            :: FilePath -- ^ Certificate file
            -> [FilePath] -- ^ Chain certificate files
            -> FilePath -- ^ Key file
            -> TLSSettings
tlsSettingsChain :: FilePath -> [FilePath] -> FilePath -> TLSSettings
tlsSettingsChain cert :: FilePath
cert chainCerts :: [FilePath]
chainCerts key :: FilePath
key = TLSSettings
defaultTlsSettings {
    certFile :: FilePath
certFile = FilePath
cert
  , chainCertFiles :: [FilePath]
chainCertFiles = [FilePath]
chainCerts
  , keyFile :: FilePath
keyFile = FilePath
key
  }

-- | A smart constructor for 'TLSSettings', but uses in-memory representations
-- of the certificate and key based on 'defaultTlsSettings'.
--
-- Since 3.0.1
tlsSettingsMemory
    :: S.ByteString -- ^ Certificate bytes
    -> S.ByteString -- ^ Key bytes
    -> TLSSettings
tlsSettingsMemory :: ByteString -> ByteString -> TLSSettings
tlsSettingsMemory cert :: ByteString
cert key :: ByteString
key = TLSSettings
defaultTlsSettings
    { certMemory :: Maybe ByteString
certMemory = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
cert
    , keyMemory :: Maybe ByteString
keyMemory = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
key
    }

-- | A smart constructor for 'TLSSettings', but uses in-memory representations
-- of the certificate and key based on 'defaultTlsSettings'.
--
-- Since 3.0.3
tlsSettingsChainMemory
    :: S.ByteString -- ^ Certificate bytes
    -> [S.ByteString] -- ^ Chain certificate bytes
    -> S.ByteString -- ^ Key bytes
    -> TLSSettings
tlsSettingsChainMemory :: ByteString -> [ByteString] -> ByteString -> TLSSettings
tlsSettingsChainMemory cert :: ByteString
cert chainCerts :: [ByteString]
chainCerts key :: ByteString
key = TLSSettings
defaultTlsSettings
    { certMemory :: Maybe ByteString
certMemory = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
cert
    , chainCertsMemory :: [ByteString]
chainCertsMemory = [ByteString]
chainCerts
    , keyMemory :: Maybe ByteString
keyMemory = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
key
    }

----------------------------------------------------------------

-- | Running 'Application' with 'TLSSettings' and 'Settings'.
runTLS :: TLSSettings -> Settings -> Application -> IO ()
runTLS :: TLSSettings -> Settings -> Application -> IO ()
runTLS tset :: TLSSettings
tset set :: Settings
set app :: Application
app = IO () -> IO ()
forall a. IO a -> IO a
withSocketsDo (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
    IO Socket -> (Socket -> IO ()) -> (Socket -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
        (Int -> HostPreference -> IO Socket
bindPortTCP (Settings -> Int
getPort Settings
set) (Settings -> HostPreference
getHost Settings
set))
        Socket -> IO ()
close
        (\sock :: Socket
sock -> TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket TLSSettings
tset Settings
set Socket
sock Application
app)

----------------------------------------------------------------

-- | Running 'Application' with 'TLSSettings' and 'Settings' using
--   specified 'Socket'.
runTLSSocket :: TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket :: TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket tlsset :: TLSSettings
tlsset@TLSSettings{..} set :: Settings
set sock :: Socket
sock app :: Application
app = do
    Credential
credential <- case (Maybe ByteString
certMemory, Maybe ByteString
keyMemory) of
        (Nothing, Nothing) ->
            (FilePath -> Credential)
-> (Credential -> Credential)
-> Either FilePath Credential
-> Credential
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either FilePath -> Credential
forall a. HasCallStack => FilePath -> a
error Credential -> Credential
forall a. a -> a
id (Either FilePath Credential -> Credential)
-> IO (Either FilePath Credential) -> IO Credential
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
            FilePath
-> [FilePath] -> FilePath -> IO (Either FilePath Credential)
TLS.credentialLoadX509Chain FilePath
certFile [FilePath]
chainCertFiles FilePath
keyFile
        (mcert :: Maybe ByteString
mcert, mkey :: Maybe ByteString
mkey) -> do
            ByteString
cert <- IO ByteString
-> (ByteString -> IO ByteString)
-> Maybe ByteString
-> IO ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (FilePath -> IO ByteString
S.readFile FilePath
certFile) ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
mcert
            ByteString
key <- IO ByteString
-> (ByteString -> IO ByteString)
-> Maybe ByteString
-> IO ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (FilePath -> IO ByteString
S.readFile FilePath
keyFile) ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
mkey
            (FilePath -> IO Credential)
-> (Credential -> IO Credential)
-> Either FilePath Credential
-> IO Credential
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either FilePath -> IO Credential
forall a. HasCallStack => FilePath -> a
error Credential -> IO Credential
forall (m :: * -> *) a. Monad m => a -> m a
return (Either FilePath Credential -> IO Credential)
-> Either FilePath Credential -> IO Credential
forall a b. (a -> b) -> a -> b
$
              ByteString
-> [ByteString] -> ByteString -> Either FilePath Credential
TLS.credentialLoadX509ChainFromMemory ByteString
cert [ByteString]
chainCertsMemory ByteString
key
    SessionManager
mgr <- case Maybe Config
tlsSessionManagerConfig of
      Nothing     -> SessionManager -> IO SessionManager
forall (m :: * -> *) a. Monad m => a -> m a
return SessionManager
TLS.noSessionManager
      Just config :: Config
config -> Config -> IO SessionManager
SM.newSessionManager Config
config
    TLSSettings
-> Settings
-> Credential
-> SessionManager
-> Socket
-> Application
-> IO ()
runTLSSocket' TLSSettings
tlsset Settings
set Credential
credential SessionManager
mgr Socket
sock Application
app

runTLSSocket' :: TLSSettings -> Settings -> TLS.Credential -> TLS.SessionManager -> Socket -> Application -> IO ()
runTLSSocket' :: TLSSettings
-> Settings
-> Credential
-> SessionManager
-> Socket
-> Application
-> IO ()
runTLSSocket' tlsset :: TLSSettings
tlsset@TLSSettings{..} set :: Settings
set credential :: Credential
credential mgr :: SessionManager
mgr sock :: Socket
sock app :: Application
app =
    Settings
-> IO (IO (Connection, Transport), SockAddr)
-> Application
-> IO ()
runSettingsConnectionMakerSecure Settings
set IO (IO (Connection, Transport), SockAddr)
get Application
app
  where
    get :: IO (IO (Connection, Transport), SockAddr)
get = TLSSettings
-> Settings
-> Socket
-> ServerParams
-> IO (IO (Connection, Transport), SockAddr)
forall params.
TLSParams params =>
TLSSettings
-> Settings
-> Socket
-> params
-> IO (IO (Connection, Transport), SockAddr)
getter TLSSettings
tlsset Settings
set Socket
sock ServerParams
params
    params :: ServerParams
params = ServerParams
forall a. Default a => a
def { -- TLS.ServerParams
        serverWantClientCert :: Bool
TLS.serverWantClientCert = Bool
tlsWantClientCert
      , serverCACertificates :: [SignedCertificate]
TLS.serverCACertificates = []
      , serverDHEParams :: Maybe Params
TLS.serverDHEParams      = Maybe Params
tlsServerDHEParams
      , serverHooks :: ServerHooks
TLS.serverHooks          = ServerHooks
hooks
      , serverShared :: Shared
TLS.serverShared         = Shared
shared
      , serverSupported :: Supported
TLS.serverSupported      = Supported
supported
#if MIN_VERSION_tls(1,5,0)
      , serverEarlyDataSize :: Int
TLS.serverEarlyDataSize  = 2018
#endif
      }
    -- Adding alpn to user's tlsServerHooks.
    hooks :: ServerHooks
hooks = ServerHooks
tlsServerHooks {
        onALPNClientSuggest :: Maybe ([ByteString] -> IO ByteString)
TLS.onALPNClientSuggest = ServerHooks -> Maybe ([ByteString] -> IO ByteString)
TLS.onALPNClientSuggest ServerHooks
tlsServerHooks Maybe ([ByteString] -> IO ByteString)
-> Maybe ([ByteString] -> IO ByteString)
-> Maybe ([ByteString] -> IO ByteString)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
          (if Settings -> Bool
settingsHTTP2Enabled Settings
set then ([ByteString] -> IO ByteString)
-> Maybe ([ByteString] -> IO ByteString)
forall a. a -> Maybe a
Just [ByteString] -> IO ByteString
alpn else Maybe ([ByteString] -> IO ByteString)
forall a. Maybe a
Nothing)
      }
    shared :: Shared
shared = Shared
forall a. Default a => a
def {
        sharedCredentials :: Credentials
TLS.sharedCredentials    = [Credential] -> Credentials
TLS.Credentials [Credential
credential]
      , sharedSessionManager :: SessionManager
TLS.sharedSessionManager = SessionManager
mgr
      }
    supported :: Supported
supported = Supported
forall a. Default a => a
def { -- TLS.Supported
        supportedVersions :: [Version]
TLS.supportedVersions       = [Version]
tlsAllowedVersions
      , supportedCiphers :: [Cipher]
TLS.supportedCiphers        = [Cipher]
tlsCiphers
      , supportedCompressions :: [Compression]
TLS.supportedCompressions   = [Compression
TLS.nullCompression]
      , supportedSecureRenegotiation :: Bool
TLS.supportedSecureRenegotiation = Bool
True
      , supportedClientInitiatedRenegotiation :: Bool
TLS.supportedClientInitiatedRenegotiation = Bool
False
      , supportedSession :: Bool
TLS.supportedSession             = Bool
True
      , supportedFallbackScsv :: Bool
TLS.supportedFallbackScsv        = Bool
True
#if MIN_VERSION_tls(1,5,0)
      , supportedGroups :: [Group]
TLS.supportedGroups              = [Group
TLS.X25519,Group
TLS.P256,Group
TLS.P384]
#endif
      }

alpn :: [S.ByteString] -> IO S.ByteString
alpn :: [ByteString] -> IO ByteString
alpn xs :: [ByteString]
xs
  | "h2"    ByteString -> [ByteString] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
xs = ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return "h2"
  | Bool
otherwise         = ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return "http/1.1"

----------------------------------------------------------------

getter :: TLS.TLSParams params => TLSSettings -> Settings -> Socket -> params -> IO (IO (Connection, Transport), SockAddr)
getter :: TLSSettings
-> Settings
-> Socket
-> params
-> IO (IO (Connection, Transport), SockAddr)
getter tlsset :: TLSSettings
tlsset@TLSSettings{..} set :: Settings
set sock :: Socket
sock params :: params
params = do
#if WINDOWS
    (s, sa) <- windowsThreadBlockHack $ accept sock
#else
    (s :: Socket
s, sa :: SockAddr
sa) <- Socket -> IO (Socket, SockAddr)
accept Socket
sock
#endif
    Socket -> IO ()
setSocketCloseOnExec Socket
s
    (IO (Connection, Transport), SockAddr)
-> IO (IO (Connection, Transport), SockAddr)
forall (m :: * -> *) a. Monad m => a -> m a
return (TLSSettings
-> Settings -> Socket -> params -> IO (Connection, Transport)
forall params.
TLSParams params =>
TLSSettings
-> Settings -> Socket -> params -> IO (Connection, Transport)
mkConn TLSSettings
tlsset Settings
set Socket
s params
params, SockAddr
sa)

mkConn :: TLS.TLSParams params => TLSSettings -> Settings -> Socket -> params -> IO (Connection, Transport)
mkConn :: TLSSettings
-> Settings -> Socket -> params -> IO (Connection, Transport)
mkConn tlsset :: TLSSettings
tlsset set :: Settings
set s :: Socket
s params :: params
params = IO (Connection, Transport)
switch IO (Connection, Transport) -> IO () -> IO (Connection, Transport)
forall a b. IO a -> IO b -> IO a
`onException` Socket -> IO ()
close Socket
s
  where
    switch :: IO (Connection, Transport)
switch = do
        ByteString
firstBS <- Socket -> Int -> IO ByteString
safeRecv Socket
s 4096
        if Bool -> Bool
not (ByteString -> Bool
S.null ByteString
firstBS) Bool -> Bool -> Bool
&& ByteString -> Word8
S.head ByteString
firstBS Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== 0x16 then
            TLSSettings
-> Settings
-> Socket
-> ByteString
-> params
-> IO (Connection, Transport)
forall params.
TLSParams params =>
TLSSettings
-> Settings
-> Socket
-> ByteString
-> params
-> IO (Connection, Transport)
httpOverTls TLSSettings
tlsset Settings
set Socket
s ByteString
firstBS params
params
          else
            TLSSettings
-> Settings -> Socket -> ByteString -> IO (Connection, Transport)
plainHTTP TLSSettings
tlsset Settings
set Socket
s ByteString
firstBS

----------------------------------------------------------------

httpOverTls :: TLS.TLSParams params => TLSSettings -> Settings -> Socket -> S.ByteString -> params -> IO (Connection, Transport)
httpOverTls :: TLSSettings
-> Settings
-> Socket
-> ByteString
-> params
-> IO (Connection, Transport)
httpOverTls TLSSettings{..} _set :: Settings
_set s :: Socket
s bs0 :: ByteString
bs0 params :: params
params = do
    Int -> IO ByteString
recvN <- Socket -> ByteString -> IO (Int -> IO ByteString)
makePlainReceiveN Socket
s ByteString
bs0
    Context
ctx <- Backend -> params -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew ((Int -> IO ByteString) -> Backend
backend Int -> IO ByteString
recvN) params
params
    Context -> Logging -> IO ()
TLS.contextHookSetLogging Context
ctx Logging
tlsLogging
    Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
    Bool
h2 <- (Maybe ByteString -> Maybe ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just "h2") (Maybe ByteString -> Bool) -> IO (Maybe ByteString) -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> IO (Maybe ByteString)
forall (m :: * -> *). MonadIO m => Context -> m (Maybe ByteString)
TLS.getNegotiatedProtocol Context
ctx
    IORef Bool
isH2 <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
I.newIORef Bool
h2
    Buffer
writeBuf <- Int -> IO Buffer
allocateBuffer Int
bufferSize
    -- Creating a cache for leftover input data.
    IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
I.newIORef ""
    Transport
tls <- Context -> IO Transport
getTLSinfo Context
ctx
    (Connection, Transport) -> IO (Connection, Transport)
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> Buffer -> IORef ByteString -> IORef Bool -> Connection
conn Context
ctx Buffer
writeBuf IORef ByteString
ref IORef Bool
isH2, Transport
tls)
  where
    backend :: (Int -> IO ByteString) -> Backend
backend recvN :: Int -> IO ByteString
recvN = Backend :: IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
TLS.Backend {
        backendFlush :: IO ()
TLS.backendFlush = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if MIN_VERSION_network(3,1,1)
      , backendClose :: IO ()
TLS.backendClose = Socket -> Int -> IO ()
gracefulClose Socket
s 5000 IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \(SomeException _) -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#else
      , TLS.backendClose = close s
#endif
      , backendSend :: ByteString -> IO ()
TLS.backendSend  = Socket -> ByteString -> IO ()
sendAll' Socket
s
      , backendRecv :: Int -> IO ByteString
TLS.backendRecv  = Int -> IO ByteString
recvN
      }
    sendAll' :: Socket -> ByteString -> IO ()
sendAll' sock :: Socket
sock bs :: ByteString
bs = Socket -> ByteString -> IO ()
sendAll Socket
sock ByteString
bs IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \(SomeException _) ->
        InvalidRequest -> IO ()
forall e a. Exception e => e -> IO a
throwIO InvalidRequest
ConnectionClosedByPeer
    conn :: Context -> Buffer -> IORef ByteString -> IORef Bool -> Connection
conn ctx :: Context
ctx writeBuf :: Buffer
writeBuf ref :: IORef ByteString
ref isH2 :: IORef Bool
isH2 = $WConnection :: ([ByteString] -> IO ())
-> (ByteString -> IO ())
-> SendFile
-> IO ()
-> IO ()
-> IO ByteString
-> RecvBuf
-> Buffer
-> Int
-> IORef Bool
-> Connection
Connection {
        connSendMany :: [ByteString] -> IO ()
connSendMany         = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> IO ())
-> ([ByteString] -> ByteString) -> [ByteString] -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
L.fromChunks
      , connSendAll :: ByteString -> IO ()
connSendAll          = ByteString -> IO ()
sendall
      , connSendFile :: SendFile
connSendFile         = SendFile
sendfile
      , connClose :: IO ()
connClose            = IO ()
close'
      , connFree :: IO ()
connFree             = Buffer -> IO ()
freeBuffer Buffer
writeBuf
      , connRecv :: IO ByteString
connRecv             = IORef ByteString -> IO ByteString
recv IORef ByteString
ref
      , connRecvBuf :: RecvBuf
connRecvBuf          = IORef ByteString -> RecvBuf
recvBuf IORef ByteString
ref
      , connWriteBuffer :: Buffer
connWriteBuffer      = Buffer
writeBuf
      , connBufferSize :: Int
connBufferSize       = Int
bufferSize
      , connHTTP2 :: IORef Bool
connHTTP2            = IORef Bool
isH2
      }
      where
        sendall :: ByteString -> IO ()
sendall = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
L.fromChunks ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return
        sendfile :: SendFile
sendfile fid :: FileId
fid offset :: Integer
offset len :: Integer
len hook :: IO ()
hook headers :: [ByteString]
headers =
            Buffer -> Int -> (ByteString -> IO ()) -> SendFile
readSendFile Buffer
writeBuf Int
bufferSize ByteString -> IO ()
sendall FileId
fid Integer
offset Integer
len IO ()
hook [ByteString]
headers

        close' :: IO ()
close' = IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO (Either IOException ())
forall a. IO a -> IO (Either IOException a)
tryIO IO ()
sendBye) IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally`
                 Context -> IO ()
TLS.contextClose Context
ctx

        sendBye :: IO ()
sendBye =
          -- It's fine if the connection was closed by the other side before
          -- receiving close_notify, see RFC 5246 section 7.2.1.
          (InvalidRequest -> Maybe InvalidRequest)
-> (InvalidRequest -> IO ()) -> IO () -> IO ()
forall e b a.
Exception e =>
(e -> Maybe b) -> (b -> IO a) -> IO a -> IO a
handleJust
            (\e :: InvalidRequest
e -> Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (InvalidRequest
e InvalidRequest -> InvalidRequest -> Bool
forall a. Eq a => a -> a -> Bool
== InvalidRequest
ConnectionClosedByPeer) Maybe () -> Maybe InvalidRequest -> Maybe InvalidRequest
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> InvalidRequest -> Maybe InvalidRequest
forall (m :: * -> *) a. Monad m => a -> m a
return InvalidRequest
e)
            (IO () -> InvalidRequest -> IO ()
forall a b. a -> b -> a
const (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()))
            (Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx)

        -- TLS version of recv with a cache for leftover input data.
        -- The cache is shared with recvBuf.
        recv :: IORef ByteString -> IO ByteString
recv cref :: IORef ByteString
cref = do
            ByteString
cached <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
I.readIORef IORef ByteString
cref
            if ByteString
cached ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= "" then do
                IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ByteString
cref ""
                ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
cached
              else
                IO ByteString
recv'

        -- TLS version of recv (decrypting) without a cache.
        recv' :: IO ByteString
recv' = (SomeException -> IO ByteString) -> IO ByteString -> IO ByteString
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle SomeException -> IO ByteString
onEOF IO ByteString
go
          where
            onEOF :: SomeException -> IO ByteString
onEOF e :: SomeException
e
              | Just TLS.Error_EOF <- SomeException -> Maybe TLSError
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e       = ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
              | Just ioe :: IOException
ioe <- SomeException -> Maybe IOException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e, IOException -> Bool
isEOFError IOException
ioe = ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty                  | Bool
otherwise                                   = SomeException -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO SomeException
e
            go :: IO ByteString
go = do
                ByteString
x <- Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
                if ByteString -> Bool
S.null ByteString
x then
                    IO ByteString
go
                  else
                    ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
x

        -- TLS version of recvBuf with a cache for leftover input data.
        recvBuf :: IORef ByteString -> RecvBuf
recvBuf cref :: IORef ByteString
cref buf :: Buffer
buf siz :: Int
siz = do
            ByteString
cached <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
I.readIORef IORef ByteString
cref
            (ret :: Bool
ret, leftover :: ByteString
leftover) <- ByteString
-> Buffer -> Int -> IO ByteString -> IO (Bool, ByteString)
fill ByteString
cached Buffer
buf Int
siz IO ByteString
recv'
            IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ByteString
cref ByteString
leftover
            Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
ret

fill :: S.ByteString -> Buffer -> BufSize -> Recv -> IO (Bool,S.ByteString)
fill :: ByteString
-> Buffer -> Int -> IO ByteString -> IO (Bool, ByteString)
fill bs0 :: ByteString
bs0 buf0 :: Buffer
buf0 siz0 :: Int
siz0 recv :: IO ByteString
recv
  | Int
siz0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
len0 = do
      let (bs :: ByteString
bs, leftover :: ByteString
leftover) = Int -> ByteString -> (ByteString, ByteString)
S.splitAt Int
siz0 ByteString
bs0
      IO Buffer -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Buffer -> IO ()) -> IO Buffer -> IO ()
forall a b. (a -> b) -> a -> b
$ Buffer -> ByteString -> IO Buffer
copy Buffer
buf0 ByteString
bs
      (Bool, ByteString) -> IO (Bool, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, ByteString
leftover)
  | Bool
otherwise = do
      Buffer
buf <- Buffer -> ByteString -> IO Buffer
copy Buffer
buf0 ByteString
bs0
      Buffer -> Int -> IO (Bool, ByteString)
loop Buffer
buf (Int
siz0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len0)
  where
    len0 :: Int
len0 = ByteString -> Int
S.length ByteString
bs0
    loop :: Buffer -> Int -> IO (Bool, ByteString)
loop _   0   = (Bool, ByteString) -> IO (Bool, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, "")
    loop buf :: Buffer
buf siz :: Int
siz = do
      ByteString
bs <- IO ByteString
recv
      let len :: Int
len = ByteString -> Int
S.length ByteString
bs
      if Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then (Bool, ByteString) -> IO (Bool, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, "")
        else if (Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
siz) then do
          Buffer
buf' <- Buffer -> ByteString -> IO Buffer
copy Buffer
buf ByteString
bs
          Buffer -> Int -> IO (Bool, ByteString)
loop Buffer
buf' (Int
siz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len)
        else do
          let (bs1 :: ByteString
bs1,bs2 :: ByteString
bs2) = Int -> ByteString -> (ByteString, ByteString)
S.splitAt Int
siz ByteString
bs
          IO Buffer -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Buffer -> IO ()) -> IO Buffer -> IO ()
forall a b. (a -> b) -> a -> b
$ Buffer -> ByteString -> IO Buffer
copy Buffer
buf ByteString
bs1
          (Bool, ByteString) -> IO (Bool, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, ByteString
bs2)

getTLSinfo :: TLS.Context -> IO Transport
getTLSinfo :: Context -> IO Transport
getTLSinfo ctx :: Context
ctx = do
    Maybe ByteString
proto <- Context -> IO (Maybe ByteString)
forall (m :: * -> *). MonadIO m => Context -> m (Maybe ByteString)
TLS.getNegotiatedProtocol Context
ctx
    Maybe Information
minfo <- Context -> IO (Maybe Information)
TLS.contextGetInformation Context
ctx
    case Maybe Information
minfo of
        Nothing   -> Transport -> IO Transport
forall (m :: * -> *) a. Monad m => a -> m a
return Transport
TCP
        Just TLS.Information{..} -> do
            let (major :: Int
major, minor :: Int
minor) = case Version
infoVersion of
                    TLS.SSL2  -> (2,0)
                    TLS.SSL3  -> (3,0)
                    TLS.TLS10 -> (3,1)
                    TLS.TLS11 -> (3,2)
                    TLS.TLS12 -> (3,3)
#if MIN_VERSION_tls(1,5,0)
                    TLS.TLS13 -> (3,4)
#endif
            Maybe CertificateChain
clientCert <- Context -> IO (Maybe CertificateChain)
TLS.getClientCertificateChain Context
ctx
            Transport -> IO Transport
forall (m :: * -> *) a. Monad m => a -> m a
return $WTLS :: Int
-> Int
-> Maybe ByteString
-> Word16
-> Maybe CertificateChain
-> Transport
TLS {
                tlsMajorVersion :: Int
tlsMajorVersion = Int
major
              , tlsMinorVersion :: Int
tlsMinorVersion = Int
minor
              , tlsNegotiatedProtocol :: Maybe ByteString
tlsNegotiatedProtocol = Maybe ByteString
proto
              , tlsChiperID :: Word16
tlsChiperID = Cipher -> Word16
TLS.cipherID Cipher
infoCipher
              , tlsClientCertificate :: Maybe CertificateChain
tlsClientCertificate = Maybe CertificateChain
clientCert
              }

tryIO :: IO a -> IO (Either IOException a)
tryIO :: IO a -> IO (Either IOException a)
tryIO = IO a -> IO (Either IOException a)
forall e a. Exception e => IO a -> IO (Either e a)
try

----------------------------------------------------------------

plainHTTP :: TLSSettings -> Settings -> Socket -> S.ByteString -> IO (Connection, Transport)
plainHTTP :: TLSSettings
-> Settings -> Socket -> ByteString -> IO (Connection, Transport)
plainHTTP TLSSettings{..} set :: Settings
set s :: Socket
s bs0 :: ByteString
bs0 = case OnInsecure
onInsecure of
    AllowInsecure -> do
        Connection
conn' <- Settings -> Socket -> IO Connection
socketConnection Settings
set Socket
s
        IORef ByteString
cachedRef <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
I.newIORef ByteString
bs0
        let conn'' :: Connection
conn'' = Connection
conn'
                { connRecv :: IO ByteString
connRecv = IORef ByteString -> IO ByteString -> IO ByteString
recvPlain IORef ByteString
cachedRef (Connection -> IO ByteString
connRecv Connection
conn')
                }
        (Connection, Transport) -> IO (Connection, Transport)
forall (m :: * -> *) a. Monad m => a -> m a
return (Connection
conn'', Transport
TCP)
    DenyInsecure lbs :: ByteString
lbs -> do
        -- Listening port 443 but TLS records do not arrive.
        -- We want to let the browser know that TLS is required.
        -- So, we use 426.
        --     http://tools.ietf.org/html/rfc2817#section-4.2
        --     https://tools.ietf.org/html/rfc7231#section-6.5.15
        -- FIXME: should we distinguish HTTP/1.1 and HTTP/2?
        --        In the case of HTTP/2, should we send
        --        GOAWAY + INADEQUATE_SECURITY?
        -- FIXME: Content-Length:
        -- FIXME: TLS/<version>
        Socket -> ByteString -> IO ()
sendAll Socket
s "HTTP/1.1 426 Upgrade Required\
        \r\nUpgrade: TLS/1.0, HTTP/1.1\
        \r\nConnection: Upgrade\
        \r\nContent-Type: text/plain\r\n\r\n"
        (ByteString -> IO ()) -> [ByteString] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Socket -> ByteString -> IO ()
sendAll Socket
s) ([ByteString] -> IO ()) -> [ByteString] -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString]
L.toChunks ByteString
lbs
        Socket -> IO ()
close Socket
s
        WarpTLSException -> IO (Connection, Transport)
forall e a. Exception e => e -> IO a
throwIO WarpTLSException
InsecureConnectionDenied

----------------------------------------------------------------

-- | Modify the given receive function to first check the given @IORef@ for a
-- chunk of data. If present, takes the chunk of data from the @IORef@ and
-- empties out the @IORef@. Otherwise, calls the supplied receive function.
recvPlain :: I.IORef S.ByteString -> IO S.ByteString -> IO S.ByteString
recvPlain :: IORef ByteString -> IO ByteString -> IO ByteString
recvPlain ref :: IORef ByteString
ref fallback :: IO ByteString
fallback = do
    ByteString
bs <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
I.readIORef IORef ByteString
ref
    if ByteString -> Bool
S.null ByteString
bs
        then IO ByteString
fallback
        else do
            IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ByteString
ref ByteString
S.empty
            ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

----------------------------------------------------------------

data WarpTLSException = InsecureConnectionDenied
    deriving (Int -> WarpTLSException -> ShowS
[WarpTLSException] -> ShowS
WarpTLSException -> FilePath
(Int -> WarpTLSException -> ShowS)
-> (WarpTLSException -> FilePath)
-> ([WarpTLSException] -> ShowS)
-> Show WarpTLSException
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [WarpTLSException] -> ShowS
$cshowList :: [WarpTLSException] -> ShowS
show :: WarpTLSException -> FilePath
$cshow :: WarpTLSException -> FilePath
showsPrec :: Int -> WarpTLSException -> ShowS
$cshowsPrec :: Int -> WarpTLSException -> ShowS
Show, Typeable)
instance Exception WarpTLSException