{-# LANGUAGE
    MultiParamTypeClasses,
    FlexibleInstances, FlexibleContexts,
    UndecidableInstances,
    TemplateHaskell
  #-}

{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}

module Data.Random.Distribution.Bernoulli where

import Data.Random.Internal.TH

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform

import Data.Ratio
import Data.Complex

-- |Generate a Bernoulli variate with the given probability.  For @Bool@ results,
-- @bernoulli p@ will return True (p*100)% of the time and False otherwise.
-- For numerical types, True is replaced by 1 and False by 0.
bernoulli :: Distribution (Bernoulli b) a => b -> RVar a
bernoulli :: b -> RVar a
bernoulli p :: b
p = Bernoulli b a -> RVar a
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (b -> Bernoulli b a
forall b a. b -> Bernoulli b a
Bernoulli b
p)

-- |Generate a Bernoulli process with the given probability.  For @Bool@ results,
-- @bernoulli p@ will return True (p*100)% of the time and False otherwise.
-- For numerical types, True is replaced by 1 and False by 0.
bernoulliT :: Distribution (Bernoulli b) a => b -> RVarT m a
bernoulliT :: b -> RVarT m a
bernoulliT p :: b
p = Bernoulli b a -> RVarT m a
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (b -> Bernoulli b a
forall b a. b -> Bernoulli b a
Bernoulli b
p)

-- |A random variable whose value is 'True' the given fraction of the time
-- and 'False' the rest.
boolBernoulli :: (Fractional a, Ord a, Distribution StdUniform a) => a -> RVarT m Bool
boolBernoulli :: a -> RVarT m Bool
boolBernoulli p :: a
p = do
    a
x <- RVarT m a
forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
    Bool -> RVarT m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
p)

boolBernoulliCDF :: (Real a) => a -> Bool -> Double
boolBernoulliCDF :: a -> Bool -> Double
boolBernoulliCDF _ True  = 1
boolBernoulliCDF p :: a
p False = (1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
p)

-- | @generalBernoulli t f p@ generates a random variable whose value is @t@
-- with probability @p@ and @f@ with probability @1-p@.
generalBernoulli :: Distribution (Bernoulli b) Bool => a -> a -> b -> RVarT m a
generalBernoulli :: a -> a -> b -> RVarT m a
generalBernoulli f :: a
f t :: a
t p :: b
p = do
    Bool
x <- b -> RVarT m Bool
forall b a (m :: * -> *).
Distribution (Bernoulli b) a =>
b -> RVarT m a
bernoulliT b
p
    a -> RVarT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (if Bool
x then a
t else a
f)

generalBernoulliCDF :: CDF (Bernoulli b) Bool => (a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF :: (a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF gte :: a -> a -> Bool
gte f :: a
f t :: a
t p :: b
p x :: a
x
    | a
f a -> a -> Bool
`gte` a
t = [Char] -> Double
forall a. HasCallStack => [Char] -> a
error "generalBernoulliCDF: f >= t"
    | a
x a -> a -> Bool
`gte` a
t = Bernoulli b Bool -> Bool -> Double
forall (d :: * -> *) t. CDF d t => d t -> t -> Double
cdf (b -> Bernoulli b Bool
forall b a. b -> Bernoulli b a
Bernoulli b
p) Bool
True
    | a
x a -> a -> Bool
`gte` a
f = Bernoulli b Bool -> Bool -> Double
forall (d :: * -> *) t. CDF d t => d t -> t -> Double
cdf (b -> Bernoulli b Bool
forall b a. b -> Bernoulli b a
Bernoulli b
p) Bool
False
    | Bool
otherwise = 0

newtype Bernoulli b a = Bernoulli b

instance (Fractional b, Ord b, Distribution StdUniform b) 
       => Distribution (Bernoulli b) Bool
    where
        rvarT :: Bernoulli b Bool -> RVarT n Bool
rvarT (Bernoulli p :: b
p) = b -> RVarT n Bool
forall a (m :: * -> *).
(Fractional a, Ord a, Distribution StdUniform a) =>
a -> RVarT m Bool
boolBernoulli b
p
instance (Distribution (Bernoulli b) Bool, Real b)
       => CDF (Bernoulli b) Bool
    where
        cdf :: Bernoulli b Bool -> Bool -> Double
cdf  (Bernoulli p :: b
p) = b -> Bool -> Double
forall a. Real a => a -> Bool -> Double
boolBernoulliCDF b
p

$( replicateInstances ''Int integralTypes [d|
        instance Distribution (Bernoulli b) Bool 
              => Distribution (Bernoulli b) Int
              where
                  rvarT (Bernoulli p) = generalBernoulli 0 1 p
        instance CDF (Bernoulli b) Bool
              => CDF (Bernoulli b) Int
              where
                  cdf  (Bernoulli p) = generalBernoulliCDF (>=) 0 1 p
    |] )

$( replicateInstances ''Float realFloatTypes [d|
        instance Distribution (Bernoulli b) Bool 
              => Distribution (Bernoulli b) Float
              where
                  rvarT (Bernoulli p) = generalBernoulli 0 1 p
        instance CDF (Bernoulli b) Bool
              => CDF (Bernoulli b) Float
              where
                  cdf  (Bernoulli p) = generalBernoulliCDF (>=) 0 1 p
    |] )

instance (Distribution (Bernoulli b) Bool, Integral a)
       => Distribution (Bernoulli b) (Ratio a)   
       where
           rvarT :: Bernoulli b (Ratio a) -> RVarT n (Ratio a)
rvarT (Bernoulli p :: b
p) = Ratio a -> Ratio a -> b -> RVarT n (Ratio a)
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli 0 1 b
p
instance (CDF (Bernoulli b) Bool, Integral a)
       => CDF (Bernoulli b) (Ratio a)   
       where
           cdf :: Bernoulli b (Ratio a) -> Ratio a -> Double
cdf  (Bernoulli p :: b
p) = (Ratio a -> Ratio a -> Bool)
-> Ratio a -> Ratio a -> b -> Ratio a -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Ratio a -> Ratio a -> Bool
forall a. Ord a => a -> a -> Bool
(>=) 0 1 b
p
instance (Distribution (Bernoulli b) Bool, RealFloat a)
       => Distribution (Bernoulli b) (Complex a)
       where
           rvarT :: Bernoulli b (Complex a) -> RVarT n (Complex a)
rvarT (Bernoulli p :: b
p) = Complex a -> Complex a -> b -> RVarT n (Complex a)
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli 0 1 b
p
instance (CDF (Bernoulli b) Bool, RealFloat a)
       => CDF (Bernoulli b) (Complex a)
       where
           cdf :: Bernoulli b (Complex a) -> Complex a -> Double
cdf  (Bernoulli p :: b
p) = (Complex a -> Complex a -> Bool)
-> Complex a -> Complex a -> b -> Complex a -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF (\x :: Complex a
x y :: Complex a
y -> Complex a -> a
forall a. Complex a -> a
realPart Complex a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= Complex a -> a
forall a. Complex a -> a
realPart Complex a
y) 0 1 b
p