{-|
Module      : Hookup
Description : Network connections generalized over TLS and SOCKS
Copyright   : (c) Eric Mertens, 2016
License     : ISC
Maintainer  : emertens@gmail.com

This module provides a uniform interface to network connections
with optional support for TLS and SOCKS.

This library is careful to support both IPv4 and IPv6. It will attempt to
all of the addresses that a domain name resolves to until one the first
successful connection.

Use 'connect' and 'close' to establish and close network connections.

Use 'recv', 'recvLine', and 'send' to receive and transmit data on an
open network connection.

TLS and SOCKS parameters can be provided. When both are provided a connection
will first be established to the SOCKS server and then the TLS connection will
be established through that proxy server. This is most useful when connecting
through a dynamic port forward of an SSH client via the @-D@ flag.

-}
module Hookup
  (
  -- * Connections
  Connection,
  connect,
  connectWithSocket,
  close,

  -- * Reading and writing data
  recv,
  recvLine,
  send,
  putBuf,

  -- * Configuration
  ConnectionParams(..),
  SocksParams(..),
  TlsParams(..),
  defaultFamily,
  defaultTlsParams,


  -- * Errors
  ConnectionFailure(..),
  CommandReply(..)

  -- * SSL Information
  , getClientCertificate
  , getPeerCertificate
  , getPeerCertFingerprintSha1
  , getPeerCertFingerprintSha256
  , getPeerCertFingerprintSha512
  , getPeerPubkeyFingerprintSha1
  , getPeerPubkeyFingerprintSha256
  , getPeerPubkeyFingerprintSha512
  ) where

import           Control.Concurrent
import           Control.Exception
import           Control.Monad
import           System.IO.Error (isDoesNotExistError, ioeGetErrorString)
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import           Data.Foldable
import           Data.List (intercalate)
import           Network.Socket (Socket, AddrInfo, PortNumber, HostName, Family)
import qualified Network.Socket as Socket
import qualified Network.Socket.ByteString as SocketB
import           OpenSSL.Session (SSL, SSLContext)
import qualified OpenSSL as SSL
import qualified OpenSSL.Session as SSL
import           OpenSSL.X509.SystemStore
import           OpenSSL.X509 (X509)
import qualified OpenSSL.X509 as X509
import qualified OpenSSL.PEM as PEM
import qualified OpenSSL.EVP.Digest as Digest
import           Data.Attoparsec.ByteString (Parser)
import qualified Data.Attoparsec.ByteString as Parser

import           Hookup.OpenSSL (installVerification, getPubKeyDer)
import           Hookup.Socks5


-- | Parameters for 'connect'.
--
-- Common defaults for fields: 'defaultFamily', 'defaultTlsParams'
--
-- The address family can be specified in order to force only
-- IPv4 or IPv6 to be used. The default behavior is to support both.
-- It can be useful to specify exactly one of these in the case that
-- the other is misconfigured and a hostname is resolving to both.
--
-- When a 'SocksParams' is provided the connection will be established
-- using a SOCKS (version 5) proxy.
--
-- When a 'TlsParams' is provided the connection negotiate TLS at connect
-- time in order to protect the stream.
data ConnectionParams = ConnectionParams
  { ConnectionParams -> Family
cpFamily :: Family           -- ^ IP Protocol family (default 'AF_UNSPEC')
  , ConnectionParams -> HostName
cpHost  :: HostName          -- ^ Destination host
  , ConnectionParams -> PortNumber
cpPort  :: PortNumber        -- ^ Destination TCP port
  , ConnectionParams -> Maybe SocksParams
cpSocks :: Maybe SocksParams -- ^ Optional SOCKS parameters
  , ConnectionParams -> Maybe TlsParams
cpTls   :: Maybe TlsParams   -- ^ Optional TLS parameters
  }


-- | SOCKS connection parameters
data SocksParams = SocksParams
  { SocksParams -> HostName
spHost :: HostName   -- ^ SOCKS server host
  , SocksParams -> PortNumber
spPort :: PortNumber -- ^ SOCKS server port
  }


-- | TLS connection parameters. These parameters are passed to
-- OpenSSL when making a secure connection.
data TlsParams = TlsParams
  { TlsParams -> Maybe HostName
tpClientCertificate  :: Maybe FilePath -- ^ Path to client certificate
  , TlsParams -> Maybe HostName
tpClientPrivateKey   :: Maybe FilePath -- ^ Path to client private key
  , TlsParams -> Maybe HostName
tpServerCertificate  :: Maybe FilePath -- ^ Path to CA certificate bundle
  , TlsParams -> HostName
tpCipherSuite        :: String -- ^ OpenSSL cipher suite name (e.g. @\"HIGH\"@)
  , TlsParams -> Bool
tpInsecure           :: Bool -- ^ Disables certificate checking when 'True'
  }

-- | Type for errors that can be thrown by this package.
data ConnectionFailure
  -- | Failure during 'getAddrInfo' resolving remote host
  = HostnameResolutionFailure HostName String
  -- | Failure during 'connect' to remote host
  | ConnectionFailure [IOError]
  -- | Failure during 'recvLine'
  | LineTooLong
  -- | Incomplete line during 'recvLine'
  | LineTruncated
  -- | Socks command rejected by server by given reply code
  | SocksError CommandReply
  -- | Socks authentication method was not accepted
  | SocksAuthenticationError
  -- | Socks server sent an invalid message or no message.
  | SocksProtocolError
  -- | Domain name was too long for SOCKS protocol
  | SocksBadDomainName
  deriving Int -> ConnectionFailure -> ShowS
[ConnectionFailure] -> ShowS
ConnectionFailure -> HostName
(Int -> ConnectionFailure -> ShowS)
-> (ConnectionFailure -> HostName)
-> ([ConnectionFailure] -> ShowS)
-> Show ConnectionFailure
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionFailure] -> ShowS
$cshowList :: [ConnectionFailure] -> ShowS
show :: ConnectionFailure -> HostName
$cshow :: ConnectionFailure -> HostName
showsPrec :: Int -> ConnectionFailure -> ShowS
$cshowsPrec :: Int -> ConnectionFailure -> ShowS
Show

-- | 'displayException' implemented for prettier messages
instance Exception ConnectionFailure where
  displayException :: ConnectionFailure -> HostName
displayException LineTruncated = "connection closed while reading line"
  displayException LineTooLong   = "line length exceeded maximum"
  displayException (ConnectionFailure xs :: [IOError]
xs) =
    "connection attempt failed due to: " HostName -> ShowS
forall a. [a] -> [a] -> [a]
++
      HostName -> [HostName] -> HostName
forall a. [a] -> [[a]] -> [a]
intercalate ", " ((IOError -> HostName) -> [IOError] -> [HostName]
forall a b. (a -> b) -> [a] -> [b]
map IOError -> HostName
forall e. Exception e => e -> HostName
displayException [IOError]
xs)
  displayException (HostnameResolutionFailure h :: HostName
h s :: HostName
s) =
    "hostname resolution failed (" HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
h HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ "): "  HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
s
  displayException SocksAuthenticationError =
    "SOCKS authentication method rejected"
  displayException SocksProtocolError =
    "SOCKS server protocol error"
  displayException SocksBadDomainName =
    "SOCKS domain name length limit exceeded"
  displayException (SocksError reply :: CommandReply
reply) =
    "SOCKS command rejected: " HostName -> ShowS
forall a. [a] -> [a] -> [a]
++
    case CommandReply
reply of
      Succeeded         -> "succeeded"
      GeneralFailure    -> "general SOCKS server failure"
      NotAllowed        -> "connection not allowed by ruleset"
      NetUnreachable    -> "network unreachable"
      HostUnreachable   -> "host unreachable"
      ConnectionRefused -> "connection refused"
      TTLExpired        -> "TTL expired"
      CmdNotSupported   -> "command not supported"
      AddrNotSupported  -> "address type not supported"
      CommandReply n :: Word8
n    -> "unknown reply " HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ Word8 -> HostName
forall a. Show a => a -> HostName
show Word8
n

-- | Default 'Family' value is unspecified and allows both INET and INET6.
defaultFamily :: Socket.Family
defaultFamily :: Family
defaultFamily = Family
Socket.AF_UNSPEC

-- | Default values for TLS that use no client certificates, use
-- system CA root, @\"HIGH\"@ cipher suite, and which validate hostnames.
defaultTlsParams :: TlsParams
defaultTlsParams :: TlsParams
defaultTlsParams = TlsParams :: Maybe HostName
-> Maybe HostName
-> Maybe HostName
-> HostName
-> Bool
-> TlsParams
TlsParams
  { tpClientCertificate :: Maybe HostName
tpClientCertificate  = Maybe HostName
forall a. Maybe a
Nothing
  , tpClientPrivateKey :: Maybe HostName
tpClientPrivateKey   = Maybe HostName
forall a. Maybe a
Nothing
  , tpServerCertificate :: Maybe HostName
tpServerCertificate  = Maybe HostName
forall a. Maybe a
Nothing -- use system provided CAs
  , tpCipherSuite :: HostName
tpCipherSuite        = "HIGH"
  , tpInsecure :: Bool
tpInsecure           = Bool
False
  }

------------------------------------------------------------------------
-- Opening sockets
------------------------------------------------------------------------

-- | Open a socket using the given parameters either directly or
-- via a SOCKS server.
openSocket :: ConnectionParams -> IO Socket
openSocket :: ConnectionParams -> IO Socket
openSocket params :: ConnectionParams
params =
  case ConnectionParams -> Maybe SocksParams
cpSocks ConnectionParams
params of
    Nothing -> Family -> HostName -> PortNumber -> IO Socket
openSocket' (ConnectionParams -> Family
cpFamily ConnectionParams
params) (ConnectionParams -> HostName
cpHost ConnectionParams
params) (ConnectionParams -> PortNumber
cpPort ConnectionParams
params)
    Just sp :: SocksParams
sp ->
      do Socket
sock <- Family -> HostName -> PortNumber -> IO Socket
openSocket' (ConnectionParams -> Family
cpFamily ConnectionParams
params) (SocksParams -> HostName
spHost SocksParams
sp) (SocksParams -> PortNumber
spPort SocksParams
sp)
         (Socket
sock Socket -> IO () -> IO Socket
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Socket -> HostName -> PortNumber -> IO ()
socksConnect Socket
sock (ConnectionParams -> HostName
cpHost ConnectionParams
params) (ConnectionParams -> PortNumber
cpPort ConnectionParams
params))
           IO Socket -> IO () -> IO Socket
forall a b. IO a -> IO b -> IO a
`onException` Socket -> IO ()
Socket.close Socket
sock


netParse :: Show a => Socket -> Parser a -> IO a
netParse :: Socket -> Parser a -> IO a
netParse sock :: Socket
sock parser :: Parser a
parser =
  do -- receiving 1 byte at a time is not efficient, but these messages
     -- are very short and we don't want to read any more from the socket
     -- than is necessary
     Result a
result <- IO ByteString -> Parser a -> ByteString -> IO (Result a)
forall (m :: * -> *) a.
Monad m =>
m ByteString -> Parser a -> ByteString -> m (Result a)
Parser.parseWith
                 (Socket -> Int -> IO ByteString
SocketB.recv Socket
sock 1)
                 Parser a
parser
                 ByteString
B.empty
     case Result a
result of
       Parser.Done i :: ByteString
i x :: a
x | ByteString -> Bool
B.null ByteString
i -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
       _ -> ConnectionFailure -> IO a
forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
SocksProtocolError


socksConnect :: Socket -> HostName -> PortNumber -> IO ()
socksConnect :: Socket -> HostName -> PortNumber -> IO ()
socksConnect sock :: Socket
sock host :: HostName
host port :: PortNumber
port =
  do Socket -> ByteString -> IO ()
SocketB.sendAll Socket
sock (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$
       ClientHello -> ByteString
buildClientHello ClientHello :: [AuthMethod] -> ClientHello
ClientHello
         { cHelloMethods :: [AuthMethod]
cHelloMethods = [AuthMethod
AuthNoAuthenticationRequired] }

     ServerHello -> IO ()
validateHello (ServerHello -> IO ()) -> IO ServerHello -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Socket -> Parser ServerHello -> IO ServerHello
forall a. Show a => Socket -> Parser a -> IO a
netParse Socket
sock Parser ServerHello
parseServerHello

     let dnBytes :: ByteString
dnBytes = HostName -> ByteString
B8.pack HostName
host
     Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Int
B.length ByteString
dnBytes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 256)
       (ConnectionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
SocksBadDomainName)

     Socket -> ByteString -> IO ()
SocketB.sendAll Socket
sock (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$
       Request -> ByteString
buildRequest Request :: Command -> Address -> Request
Request
         { reqCommand :: Command
reqCommand  = Command
Connect
         , reqAddress :: Address
reqAddress  = Host -> PortNumber -> Address
Address (ByteString -> Host
DomainName ByteString
dnBytes) PortNumber
port
         }

     Response -> IO ()
validateResponse (Response -> IO ()) -> IO Response -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Socket -> Parser Response -> IO Response
forall a. Show a => Socket -> Parser a -> IO a
netParse Socket
sock Parser Response
parseResponse


validateHello :: ServerHello -> IO ()
validateHello :: ServerHello -> IO ()
validateHello hello :: ServerHello
hello =
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ServerHello -> AuthMethod
sHelloMethod ServerHello
hello AuthMethod -> AuthMethod -> Bool
forall a. Eq a => a -> a -> Bool
== AuthMethod
AuthNoAuthenticationRequired)
    (ConnectionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
SocksAuthenticationError)

validateResponse :: Response -> IO ()
validateResponse :: Response -> IO ()
validateResponse response :: Response
response =
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Response -> CommandReply
rspReply Response
response CommandReply -> CommandReply -> Bool
forall a. Eq a => a -> a -> Bool
== CommandReply
Succeeded )
    (ConnectionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO (CommandReply -> ConnectionFailure
SocksError (Response -> CommandReply
rspReply Response
response)))


openSocket' :: Family -> HostName -> PortNumber -> IO Socket
openSocket' :: Family -> HostName -> PortNumber -> IO Socket
openSocket' family :: Family
family h :: HostName
h p :: PortNumber
p =
  do let hints :: AddrInfo
hints = AddrInfo
Socket.defaultHints
           { addrFamily :: Family
Socket.addrFamily     = Family
family
           , addrSocketType :: SocketType
Socket.addrSocketType = SocketType
Socket.Stream
           , addrFlags :: [AddrInfoFlag]
Socket.addrFlags      = [AddrInfoFlag
Socket.AI_ADDRCONFIG
                                     ,AddrInfoFlag
Socket.AI_NUMERICSERV]
           }
     Either IOError [AddrInfo]
res <- IO [AddrInfo] -> IO (Either IOError [AddrInfo])
forall e a. Exception e => IO a -> IO (Either e a)
try (Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
Socket.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
h) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just (PortNumber -> HostName
forall a. Show a => a -> HostName
show PortNumber
p)))
     case Either IOError [AddrInfo]
res of
       Right ais :: [AddrInfo]
ais -> [IOError] -> [AddrInfo] -> IO Socket
attemptConnections [] [AddrInfo]
ais
       Left  ioe :: IOError
ioe
         | IOError -> Bool
isDoesNotExistError IOError
ioe ->
             ConnectionFailure -> IO Socket
forall e a. Exception e => e -> IO a
throwIO (HostName -> HostName -> ConnectionFailure
HostnameResolutionFailure HostName
h (IOError -> HostName
ioeGetErrorString IOError
ioe))
         | Bool
otherwise -> IOError -> IO Socket
forall e a. Exception e => e -> IO a
throwIO IOError
ioe -- unexpected


-- | Try establishing a connection to the services indicated by
-- a given list of 'AddrInfo' values. Either return a socket that
-- has successfully connected to one of the candidate 'AddrInfo's
-- or throw a 'ConnectionFailure' exception will all of the
-- encountered errors.
attemptConnections ::
  [IOError]         {- ^ accumulated errors  -} ->
  [Socket.AddrInfo] {- ^ candidate AddrInfos -} ->
  IO Socket         {- ^ connected socket    -}
attemptConnections :: [IOError] -> [AddrInfo] -> IO Socket
attemptConnections exs :: [IOError]
exs [] = ConnectionFailure -> IO Socket
forall e a. Exception e => e -> IO a
throwIO ([IOError] -> ConnectionFailure
ConnectionFailure [IOError]
exs)
attemptConnections exs :: [IOError]
exs (ai :: AddrInfo
ai:ais :: [AddrInfo]
ais) =
  do Either IOError Socket
res <- IO Socket -> IO (Either IOError Socket)
forall e a. Exception e => IO a -> IO (Either e a)
try (AddrInfo -> IO Socket
connectToAddrInfo AddrInfo
ai)
     case Either IOError Socket
res of
       Left ex :: IOError
ex -> [IOError] -> [AddrInfo] -> IO Socket
attemptConnections (IOError
exIOError -> [IOError] -> [IOError]
forall a. a -> [a] -> [a]
:[IOError]
exs) [AddrInfo]
ais
       Right s :: Socket
s -> Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
s

-- | Create a socket and connect to the service identified
-- by the given 'AddrInfo' and return the connected socket.
connectToAddrInfo :: AddrInfo -> IO Socket
connectToAddrInfo :: AddrInfo -> IO Socket
connectToAddrInfo info :: AddrInfo
info
  = IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (AddrInfo -> IO Socket
socket' AddrInfo
info) Socket -> IO ()
Socket.close
  ((Socket -> IO Socket) -> IO Socket)
-> (Socket -> IO Socket) -> IO Socket
forall a b. (a -> b) -> a -> b
$ \s :: Socket
s -> Socket
s Socket -> IO () -> IO Socket
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Socket -> SockAddr -> IO ()
Socket.connect Socket
s (AddrInfo -> SockAddr
Socket.addrAddress AddrInfo
info)

-- | Open a 'Socket' using the parameters from an 'AddrInfo'
socket' :: AddrInfo -> IO Socket
socket' :: AddrInfo -> IO Socket
socket' ai :: AddrInfo
ai =
  Family -> SocketType -> ProtocolNumber -> IO Socket
Socket.socket
    (AddrInfo -> Family
Socket.addrFamily     AddrInfo
ai)
    (AddrInfo -> SocketType
Socket.addrSocketType AddrInfo
ai)
    (AddrInfo -> ProtocolNumber
Socket.addrProtocol   AddrInfo
ai)


------------------------------------------------------------------------
-- Generalization of Socket
------------------------------------------------------------------------

data NetworkHandle = SSL (Maybe X509) SSL | Socket Socket


openNetworkHandle ::
  ConnectionParams {- ^ parameters             -} ->
  IO Socket        {- ^ socket creation action -} ->
  IO NetworkHandle {- ^ open network handle    -}
openNetworkHandle :: ConnectionParams -> IO Socket -> IO NetworkHandle
openNetworkHandle params :: ConnectionParams
params mkSocket :: IO Socket
mkSocket =
  case ConnectionParams -> Maybe TlsParams
cpTls ConnectionParams
params of
    Nothing  -> Socket -> NetworkHandle
Socket (Socket -> NetworkHandle) -> IO Socket -> IO NetworkHandle
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Socket
mkSocket
    Just tls :: TlsParams
tls ->
        do (clientCert :: Maybe X509
clientCert, ssl :: SSL
ssl) <- TlsParams -> HostName -> IO Socket -> IO (Maybe X509, SSL)
startTls TlsParams
tls (ConnectionParams -> HostName
cpHost ConnectionParams
params) IO Socket
mkSocket
           NetworkHandle -> IO NetworkHandle
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe X509 -> SSL -> NetworkHandle
SSL Maybe X509
clientCert SSL
ssl)


closeNetworkHandle :: NetworkHandle -> IO ()
closeNetworkHandle :: NetworkHandle -> IO ()
closeNetworkHandle (Socket s :: Socket
s) = Socket -> IO ()
Socket.close Socket
s
closeNetworkHandle (SSL _ s :: SSL
s) =
  do SSL -> ShutdownType -> IO ()
SSL.shutdown SSL
s ShutdownType
SSL.Unidirectional
     (Socket -> IO ()) -> Maybe Socket -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ Socket -> IO ()
Socket.close (SSL -> Maybe Socket
SSL.sslSocket SSL
s)

networkSend :: NetworkHandle -> ByteString -> IO ()
networkSend :: NetworkHandle -> ByteString -> IO ()
networkSend (Socket s :: Socket
s) = Socket -> ByteString -> IO ()
SocketB.sendAll Socket
s
networkSend (SSL  _ s :: SSL
s) = SSL -> ByteString -> IO ()
SSL.write       SSL
s

networkRecv :: NetworkHandle -> Int -> IO ByteString
networkRecv :: NetworkHandle -> Int -> IO ByteString
networkRecv (Socket s :: Socket
s) = Socket -> Int -> IO ByteString
SocketB.recv Socket
s
networkRecv (SSL  _ s :: SSL
s) = SSL -> Int -> IO ByteString
SSL.read     SSL
s


------------------------------------------------------------------------
-- Sockets with a receive buffer
------------------------------------------------------------------------

-- | A connection to a network service along with its read buffer
-- used for line-oriented protocols. The connection could be a plain
-- network connection, SOCKS connected, or TLS.
data Connection = Connection (MVar ByteString) NetworkHandle

-- | Open network connection to TCP service specified by
-- the given parameters.
--
-- The resulting connection MUST be closed with 'close' to avoid leaking
-- resources.
--
-- Throws 'IOError', 'SocksError', 'SSL.ProtocolError', 'ConnectionFailure'
connect ::
  ConnectionParams {- ^ parameters      -} ->
  IO Connection    {- ^ open connection -}
connect :: ConnectionParams -> IO Connection
connect params :: ConnectionParams
params =
  do NetworkHandle
h <- ConnectionParams -> IO Socket -> IO NetworkHandle
openNetworkHandle ConnectionParams
params (ConnectionParams -> IO Socket
openSocket ConnectionParams
params)
     MVar ByteString
b <- ByteString -> IO (MVar ByteString)
forall a. a -> IO (MVar a)
newMVar ByteString
B.empty
     Connection -> IO Connection
forall (m :: * -> *) a. Monad m => a -> m a
return (MVar ByteString -> NetworkHandle -> Connection
Connection MVar ByteString
b NetworkHandle
h)

-- | Create a new 'Connection' using an already connected socket.
-- This will attempt to start TLS if configured but will ignore
-- any SOCKS server settings as it is assumed that the socket
-- is already actively connected to the intended service.
--
-- Throws 'SSL.ProtocolError'
connectWithSocket ::
  ConnectionParams {- ^ parameters       -} ->
  Socket           {- ^ connected socket -} ->
  IO Connection    {- ^ open connection  -}
connectWithSocket :: ConnectionParams -> Socket -> IO Connection
connectWithSocket params :: ConnectionParams
params sock :: Socket
sock =
  do NetworkHandle
h <- ConnectionParams -> IO Socket -> IO NetworkHandle
openNetworkHandle ConnectionParams
params (Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock)
     MVar ByteString
b <- ByteString -> IO (MVar ByteString)
forall a. a -> IO (MVar a)
newMVar ByteString
B.empty
     Connection -> IO Connection
forall (m :: * -> *) a. Monad m => a -> m a
return (MVar ByteString -> NetworkHandle -> Connection
Connection MVar ByteString
b NetworkHandle
h)

-- | Close network connection.
close ::
  Connection {- ^ open connection -} ->
  IO ()
close :: Connection -> IO ()
close (Connection _ h :: NetworkHandle
h) = NetworkHandle -> IO ()
closeNetworkHandle NetworkHandle
h

-- | Receive the next chunk from the stream. This operation will first
-- return the buffer if it contains a non-empty chunk. Otherwise it will
-- request up to the requested number of bytes from the stream.
--
-- Throws: 'IOError', 'SSL.ConnectionAbruptlyTerminated', 'SSL.ProtocolError'
recv ::
  Connection    {- ^ open connection              -} ->
  Int           {- ^ maximum underlying recv size -} ->
  IO ByteString {- ^ next chunk from stream       -}
recv :: Connection -> Int -> IO ByteString
recv (Connection buf :: MVar ByteString
buf h :: NetworkHandle
h) n :: Int
n =
  do ByteString
bufChunk <- MVar ByteString -> ByteString -> IO ByteString
forall a. MVar a -> a -> IO a
swapMVar MVar ByteString
buf ByteString
B.empty
     if ByteString -> Bool
B.null ByteString
bufChunk
       then NetworkHandle -> Int -> IO ByteString
networkRecv NetworkHandle
h Int
n
       else ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bufChunk

-- | Receive a line from the network connection. Both
-- @"\\r\\n"@ and @"\\n"@ are recognized.
--
-- Returning 'Nothing' means that the peer has closed its half of
-- the connection.
--
-- Unterminated lines will raise a 'LineTruncated' exception. This
-- can happen if the peer transmits some data and closes its end
-- without transmitting a line terminator.
--
-- Throws: 'SSL.ConnectionAbruptlyTerminated', 'SSL.ProtocolError', 'ConnectionFailure', 'IOError'
recvLine ::
  Connection            {- ^ open connection            -} ->
  Int                   {- ^ maximum line length        -} ->
  IO (Maybe ByteString) {- ^ next line or end-of-stream -}
recvLine :: Connection -> Int -> IO (Maybe ByteString)
recvLine (Connection buf :: MVar ByteString
buf h :: NetworkHandle
h) n :: Int
n =
  MVar ByteString
-> (ByteString -> IO (ByteString, Maybe ByteString))
-> IO (Maybe ByteString)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar ByteString
buf ((ByteString -> IO (ByteString, Maybe ByteString))
 -> IO (Maybe ByteString))
-> (ByteString -> IO (ByteString, Maybe ByteString))
-> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ \bs :: ByteString
bs ->
    Int
-> ByteString -> [ByteString] -> IO (ByteString, Maybe ByteString)
go (ByteString -> Int
B.length ByteString
bs) ByteString
bs []
  where
    -- bsn: cached length of concatenation of (bs:bss)
    -- bs : most recent chunk
    -- bss: other chunks ordered from most to least recent
    go :: Int
-> ByteString -> [ByteString] -> IO (ByteString, Maybe ByteString)
go bsn :: Int
bsn bs :: ByteString
bs bss :: [ByteString]
bss =
      case Char -> ByteString -> Maybe Int
B8.elemIndex '\n' ByteString
bs of
        Just i :: Int
i -> (ByteString, Maybe ByteString) -> IO (ByteString, Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteString
B.tail ByteString
b, -- tail drops newline
                          ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> ByteString
cleanEnd ([ByteString] -> ByteString
B.concat ([ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse (ByteString
aByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
bss)))))
          where
            (a :: ByteString
a,b :: ByteString
b) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
i ByteString
bs
        Nothing ->
          do Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
bsn Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n) (ConnectionFailure -> IO ()
forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
LineTooLong)
             ByteString
more <- NetworkHandle -> Int -> IO ByteString
networkRecv NetworkHandle
h Int
n
             if ByteString -> Bool
B.null ByteString
more -- connection closed
               then if Int
bsn Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then (ByteString, Maybe ByteString) -> IO (ByteString, Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
B.empty, Maybe ByteString
forall a. Maybe a
Nothing)
                                else ConnectionFailure -> IO (ByteString, Maybe ByteString)
forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
LineTruncated
               else Int
-> ByteString -> [ByteString] -> IO (ByteString, Maybe ByteString)
go (Int
bsn Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ByteString -> Int
B.length ByteString
more) ByteString
more (ByteString
bsByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
bss)


-- | Push a 'ByteString' onto the buffer so that it will be the first
-- bytes to be read on the next receive operation. This could perhaps
-- be useful for putting the unused portion of a 'recv' back into the
-- buffer for future 'recvLine' or 'recv' operations.
putBuf ::
  Connection {- ^ connection         -} ->
  ByteString {- ^ new head of buffer -} ->
  IO ()
putBuf :: Connection -> ByteString -> IO ()
putBuf (Connection buf :: MVar ByteString
buf h :: NetworkHandle
h) bs :: ByteString
bs =
  MVar ByteString -> (ByteString -> IO ByteString) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar ByteString
buf (\old :: ByteString
old -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$! ByteString -> ByteString -> ByteString
B.append ByteString
bs ByteString
old)


-- | Remove the trailing @'\\r'@ if one is found.
cleanEnd :: ByteString -> ByteString
cleanEnd :: ByteString -> ByteString
cleanEnd bs :: ByteString
bs
  | ByteString -> Bool
B.null ByteString
bs Bool -> Bool -> Bool
|| ByteString -> Char
B8.last ByteString
bs Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= '\r' = ByteString
bs
  | Bool
otherwise                       = ByteString -> ByteString
B.init ByteString
bs


-- | Send bytes on the network connection. This ensures the whole chunk is
-- transmitted, which might take multiple underlying sends.
--
-- Throws: 'IOError', 'SSL.ProtocolError'
send ::
  Connection {- ^ open connection -} ->
  ByteString {- ^ chunk           -} ->
  IO ()
send :: Connection -> ByteString -> IO ()
send (Connection _ h :: NetworkHandle
h) = NetworkHandle -> ByteString -> IO ()
networkSend NetworkHandle
h


------------------------------------------------------------------------


-- | Initiate a TLS session on the given socket destined for
-- the given hostname. When successful an active TLS connection
-- is returned with certificate verification successful when
-- requested. This function requires that the TLSParams component
-- of 'ConnectionParams' is set.
startTls ::
  TlsParams {- ^ connection params      -} ->
  String    {- ^ hostname               -} ->
  IO Socket {- ^ socket creation action -} ->
  IO (Maybe X509, SSL) {- ^ (client certificate, connected TLS) -}
startTls :: TlsParams -> HostName -> IO Socket -> IO (Maybe X509, SSL)
startTls tp :: TlsParams
tp hostname :: HostName
hostname mkSocket :: IO Socket
mkSocket = IO (Maybe X509, SSL) -> IO (Maybe X509, SSL)
forall a. IO a -> IO a
SSL.withOpenSSL (IO (Maybe X509, SSL) -> IO (Maybe X509, SSL))
-> IO (Maybe X509, SSL) -> IO (Maybe X509, SSL)
forall a b. (a -> b) -> a -> b
$
  do SSLContext
ctx <- IO SSLContext
SSL.context

     -- configure context
     SSLContext -> HostName -> IO ()
SSL.contextSetCiphers          SSLContext
ctx (TlsParams -> HostName
tpCipherSuite TlsParams
tp)
     SSLContext -> HostName -> IO ()
installVerification            SSLContext
ctx HostName
hostname
     SSLContext -> VerificationMode -> IO ()
SSL.contextSetVerificationMode SSLContext
ctx (Bool -> VerificationMode
verificationMode (TlsParams -> Bool
tpInsecure TlsParams
tp))
     SSLContext -> SSLOption -> IO ()
SSL.contextAddOption           SSLContext
ctx SSLOption
SSL.SSL_OP_ALL
     SSLContext -> SSLOption -> IO ()
SSL.contextRemoveOption        SSLContext
ctx SSLOption
SSL.SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS

     -- configure certificates
     SSLContext -> Maybe HostName -> IO ()
setupCaCertificates SSLContext
ctx (TlsParams -> Maybe HostName
tpServerCertificate TlsParams
tp)
     Maybe X509
clientCert <- (HostName -> IO X509) -> Maybe HostName -> IO (Maybe X509)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SSLContext -> HostName -> IO X509
setupCertificate SSLContext
ctx) (TlsParams -> Maybe HostName
tpClientCertificate TlsParams
tp)
     (HostName -> IO ()) -> Maybe HostName -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (SSLContext -> HostName -> IO ()
setupPrivateKey  SSLContext
ctx) (TlsParams -> Maybe HostName
tpClientPrivateKey  TlsParams
tp)

     -- add socket to context
     -- creation of the socket is delayed until this point to avoid
     -- leaking the file descriptor in the cases of exceptions above.
     SSL
ssl <- SSLContext -> Socket -> IO SSL
SSL.connection SSLContext
ctx (Socket -> IO SSL) -> IO Socket -> IO SSL
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO Socket
mkSocket

     -- configure hostname used for certificate validation
     SSL -> HostName -> IO ()
SSL.setTlsextHostName SSL
ssl HostName
hostname

     SSL -> IO ()
SSL.connect SSL
ssl

     (Maybe X509, SSL) -> IO (Maybe X509, SSL)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe X509
clientCert, SSL
ssl)


setupCaCertificates :: SSLContext -> Maybe FilePath -> IO ()
setupCaCertificates :: SSLContext -> Maybe HostName -> IO ()
setupCaCertificates ctx :: SSLContext
ctx mbPath :: Maybe HostName
mbPath =
  case Maybe HostName
mbPath of
    Nothing   -> SSLContext -> IO ()
contextLoadSystemCerts SSLContext
ctx
    Just path :: HostName
path -> SSLContext -> HostName -> IO ()
SSL.contextSetCAFile SSLContext
ctx HostName
path


setupCertificate :: SSLContext -> FilePath -> IO X509
setupCertificate :: SSLContext -> HostName -> IO X509
setupCertificate ctx :: SSLContext
ctx path :: HostName
path =
  do X509
x509 <- HostName -> IO X509
PEM.readX509 (HostName -> IO X509) -> IO HostName -> IO X509
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< HostName -> IO HostName
readFile HostName
path -- EX
     SSLContext -> X509 -> IO ()
SSL.contextSetCertificate SSLContext
ctx X509
x509
     X509 -> IO X509
forall (f :: * -> *) a. Applicative f => a -> f a
pure X509
x509


setupPrivateKey :: SSLContext -> FilePath -> IO ()
setupPrivateKey :: SSLContext -> HostName -> IO ()
setupPrivateKey ctx :: SSLContext
ctx path :: HostName
path =
  do HostName
str <- HostName -> IO HostName
readFile HostName
path -- EX
     SomeKeyPair
key <- HostName -> PemPasswordSupply -> IO SomeKeyPair
PEM.readPrivateKey HostName
str PemPasswordSupply
PEM.PwNone -- TODO: add password support
     SSLContext -> SomeKeyPair -> IO ()
forall k. KeyPair k => SSLContext -> k -> IO ()
SSL.contextSetPrivateKey SSLContext
ctx SomeKeyPair
key


verificationMode :: Bool {- ^ insecure -} -> SSL.VerificationMode
verificationMode :: Bool -> VerificationMode
verificationMode insecure :: Bool
insecure
  | Bool
insecure  = VerificationMode
SSL.VerifyNone
  | Bool
otherwise = VerifyPeer :: Bool
-> Bool
-> Maybe (Bool -> X509StoreCtx -> IO Bool)
-> VerificationMode
SSL.VerifyPeer
                  { vpFailIfNoPeerCert :: Bool
SSL.vpFailIfNoPeerCert = Bool
True
                  , vpClientOnce :: Bool
SSL.vpClientOnce       = Bool
True
                  , vpCallback :: Maybe (Bool -> X509StoreCtx -> IO Bool)
SSL.vpCallback         = Maybe (Bool -> X509StoreCtx -> IO Bool)
forall a. Maybe a
Nothing
                  }

-- | Get peer certificate if one exists.
getPeerCertificate :: Connection -> IO (Maybe X509.X509)
getPeerCertificate :: Connection -> IO (Maybe X509)
getPeerCertificate (Connection _ h :: NetworkHandle
h) =
  case NetworkHandle
h of
    Socket{} -> Maybe X509 -> IO (Maybe X509)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe X509
forall a. Maybe a
Nothing
    SSL _ ssl :: SSL
ssl -> SSL -> IO (Maybe X509)
SSL.getPeerCertificate SSL
ssl

-- | Get peer certificate if one exists.
getClientCertificate :: Connection -> Maybe X509.X509
getClientCertificate :: Connection -> Maybe X509
getClientCertificate (Connection _ h :: NetworkHandle
h) =
  case NetworkHandle
h of
    Socket{} -> Maybe X509
forall a. Maybe a
Nothing
    SSL c :: Maybe X509
c _  -> Maybe X509
c

getPeerCertFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha1 = HostName -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint "sha1"

getPeerCertFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha256 = HostName -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint "sha256"

getPeerCertFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha512 = HostName -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint "sha512"

getPeerCertFingerprint :: String -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint :: HostName -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint name :: HostName
name h :: Connection
h =
   do Maybe X509
mb <- Connection -> IO (Maybe X509)
getPeerCertificate Connection
h
      case Maybe X509
mb of
        Nothing -> Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
        Just x509 :: X509
x509 ->
          do ByteString
der <- X509 -> IO ByteString
X509.writeDerX509 X509
x509
             Maybe Digest
mbdigest <- HostName -> IO (Maybe Digest)
Digest.getDigestByName HostName
name
             case Maybe Digest
mbdigest of
               Nothing -> Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
               Just digest :: Digest
digest -> Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$! ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$! Digest -> ByteString -> ByteString
Digest.digestLBS Digest
digest ByteString
der

getPeerPubkeyFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha1 = HostName -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint "sha1"

getPeerPubkeyFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha256 = HostName -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint "sha256"

getPeerPubkeyFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha512 = HostName -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint "sha512"


getPeerPubkeyFingerprint :: String -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint :: HostName -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint name :: HostName
name h :: Connection
h =
   do Maybe X509
mb <- Connection -> IO (Maybe X509)
getPeerCertificate Connection
h
      case Maybe X509
mb of
        Nothing -> Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
        Just x509 :: X509
x509 ->
          do ByteString
der <- X509 -> IO ByteString
getPubKeyDer X509
x509
             Maybe Digest
mbdigest <- HostName -> IO (Maybe Digest)
Digest.getDigestByName HostName
name
             case Maybe Digest
mbdigest of
               Nothing -> Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
               Just digest :: Digest
digest -> Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$! ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$! Digest -> ByteString -> ByteString
Digest.digestBS Digest
digest ByteString
der