{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
{-# LANGUAGE RecordWildCards #-}

{-# OPTIONS_GHC -fno-warn-orphans #-}

module Internal.CG(
    cgSolve, cgSolve',
    CGState(..), R, V
) where

import Internal.Vector
import Internal.Matrix
import Internal.Numeric
import Internal.Element
import Internal.IO
import Internal.Container
import Internal.Sparse
import Numeric.Vector()
import Internal.Algorithms(linearSolveLS, linearSolve, relativeError, pnorm, NormType(..))
import Control.Arrow((***))

{-
import Util.Misc(debug, debugMat)

(//) :: Show a => a -> String -> a
infix 0 // -- , ///
a // b = debug b id a

(///) :: V -> String -> V
infix 0 ///
v /// b = debugMat b 2 asRow v
-}

type V = Vector R

data CGState = CGState
    { CGState -> Vector R
cgp  :: Vector R  -- ^ conjugate gradient
    , CGState -> Vector R
cgr  :: Vector R  -- ^ residual
    , CGState -> R
cgr2 :: R         -- ^ squared norm of residual
    , CGState -> Vector R
cgx  :: Vector R  -- ^ current solution
    , CGState -> R
cgdx :: R         -- ^ normalized size of correction
    }

cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState
cg :: Bool
-> (Vector R -> Vector R)
-> (Vector R -> Vector R)
-> CGState
-> CGState
cg Bool
sym Vector R -> Vector R
at Vector R -> Vector R
a (CGState Vector R
p Vector R
r R
r2 Vector R
x R
_) = Vector R -> Vector R -> R -> Vector R -> R -> CGState
CGState Vector R
p' Vector R
r' R
r'2 Vector R
x' R
rdx
  where
    ap1 :: Vector R
ap1 = Vector R -> Vector R
a Vector R
p
    ap :: Vector R
ap  | Bool
sym       = Vector R
ap1
        | Bool
otherwise = Vector R -> Vector R
at Vector R
ap1
    pap :: R
pap | Bool
sym       = Vector R
p forall t. Numeric t => Vector t -> Vector t -> t
<.> Vector R
ap1
        | Bool
otherwise = forall e. (Product e, Floating e) => Vector e -> RealOf e
norm2 Vector R
ap1 forall a. Floating a => a -> a -> a
** R
2
    alpha :: R
alpha = R
r2 forall a. Fractional a => a -> a -> a
/ R
pap
    dx :: Vector R
dx = forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale R
alpha Vector R
p
    x' :: Vector R
x' = Vector R
x forall a. Num a => a -> a -> a
+ Vector R
dx
    r' :: Vector R
r' = Vector R
r forall a. Num a => a -> a -> a
- forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale R
alpha Vector R
ap
    r'2 :: R
r'2 = Vector R
r' forall t. Numeric t => Vector t -> Vector t -> t
<.> Vector R
r'
    beta :: R
beta = R
r'2 forall a. Fractional a => a -> a -> a
/ R
r2
    p' :: Vector R
p' = Vector R
r' forall a. Num a => a -> a -> a
+ forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale R
beta Vector R
p

    rdx :: R
rdx = forall e. (Product e, Floating e) => Vector e -> RealOf e
norm2 Vector R
dx forall a. Fractional a => a -> a -> a
/ forall a. Ord a => a -> a -> a
max R
1 (forall e. (Product e, Floating e) => Vector e -> RealOf e
norm2 Vector R
x)

conjugrad
  :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState]
conjugrad :: Bool -> GMatrix -> Vector R -> Vector R -> R -> R -> [CGState]
conjugrad Bool
sym GMatrix
a Vector R
b = Bool
-> (Vector R -> Vector R)
-> (Vector R -> Vector R)
-> ((Vector R -> Vector R)
    -> (Vector R -> Vector R) -> CGState -> CGState)
-> Vector R
-> Vector R
-> R
-> R
-> [CGState]
solveG Bool
sym (forall m mt. Transposable m mt => m -> mt
tr GMatrix
a GMatrix -> Vector R -> Vector R
!#>) (GMatrix
a GMatrix -> Vector R -> Vector R
!#>) (Bool
-> (Vector R -> Vector R)
-> (Vector R -> Vector R)
-> CGState
-> CGState
cg Bool
sym) Vector R
b

solveG
    :: Bool
    -> (V -> V) -> (V -> V)
    -> ((V -> V) -> (V -> V) -> CGState -> CGState)
    -> V
    -> V
    -> R -> R
    -> [CGState]
solveG :: Bool
-> (Vector R -> Vector R)
-> (Vector R -> Vector R)
-> ((Vector R -> Vector R)
    -> (Vector R -> Vector R) -> CGState -> CGState)
-> Vector R
-> Vector R
-> R
-> R
-> [CGState]
solveG Bool
sym Vector R -> Vector R
mat Vector R -> Vector R
ma (Vector R -> Vector R)
-> (Vector R -> Vector R) -> CGState -> CGState
meth Vector R
rawb Vector R
x0' R
ϵb R
ϵx
    = forall a. (a -> Bool) -> [a] -> [a]
takeUntil CGState -> Bool
ok forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> a) -> a -> [a]
iterate ((Vector R -> Vector R)
-> (Vector R -> Vector R) -> CGState -> CGState
meth Vector R -> Vector R
mat Vector R -> Vector R
ma) forall a b. (a -> b) -> a -> b
$ Vector R -> Vector R -> R -> Vector R -> R -> CGState
CGState Vector R
p0 Vector R
r0 R
r20 Vector R
x0 R
1
  where
    a :: Vector R -> Vector R
a = if Bool
sym then Vector R -> Vector R
ma else Vector R -> Vector R
mat forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector R -> Vector R
ma
    b :: Vector R
b = if Bool
sym then Vector R
rawb else Vector R -> Vector R
mat Vector R
rawb
    x0 :: Vector R
x0  = if Vector R
x0' forall a. Eq a => a -> a -> Bool
== Vector R
0 then forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst R
0 (forall t. Storable t => Vector t -> Int
dim Vector R
b) else Vector R
x0'
    r0 :: Vector R
r0  = Vector R
b forall a. Num a => a -> a -> a
- Vector R -> Vector R
a Vector R
x0
    r20 :: R
r20 = Vector R
r0 forall t. Numeric t => Vector t -> Vector t -> t
<.> Vector R
r0
    p0 :: Vector R
p0  = Vector R
r0
    nb2 :: R
nb2 = Vector R
b forall t. Numeric t => Vector t -> Vector t -> t
<.> Vector R
b
    ok :: CGState -> Bool
ok CGState {R
Vector R
cgdx :: R
cgx :: Vector R
cgr2 :: R
cgr :: Vector R
cgp :: Vector R
cgdx :: CGState -> R
cgx :: CGState -> Vector R
cgr2 :: CGState -> R
cgr :: CGState -> Vector R
cgp :: CGState -> Vector R
..}
        =  R
cgr2 forall a. Ord a => a -> a -> Bool
<R
nb2forall a. Num a => a -> a -> a
*R
ϵbforall a. Floating a => a -> a -> a
**R
2
        Bool -> Bool -> Bool
|| R
cgdx forall a. Ord a => a -> a -> Bool
< R
ϵx


takeUntil :: (a -> Bool) -> [a] -> [a]
takeUntil :: forall a. (a -> Bool) -> [a] -> [a]
takeUntil a -> Bool
q [a]
xs = [a]
aforall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
take Int
1 [a]
b
  where
    ([a]
a,[a]
b) = forall a. (a -> Bool) -> [a] -> ([a], [a])
break a -> Bool
q [a]
xs

-- | Solve a sparse linear system using the conjugate gradient method with default parameters.
cgSolve
  :: Bool          -- ^ is symmetric
  -> GMatrix       -- ^ coefficient matrix
  -> Vector R      -- ^ right-hand side
  -> Vector R      -- ^ solution
cgSolve :: Bool -> GMatrix -> Vector R -> Vector R
cgSolve Bool
sym GMatrix
a Vector R
b  = CGState -> Vector R
cgx forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last forall a b. (a -> b) -> a -> b
$ Bool
-> R -> R -> Int -> GMatrix -> Vector R -> Vector R -> [CGState]
cgSolve' Bool
sym R
1E-4 R
1E-3 Int
n GMatrix
a Vector R
b Vector R
0
  where
    n :: Int
n = forall a. Ord a => a -> a -> a
max Int
10 (forall a b. (RealFrac a, Integral b) => a -> b
round forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
sqrt (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall t. Storable t => Vector t -> Int
dim Vector R
b) :: Double))

-- | Solve a sparse linear system using the conjugate gradient method with default parameters.
cgSolve'
  :: Bool      -- ^ symmetric
  -> R         -- ^ relative tolerance for the residual (e.g. 1E-4)
  -> R         -- ^ relative tolerance for δx (e.g. 1E-3)
  -> Int       -- ^ maximum number of iterations
  -> GMatrix   -- ^ coefficient matrix
  -> Vector R  -- ^ initial solution
  -> Vector R  -- ^ right-hand side
  -> [CGState] -- ^ solution
cgSolve' :: Bool
-> R -> R -> Int -> GMatrix -> Vector R -> Vector R -> [CGState]
cgSolve' Bool
sym R
er R
es Int
n GMatrix
a Vector R
b Vector R
x = forall a. Int -> [a] -> [a]
take Int
n forall a b. (a -> b) -> a -> b
$ Bool -> GMatrix -> Vector R -> Vector R -> R -> R -> [CGState]
conjugrad Bool
sym GMatrix
a Vector R
b Vector R
x R
er R
es


--------------------------------------------------------------------------------

instance Testable GMatrix
  where
    checkT :: GMatrix -> (Bool, IO ())
checkT GMatrix
_ = (Bool
ok,IO ()
info)
      where
        sma :: AssocMatrix
sma = Int -> Int -> AssocMatrix
convo2 Int
20 Int
3
        x1 :: Vector R
x1 = [R] -> Vector R
vect [R
1..R
20]
        x2 :: Vector R
x2 = [R] -> Vector R
vect [R
1..R
40]
        sm :: GMatrix
sm = AssocMatrix -> GMatrix
mkSparse [((Int, Int), R)]
sma
        dm :: Matrix R
dm = AssocMatrix -> Matrix R
toDense [((Int, Int), R)]
sma

        s1 :: Vector R
s1 = GMatrix
sm GMatrix -> Vector R -> Vector R
!#> Vector R
x1
        d1 :: Vector R
d1 = Matrix R
dm forall t. Numeric t => Matrix t -> Vector t -> Vector t
#> Vector R
x1

        s2 :: Vector R
s2 = forall m mt. Transposable m mt => m -> mt
tr GMatrix
sm GMatrix -> Vector R -> Vector R
!#> Vector R
x2
        d2 :: Vector R
d2 = forall m mt. Transposable m mt => m -> mt
tr Matrix R
dm forall t. Numeric t => Matrix t -> Vector t -> Vector t
#> Vector R
x2

        sdia :: GMatrix
sdia = Int -> Int -> Vector R -> GMatrix
mkDiagR Int
40 Int
20 ([R] -> Vector R
vect [R
1..R
10])
        s3 :: Vector R
s3 =    GMatrix
sdia GMatrix -> Vector R -> Vector R
!#> Vector R
x1
        s4 :: Vector R
s4 = forall m mt. Transposable m mt => m -> mt
tr GMatrix
sdia GMatrix -> Vector R -> Vector R
!#> Vector R
x2
        ddia :: Matrix R
ddia = forall t. Storable t => t -> Vector t -> Int -> Int -> Matrix t
diagRect R
0 ([R] -> Vector R
vect [R
1..R
10])  Int
40 Int
20
        d3 :: Vector R
d3 = Matrix R
ddia forall t. Numeric t => Matrix t -> Vector t -> Vector t
#> Vector R
x1
        d4 :: Vector R
d4 = forall m mt. Transposable m mt => m -> mt
tr Matrix R
ddia forall t. Numeric t => Matrix t -> Vector t -> Vector t
#> Vector R
x2

        v :: Vector R
v = Int -> Vector R
testb Int
40
        s5 :: Vector R
s5 = Bool -> GMatrix -> Vector R -> Vector R
cgSolve Bool
False GMatrix
sm Vector R
v
        d5 :: Vector R
d5 = forall {t}. Field t => Matrix t -> Vector t -> Vector t
denseSolve Matrix R
dm Vector R
v

        symassoc :: [((Int, Int), R)]
symassoc = [((Int
0,Int
0),R
1.0),((Int
1,Int
1),R
2.0),((Int
0,Int
1),R
0.5),((Int
1,Int
0),R
0.5)]
        b :: Vector R
b = [R] -> Vector R
vect [R
3,R
4]
        d6 :: Vector R
d6 = forall t. Element t => Matrix t -> Vector t
flatten forall a b. (a -> b) -> a -> b
$ forall t. Field t => Matrix t -> Matrix t -> Matrix t
linearSolve (AssocMatrix -> Matrix R
toDense [((Int, Int), R)]
symassoc) (forall a. Storable a => Vector a -> Matrix a
asColumn Vector R
b)
        s6 :: Vector R
s6 = Bool -> GMatrix -> Vector R -> Vector R
cgSolve Bool
True (AssocMatrix -> GMatrix
mkSparse [((Int, Int), R)]
symassoc) Vector R
b

        info :: IO ()
info = do
            forall a. Show a => a -> IO ()
print GMatrix
sm
            Matrix R -> IO ()
disp (AssocMatrix -> Matrix R
toDense [((Int, Int), R)]
sma)
            forall a. Show a => a -> IO ()
print Vector R
s1; forall a. Show a => a -> IO ()
print Vector R
d1
            forall a. Show a => a -> IO ()
print Vector R
s2; forall a. Show a => a -> IO ()
print Vector R
d2
            forall a. Show a => a -> IO ()
print Vector R
s3; forall a. Show a => a -> IO ()
print Vector R
d3
            forall a. Show a => a -> IO ()
print Vector R
s4; forall a. Show a => a -> IO ()
print Vector R
d4
            forall a. Show a => a -> IO ()
print Vector R
s5; forall a. Show a => a -> IO ()
print Vector R
d5
            forall a. Show a => a -> IO ()
print forall a b. (a -> b) -> a -> b
$ forall a. Num a => (a -> R) -> a -> a -> R
relativeError (forall (c :: * -> *) t. Normed c t => NormType -> c t -> RealOf t
pnorm NormType
Infinity) Vector R
s5 Vector R
d5
            forall a. Show a => a -> IO ()
print Vector R
s6; forall a. Show a => a -> IO ()
print Vector R
d6
            forall a. Show a => a -> IO ()
print forall a b. (a -> b) -> a -> b
$ forall a. Num a => (a -> R) -> a -> a -> R
relativeError (forall (c :: * -> *) t. Normed c t => NormType -> c t -> RealOf t
pnorm NormType
Infinity) Vector R
s6 Vector R
d6

        ok :: Bool
ok = Vector R
s1forall a. Eq a => a -> a -> Bool
==Vector R
d1
          Bool -> Bool -> Bool
&& Vector R
s2forall a. Eq a => a -> a -> Bool
==Vector R
d2
          Bool -> Bool -> Bool
&& Vector R
s3forall a. Eq a => a -> a -> Bool
==Vector R
d3
          Bool -> Bool -> Bool
&& Vector R
s4forall a. Eq a => a -> a -> Bool
==Vector R
d4
          Bool -> Bool -> Bool
&& forall a. Num a => (a -> R) -> a -> a -> R
relativeError (forall (c :: * -> *) t. Normed c t => NormType -> c t -> RealOf t
pnorm NormType
Infinity) Vector R
s5 Vector R
d5 forall a. Ord a => a -> a -> Bool
< R
1E-10
          Bool -> Bool -> Bool
&& forall a. Num a => (a -> R) -> a -> a -> R
relativeError (forall (c :: * -> *) t. Normed c t => NormType -> c t -> RealOf t
pnorm NormType
Infinity) Vector R
s6 Vector R
d6 forall a. Ord a => a -> a -> Bool
< R
1E-10

        disp :: Matrix R -> IO ()
disp = String -> IO ()
putStr forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Matrix R -> String
dispf Int
2

        vect :: [R] -> Vector R
vect = forall a. Storable a => [a] -> Vector a
fromList :: [Double] -> Vector Double

        convomat :: Int -> Int -> AssocMatrix
        convomat :: Int -> Int -> AssocMatrix
convomat Int
n Int
k = [ ((Int
i,Int
j forall a. Integral a => a -> a -> a
`mod` Int
n),R
1) | Int
i<-[Int
0..Int
nforall a. Num a => a -> a -> a
-Int
1], Int
j <- [Int
i..Int
iforall a. Num a => a -> a -> a
+Int
kforall a. Num a => a -> a -> a
-Int
1]]

        convo2 :: Int -> Int -> AssocMatrix
        convo2 :: Int -> Int -> AssocMatrix
convo2 Int
n Int
k = [((Int, Int), R)]
m1 forall a. [a] -> [a] -> [a]
++ [((Int, Int), R)]
m2
          where
            m1 :: AssocMatrix
m1 = Int -> Int -> AssocMatrix
convomat Int
n Int
k
            m2 :: [((Int, Int), R)]
m2 = forall a b. (a -> b) -> [a] -> [b]
map (((forall a. Num a => a -> a -> a
+Int
n) forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** forall a. a -> a
id) forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** forall a. a -> a
id) [((Int, Int), R)]
m1

        testb :: Int -> Vector R
testb Int
n = [R] -> Vector R
vect forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
n forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
cycle ([R
0..R
10]forall a. [a] -> [a] -> [a]
++[R
9,R
8..R
1])

        denseSolve :: Matrix t -> Vector t -> Vector t
denseSolve Matrix t
a = forall t. Element t => Matrix t -> Vector t
flatten forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Field t => Matrix t -> Matrix t -> Matrix t
linearSolveLS Matrix t
a forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Storable a => Vector a -> Matrix a
asColumn

        -- mkDiag v = mkDiagR (dim v) (dim v) v