-- |
-- Module      : Crypto.PubKey.RSA.Prim
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
--
module Crypto.PubKey.RSA.Prim
    (
    -- * decrypt primitive
      dp
    -- * encrypt primitive
    , ep
    ) where

import Data.ByteString (ByteString)
import Crypto.PubKey.RSA.Types (Blinder(..))
import Crypto.Types.PubKey.RSA
import Crypto.Number.ModArithmetic (expFast, expSafe)
import Crypto.Number.Serialize (os2ip, i2ospOf_)

{- dpSlow computes the decrypted message not using any precomputed cache value.
   only n and d need to valid. -}
dpSlow :: PrivateKey -> ByteString -> ByteString
dpSlow :: PrivateKey -> ByteString -> ByteString
dpSlow pk :: PrivateKey
pk c :: ByteString
c = Int -> Integer -> ByteString
i2ospOf_ (PrivateKey -> Int
private_size PrivateKey
pk) (Integer -> ByteString) -> Integer -> ByteString
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer -> Integer
expSafe (ByteString -> Integer
os2ip ByteString
c) (PrivateKey -> Integer
private_d PrivateKey
pk) (PrivateKey -> Integer
private_n PrivateKey
pk)

{- dpFast computes the decrypted message more efficiently if the
   precomputed private values are available. mod p and mod q are faster
   to compute than mod pq -}
dpFast :: Blinder -> PrivateKey -> ByteString -> ByteString
dpFast :: Blinder -> PrivateKey -> ByteString -> ByteString
dpFast (Blinder r :: Integer
r rm1 :: Integer
rm1) pk :: PrivateKey
pk c :: ByteString
c =
    Int -> Integer -> ByteString
i2ospOf_ (PrivateKey -> Int
private_size PrivateKey
pk) (Integer -> Integer -> Integer -> Integer
multiplication Integer
rm1 (Integer
m2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
h Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (PrivateKey -> Integer
private_q PrivateKey
pk)) (PrivateKey -> Integer
private_n PrivateKey
pk))
    where
        re :: Integer
re  = Integer -> Integer -> Integer -> Integer
expFast Integer
r (PublicKey -> Integer
public_e (PublicKey -> Integer) -> PublicKey -> Integer
forall a b. (a -> b) -> a -> b
$ PrivateKey -> PublicKey
private_pub PrivateKey
pk) (PrivateKey -> Integer
private_n PrivateKey
pk)
        iC :: Integer
iC  = Integer -> Integer -> Integer -> Integer
multiplication Integer
re (ByteString -> Integer
os2ip ByteString
c) (PrivateKey -> Integer
private_n PrivateKey
pk)
        m1 :: Integer
m1  = Integer -> Integer -> Integer -> Integer
expSafe Integer
iC (PrivateKey -> Integer
private_dP PrivateKey
pk) (PrivateKey -> Integer
private_p PrivateKey
pk)
        m2 :: Integer
m2  = Integer -> Integer -> Integer -> Integer
expSafe Integer
iC (PrivateKey -> Integer
private_dQ PrivateKey
pk) (PrivateKey -> Integer
private_q PrivateKey
pk)
        h :: Integer
h   = ((PrivateKey -> Integer
private_qinv PrivateKey
pk) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
m1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
m2)) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` (PrivateKey -> Integer
private_p PrivateKey
pk)

dpFastNoBlinder :: PrivateKey -> ByteString -> ByteString
dpFastNoBlinder :: PrivateKey -> ByteString -> ByteString
dpFastNoBlinder pk :: PrivateKey
pk c :: ByteString
c = Int -> Integer -> ByteString
i2ospOf_ (PrivateKey -> Int
private_size PrivateKey
pk) (Integer
m2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
h Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (PrivateKey -> Integer
private_q PrivateKey
pk))
     where iC :: Integer
iC = ByteString -> Integer
os2ip ByteString
c
           m1 :: Integer
m1 = Integer -> Integer -> Integer -> Integer
expSafe Integer
iC (PrivateKey -> Integer
private_dP PrivateKey
pk) (PrivateKey -> Integer
private_p PrivateKey
pk)
           m2 :: Integer
m2 = Integer -> Integer -> Integer -> Integer
expSafe Integer
iC (PrivateKey -> Integer
private_dQ PrivateKey
pk) (PrivateKey -> Integer
private_q PrivateKey
pk)
           h :: Integer
h  = ((PrivateKey -> Integer
private_qinv PrivateKey
pk) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
m1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
m2)) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` (PrivateKey -> Integer
private_p PrivateKey
pk)

-- | Compute the RSA decrypt primitive.
-- if the p and q numbers are available, then dpFast is used
-- otherwise, we use dpSlow which only need d and n.
dp :: Maybe Blinder -> PrivateKey -> ByteString -> ByteString
dp :: Maybe Blinder -> PrivateKey -> ByteString -> ByteString
dp blinder :: Maybe Blinder
blinder pk :: PrivateKey
pk
    | PrivateKey -> Integer
private_p PrivateKey
pk Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= 0 Bool -> Bool -> Bool
&& PrivateKey -> Integer
private_q PrivateKey
pk Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= 0 = (PrivateKey -> ByteString -> ByteString)
-> (Blinder -> PrivateKey -> ByteString -> ByteString)
-> Maybe Blinder
-> PrivateKey
-> ByteString
-> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe PrivateKey -> ByteString -> ByteString
dpFastNoBlinder Blinder -> PrivateKey -> ByteString -> ByteString
dpFast Maybe Blinder
blinder (PrivateKey -> ByteString -> ByteString)
-> PrivateKey -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ PrivateKey
pk
    | Bool
otherwise                              = PrivateKey -> ByteString -> ByteString
dpSlow PrivateKey
pk

-- | Compute the RSA encrypt primitive
ep :: PublicKey -> ByteString -> ByteString
ep :: PublicKey -> ByteString -> ByteString
ep pk :: PublicKey
pk m :: ByteString
m = Int -> Integer -> ByteString
i2ospOf_ (PublicKey -> Int
public_size PublicKey
pk) (Integer -> ByteString) -> Integer -> ByteString
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer -> Integer
expFast (ByteString -> Integer
os2ip ByteString
m) (PublicKey -> Integer
public_e PublicKey
pk) (PublicKey -> Integer
public_n PublicKey
pk)

-- | multiply 2 integers in Zm only performing the modulo operation if necessary
multiplication :: Integer -> Integer -> Integer -> Integer
multiplication :: Integer -> Integer -> Integer -> Integer
multiplication a :: Integer
a b :: Integer
b m :: Integer
m = (Integer
a Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
b) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
m