module Network.Mail.SMTP.Auth (
    UserName,
    Password,
    AuthType(..),
    encodeLogin,
    auth,
) where

import Crypto.Hash.MD5 (hash)
import qualified Data.ByteString.Base16 as B16  (encode)
import qualified Data.ByteString.Base64 as B64  (encode)

import Data.ByteString  (ByteString)
import Data.List
import Data.Bits
import qualified Data.ByteString       as B
import qualified Data.ByteString.Char8 as B8    (unwords)

type UserName = String
type Password = String

data AuthType
    = PLAIN
    | LOGIN
    | CRAM_MD5
    deriving AuthType -> AuthType -> Bool
(AuthType -> AuthType -> Bool)
-> (AuthType -> AuthType -> Bool) -> Eq AuthType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AuthType -> AuthType -> Bool
$c/= :: AuthType -> AuthType -> Bool
== :: AuthType -> AuthType -> Bool
$c== :: AuthType -> AuthType -> Bool
Eq

instance Show AuthType where
    showsPrec :: Int -> AuthType -> ShowS
showsPrec d :: Int
d at :: AuthType
at = Bool -> ShowS -> ShowS
showParen (Int
dInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
app_prec) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString (String -> ShowS) -> String -> ShowS
forall a b. (a -> b) -> a -> b
$ AuthType -> String
showMain AuthType
at
        where app_prec :: Int
app_prec = 10
              showMain :: AuthType -> String
showMain PLAIN    = "PLAIN"
              showMain LOGIN    = "LOGIN"
              showMain CRAM_MD5 = "CRAM-MD5"

toAscii :: String -> ByteString
toAscii :: String -> ByteString
toAscii = [Word8] -> ByteString
B.pack ([Word8] -> ByteString)
-> (String -> [Word8]) -> String -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Word8) -> String -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Word8
forall a. Enum a => Int -> a
toEnum(Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Char -> Int
forall a. Enum a => a -> Int
fromEnum)

b64Encode :: String -> ByteString
b64Encode :: String -> ByteString
b64Encode = ByteString -> ByteString
B64.encode (ByteString -> ByteString)
-> (String -> ByteString) -> String -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
toAscii

hmacMD5 :: ByteString -> ByteString -> ByteString
hmacMD5 :: ByteString -> ByteString -> ByteString
hmacMD5 text :: ByteString
text key :: ByteString
key = ByteString -> ByteString
hash (ByteString
okey ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
hash (ByteString
ikey ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
text))
    where key' :: ByteString
key' = if ByteString -> Int
B.length ByteString
key Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 64
                 then ByteString -> ByteString
hash ByteString
key ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> ByteString
B.replicate 48 0
                 else ByteString
key ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> ByteString
B.replicate (64Int -> Int -> Int
forall a. Num a => a -> a -> a
-ByteString -> Int
B.length ByteString
key) 0
          ipad :: ByteString
ipad = Int -> Word8 -> ByteString
B.replicate 64 0x36
          opad :: ByteString
opad = Int -> Word8 -> ByteString
B.replicate 64 0x5c
          ikey :: ByteString
ikey = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
key' ByteString
ipad
          okey :: ByteString
okey = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
key' ByteString
opad

encodePlain :: UserName -> Password -> ByteString
encodePlain :: String -> String -> ByteString
encodePlain user :: String
user pass :: String
pass = String -> ByteString
b64Encode (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate "\0" [String
user, String
user, String
pass]

encodeLogin :: UserName -> Password -> (ByteString, ByteString)
encodeLogin :: String -> String -> (ByteString, ByteString)
encodeLogin user :: String
user pass :: String
pass = (String -> ByteString
b64Encode String
user, String -> ByteString
b64Encode String
pass)

cramMD5 :: String -> UserName -> Password -> ByteString
cramMD5 :: String -> String -> String -> ByteString
cramMD5 challenge :: String
challenge user :: String
user pass :: String
pass =
    ByteString -> ByteString
B64.encode (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
B8.unwords [ByteString
user', ByteString -> ByteString
B16.encode (ByteString -> ByteString -> ByteString
hmacMD5 ByteString
challenge' ByteString
pass')]
  where
    challenge' :: ByteString
challenge' = String -> ByteString
toAscii String
challenge
    user' :: ByteString
user'      = String -> ByteString
toAscii String
user
    pass' :: ByteString
pass'      = String -> ByteString
toAscii String
pass

auth :: AuthType -> String -> UserName -> Password -> ByteString
auth :: AuthType -> String -> String -> String -> ByteString
auth PLAIN    _ u :: String
u p :: String
p = String -> String -> ByteString
encodePlain String
u String
p
auth LOGIN    _ u :: String
u p :: String
p = let (u' :: ByteString
u', p' :: ByteString
p') = String -> String -> (ByteString, ByteString)
encodeLogin String
u String
p in [ByteString] -> ByteString
B8.unwords [ByteString
u', ByteString
p']
auth CRAM_MD5 c :: String
c u :: String
u p :: String
p = String -> String -> String -> ByteString
cramMD5 String
c String
u String
p