{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}

module Internal.Sparse(
    GMatrix(..), CSR(..), mkCSR, fromCSR, impureCSR,
    mkSparse, mkDiagR, mkDense,
    AssocMatrix,
    toDense,
    gmXv, (!#>)
)where

import Internal.Vector
import Internal.Matrix
import Internal.Numeric
import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as M
import Control.Arrow((***))
import Control.Monad(when, foldM)
import Control.Monad.ST (runST)
import Control.Monad.Primitive (PrimMonad)
import Data.List(sort)
import Foreign.C.Types(CInt(..))

import Internal.Devel
import System.IO.Unsafe(unsafePerformIO)
import Foreign(Ptr)
import Text.Printf(printf)

type AssocMatrix = [(IndexOf Matrix, Double)]

data CSR = CSR
        { CSR -> Vector Double
csrVals  :: Vector Double
        , CSR -> Vector CInt
csrCols  :: Vector CInt
        , CSR -> Vector CInt
csrRows  :: Vector CInt
        , CSR -> Int
csrNRows :: Int
        , CSR -> Int
csrNCols :: Int
        } deriving Int -> CSR -> ShowS
[CSR] -> ShowS
CSR -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [CSR] -> ShowS
$cshowList :: [CSR] -> ShowS
show :: CSR -> [Char]
$cshow :: CSR -> [Char]
showsPrec :: Int -> CSR -> ShowS
$cshowsPrec :: Int -> CSR -> ShowS
Show

data CSC = CSC
        { CSC -> Vector Double
cscVals  :: Vector Double
        , CSC -> Vector CInt
cscRows  :: Vector CInt
        , CSC -> Vector CInt
cscCols  :: Vector CInt
        , CSC -> Int
cscNRows :: Int
        , CSC -> Int
cscNCols :: Int
        } deriving Int -> CSC -> ShowS
[CSC] -> ShowS
CSC -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [CSC] -> ShowS
$cshowList :: [CSC] -> ShowS
show :: CSC -> [Char]
$cshow :: CSC -> [Char]
showsPrec :: Int -> CSC -> ShowS
$cshowsPrec :: Int -> CSC -> ShowS
Show


-- | Produce a CSR sparse matrix from a association matrix.
mkCSR :: AssocMatrix -> CSR
mkCSR :: AssocMatrix -> CSR
mkCSR AssocMatrix
ms =
  forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) r.
PrimMonad m =>
(forall x.
 (x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
-> r
impureCSR forall {m :: * -> *} {t :: * -> *} {t} {a} {b}.
(Monad m, Foldable t) =>
(t -> a -> m t) -> m t -> (t -> m b) -> t a -> m b
runFold forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> [a]
sort AssocMatrix
ms
    where
  runFold :: (t -> a -> m t) -> m t -> (t -> m b) -> t a -> m b
runFold t -> a -> m t
next m t
initialise t -> m b
xtract t a
as0 = do
    t
i0  <- m t
initialise
    t
acc <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM t -> a -> m t
next t
i0 t a
as0
    t -> m b
xtract t
acc

-- | Produce a CSR sparse matrix by applying a generic folding function.
--
--   This allows one to build a CSR from an effectful streaming source
--   when combined with libraries like pipes, io-streams, or streaming.
--
--   For example
--
--   > impureCSR Pipes.Prelude.foldM :: PrimMonad m => Producer AssocEntry m () -> m CSR
--   > impureCSR Streaming.Prelude.foldM :: PrimMonad m => Stream (Of AssocEntry) m r -> m (Of CSR r)
--
impureCSR
    :: PrimMonad m
    => (forall x . (x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
    -> r
impureCSR :: forall (m :: * -> *) r.
PrimMonad m =>
(forall x.
 (x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
-> r
impureCSR forall x.
(x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r
f = forall x.
(x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r
f forall {m :: * -> *} {g} {a}.
(Ord g, PrintfArg g, PrimMonad m, Storable a, Num g, Enum g) =>
(MVector (PrimState m) a, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, g)
-> ((g, Int), a)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, g)
next m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
   MVector (PrimState m) CInt, Int, Int, Int, Int)
begin forall {m :: * -> *}.
PrimMonad m =>
(MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR
done
  where
    sfi :: Int -> CInt
sfi = forall a. Enum a => a -> a
succ forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> CInt
fi
    begin :: m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
   MVector (PrimState m) CInt, Int, Int, Int, Int)
begin = do
      MVector (PrimState m) Double
mv <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
64
      MVector (PrimState m) CInt
mr <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
64
      MVector (PrimState m) CInt
mc <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
64
      forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) Double
mv, MVector (PrimState m) CInt
mr, MVector (PrimState m) CInt
mc, Int
0, Int
0, Int
0, -Int
1)

    next :: (MVector (PrimState m) a, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, g)
-> ((g, Int), a)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, g)
next (!MVector (PrimState m) a
mv, !MVector (PrimState m) CInt
mr, !MVector (PrimState m) CInt
mc, !Int
idxVC, !Int
idxR, !Int
maxC, !g
curRow) ((g
r,Int
c),a
d) = do
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (g
r forall a. Ord a => a -> a -> Bool
< g
curRow) forall a b. (a -> b) -> a -> b
$
        forall a. HasCallStack => [Char] -> a
error (forall r. PrintfType r => [Char] -> r
printf [Char]
"impureCSR: row %i specified after %i" g
r g
curRow)

      let lenVC :: Int
lenVC = forall a s. Storable a => MVector s a -> Int
M.length MVector (PrimState m) a
mv
          lenR :: Int
lenR  = forall a s. Storable a => MVector s a -> Int
M.length MVector (PrimState m) CInt
mr
          maxC' :: Int
maxC' = forall a. Ord a => a -> a -> a
max Int
maxC Int
c

      (MVector (PrimState m) a
mv', MVector (PrimState m) CInt
mc') <-
        if Int
idxVC forall a. Ord a => a -> a -> Bool
>= Int
lenVC then do
          MVector (PrimState m) a
mv' <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
M.unsafeGrow MVector (PrimState m) a
mv Int
lenVC
          MVector (PrimState m) CInt
mc' <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
M.unsafeGrow MVector (PrimState m) CInt
mc Int
lenVC
          forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) a
mv', MVector (PrimState m) CInt
mc')
        else
          forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) a
mv, MVector (PrimState m) CInt
mc)

      MVector (PrimState m) CInt
mr' <-
        if Int
idxR forall a. Ord a => a -> a -> Bool
>= Int
lenR forall a. Num a => a -> a -> a
- Int
1 then
          forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
M.unsafeGrow MVector (PrimState m) CInt
mr Int
lenR
        else
          forall (m :: * -> *) a. Monad m => a -> m a
return MVector (PrimState m) CInt
mr

      forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) CInt
mc' Int
idxVC (Int -> CInt
sfi Int
c)
      forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) a
mv' Int
idxVC a
d

      Int
idxR' <-
        forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
          (\Int
idxR' g
_ -> Int
idxR' forall a. Num a => a -> a -> a
+ Int
1 forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) CInt
mr' Int
idxR' (Int -> CInt
sfi Int
idxVC))
          Int
idxR [g
1 .. (g
rforall a. Num a => a -> a -> a
-g
curRow)]

      forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) a
mv', MVector (PrimState m) CInt
mr', MVector (PrimState m) CInt
mc', Int
idxVC forall a. Num a => a -> a -> a
+ Int
1, Int
idxR', Int
maxC', g
r)

    done :: (MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR
done (!MVector (PrimState m) Double
mv, !MVector (PrimState m) CInt
mr, !MVector (PrimState m) CInt
mc, !Int
idxVC, !Int
idxR, !Int
maxC, !Int
curR) = do
      forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) CInt
mr Int
idxR (Int -> CInt
sfi Int
idxVC)
      Vector Double
vv <- forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (forall a s. Storable a => Int -> MVector s a -> MVector s a
M.unsafeTake Int
idxVC MVector (PrimState m) Double
mv)
      Vector CInt
vc <- forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (forall a s. Storable a => Int -> MVector s a -> MVector s a
M.unsafeTake Int
idxVC MVector (PrimState m) CInt
mc)
      Vector CInt
vr <- forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (forall a s. Storable a => Int -> MVector s a -> MVector s a
M.unsafeTake (Int
idxR forall a. Num a => a -> a -> a
+ Int
1)  MVector (PrimState m) CInt
mr)
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSR
CSR Vector Double
vv Vector CInt
vc Vector CInt
vr (forall a. Enum a => a -> a
succ Int
curR) (forall a. Enum a => a -> a
succ Int
maxC)


{- | General matrix with specialized internal representations for
     dense, sparse, diagonal, banded, and constant elements.

>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
>>> m
SparseR {gmCSR = CSR {csrVals = fromList [1.0,2.0],
                      csrCols = fromList [1000,2000],
                      csrRows = fromList [1,2,3],
                      csrNRows = 2,
                      csrNCols = 2000},
                      nRows = 2,
                      nCols = 2000}

>>> let m = mkDense (mat 2 [1..4])
>>> m
Dense {gmDense = (2><2)
 [ 1.0, 2.0
 , 3.0, 4.0 ], nRows = 2, nCols = 2}

-}
data GMatrix
    = SparseR
        { GMatrix -> CSR
gmCSR   :: CSR
        , GMatrix -> Int
nRows   :: Int
        , GMatrix -> Int
nCols   :: Int
        }
    | SparseC
        { GMatrix -> CSC
gmCSC   :: CSC
        , nRows   :: Int
        , nCols   :: Int
        }
    | Diag
        { GMatrix -> Vector Double
diagVals :: Vector Double
        , nRows    :: Int
        , nCols    :: Int
        }
    | Dense
        { GMatrix -> Matrix Double
gmDense :: Matrix Double
        , nRows   :: Int
        , nCols   :: Int
        }
--    | Banded
    deriving Int -> GMatrix -> ShowS
[GMatrix] -> ShowS
GMatrix -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [GMatrix] -> ShowS
$cshowList :: [GMatrix] -> ShowS
show :: GMatrix -> [Char]
$cshow :: GMatrix -> [Char]
showsPrec :: Int -> GMatrix -> ShowS
$cshowsPrec :: Int -> GMatrix -> ShowS
Show


mkDense :: Matrix Double -> GMatrix
mkDense :: Matrix Double -> GMatrix
mkDense Matrix Double
m = Dense{Int
Matrix Double
nCols :: Int
nRows :: Int
gmDense :: Matrix Double
gmDense :: Matrix Double
nCols :: Int
nRows :: Int
..}
  where
    gmDense :: Matrix Double
gmDense = Matrix Double
m
    nRows :: Int
nRows = forall t. Matrix t -> Int
rows Matrix Double
m
    nCols :: Int
nCols = forall t. Matrix t -> Int
cols Matrix Double
m

mkSparse :: AssocMatrix -> GMatrix
mkSparse :: AssocMatrix -> GMatrix
mkSparse = CSR -> GMatrix
fromCSR forall b c a. (b -> c) -> (a -> b) -> a -> c
. AssocMatrix -> CSR
mkCSR

fromCSR :: CSR -> GMatrix
fromCSR :: CSR -> GMatrix
fromCSR CSR
csr = SparseR {Int
CSR
nCols :: Int
nRows :: Int
gmCSR :: CSR
nCols :: Int
nRows :: Int
gmCSR :: CSR
..}
  where
    gmCSR :: CSR
gmCSR@CSR {Int
Vector Double
Vector CInt
csrNCols :: Int
csrNRows :: Int
csrRows :: Vector CInt
csrCols :: Vector CInt
csrVals :: Vector Double
csrNCols :: CSR -> Int
csrNRows :: CSR -> Int
csrRows :: CSR -> Vector CInt
csrCols :: CSR -> Vector CInt
csrVals :: CSR -> Vector Double
..} = CSR
csr
    nRows :: Int
nRows = Int
csrNRows
    nCols :: Int
nCols = Int
csrNCols


mkDiagR :: Int -> Int -> Vector Double -> GMatrix
mkDiagR :: Int -> Int -> Vector Double -> GMatrix
mkDiagR Int
r Int
c Vector Double
v
    | forall t. Storable t => Vector t -> Int
dim Vector Double
v forall a. Ord a => a -> a -> Bool
<= forall a. Ord a => a -> a -> a
min Int
r Int
c = Diag{Int
Vector Double
diagVals :: Vector Double
nCols :: Int
nRows :: Int
diagVals :: Vector Double
nCols :: Int
nRows :: Int
..}
    | Bool
otherwise = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ forall r. PrintfType r => [Char] -> r
printf [Char]
"mkDiagR: incorrect sizes (%d,%d) [%d]" Int
r Int
c (forall t. Storable t => Vector t -> Int
dim Vector Double
v)
  where
    nRows :: Int
nRows = Int
r
    nCols :: Int
nCols = Int
c
    diagVals :: Vector Double
diagVals = Vector Double
v


type IV t = CInt -> Ptr CInt   -> t
type  V t = CInt -> Ptr Double -> t
type SMxV = V (IV (IV (V (V (IO CInt)))))

gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv SparseR { gmCSR :: GMatrix -> CSR
gmCSR = CSR{Int
Vector Double
Vector CInt
csrNCols :: Int
csrNRows :: Int
csrRows :: Vector CInt
csrCols :: Vector CInt
csrVals :: Vector Double
csrNCols :: CSR -> Int
csrNRows :: CSR -> Int
csrRows :: CSR -> Vector CInt
csrCols :: CSR -> Vector CInt
csrVals :: CSR -> Vector Double
..}, Int
nCols :: Int
nRows :: Int
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
.. } Vector Double
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall t. Storable t => Vector t -> Int
dim Vector Double
v forall a. Eq a => a -> a -> Bool
/= Int
nCols) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => [Char] -> a
error (forall r. PrintfType r => [Char] -> r
printf [Char]
"gmXv (CSR): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (forall t. Storable t => Vector t -> Int
dim Vector Double
v))

    Vector Double
r <- forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
    (Vector Double
csrVals forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
csrCols forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
csrRows forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector Double
v forall c c1 r.
(TransArray c, TransArray c1) =>
c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
#! Vector Double
r) SMxV
c_smXv IO CInt -> [Char] -> IO ()
#|[Char]
"CSRXv"
    forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
r

gmXv SparseC { gmCSC :: GMatrix -> CSC
gmCSC = CSC{Int
Vector Double
Vector CInt
cscNCols :: Int
cscNRows :: Int
cscCols :: Vector CInt
cscRows :: Vector CInt
cscVals :: Vector Double
cscNCols :: CSC -> Int
cscNRows :: CSC -> Int
cscCols :: CSC -> Vector CInt
cscRows :: CSC -> Vector CInt
cscVals :: CSC -> Vector Double
..}, Int
nCols :: Int
nRows :: Int
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
.. } Vector Double
v = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall t. Storable t => Vector t -> Int
dim Vector Double
v forall a. Eq a => a -> a -> Bool
/= Int
nCols) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => [Char] -> a
error (forall r. PrintfType r => [Char] -> r
printf [Char]
"gmXv (CSC): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (forall t. Storable t => Vector t -> Int
dim Vector Double
v))

    Vector Double
r <- forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
    (Vector Double
cscVals forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
cscRows forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
cscCols forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector Double
v forall c c1 r.
(TransArray c, TransArray c1) =>
c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
#! Vector Double
r) SMxV
c_smTXv IO CInt -> [Char] -> IO ()
#|[Char]
"CSCXv"
    forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
r

gmXv Diag{Int
Vector Double
nCols :: Int
nRows :: Int
diagVals :: Vector Double
diagVals :: GMatrix -> Vector Double
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
..} Vector Double
v
    | forall t. Storable t => Vector t -> Int
dim Vector Double
v forall a. Eq a => a -> a -> Bool
== Int
nCols
        = forall t. Storable t => [Vector t] -> Vector t
vjoin [ forall t. Storable t => Int -> Int -> Vector t -> Vector t
subVector Int
0 (forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) Vector Double
v forall (c :: * -> *) e. Container c e => c e -> c e -> c e
`mul` Vector Double
diagVals
                , forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst Double
0 (Int
nRows forall a. Num a => a -> a -> a
- forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) ]
    | Bool
otherwise = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ forall r. PrintfType r => [Char] -> r
printf [Char]
"gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
                                 Int
nRows Int
nCols (forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) (forall t. Storable t => Vector t -> Int
dim Vector Double
v)

gmXv Dense{Int
Matrix Double
nCols :: Int
nRows :: Int
gmDense :: Matrix Double
gmDense :: GMatrix -> Matrix Double
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
..} Vector Double
v
    | forall t. Storable t => Vector t -> Int
dim Vector Double
v forall a. Eq a => a -> a -> Bool
== Int
nCols
        = forall t. Product t => Matrix t -> Vector t -> Vector t
mXv Matrix Double
gmDense Vector Double
v
    | Bool
otherwise = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ forall r. PrintfType r => [Char] -> r
printf [Char]
"gmXv (Dense): incorrect sizes: (%d,%d) x %d"
                                 Int
nRows Int
nCols (forall t. Storable t => Vector t -> Int
dim Vector Double
v)


{- | general matrix - vector product

>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
m :: GMatrix
>>> m !#> vector [1..2000]
[1000.0,4000.0]
it :: Vector Double

-}
infixr 8 !#>
(!#>) :: GMatrix -> Vector Double -> Vector Double
!#> :: GMatrix -> Vector Double -> Vector Double
(!#>) = GMatrix -> Vector Double -> Vector Double
gmXv

--------------------------------------------------------------------------------

foreign import ccall unsafe "smXv"
  c_smXv :: SMxV

foreign import ccall unsafe "smTXv"
  c_smTXv :: SMxV

--------------------------------------------------------------------------------

toDense :: AssocMatrix -> Matrix Double
toDense :: AssocMatrix -> Matrix Double
toDense AssocMatrix
asm = forall (c :: * -> *) e.
Container c e =>
IndexOf c -> e -> [(IndexOf c, e)] -> c e
assoc (Int
rforall a. Num a => a -> a -> a
+Int
1,Int
cforall a. Num a => a -> a -> a
+Int
1) Double
0 AssocMatrix
asm
  where
    (Int
r,Int
c) = (forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ AssocMatrix
asm


instance Transposable CSR CSC
  where
    tr :: CSR -> CSC
tr (CSR Vector Double
vs Vector CInt
cs Vector CInt
rs Int
n Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSC
CSC Vector Double
vs Vector CInt
cs Vector CInt
rs Int
m Int
n
    tr' :: CSR -> CSC
tr' = forall m mt. Transposable m mt => m -> mt
tr

instance Transposable CSC CSR
  where
    tr :: CSC -> CSR
tr (CSC Vector Double
vs Vector CInt
rs Vector CInt
cs Int
n Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSR
CSR Vector Double
vs Vector CInt
rs Vector CInt
cs Int
m Int
n
    tr' :: CSC -> CSR
tr' = forall m mt. Transposable m mt => m -> mt
tr

instance Transposable GMatrix GMatrix
  where
    tr :: GMatrix -> GMatrix
tr (SparseR CSR
s Int
n Int
m) = CSC -> Int -> Int -> GMatrix
SparseC (forall m mt. Transposable m mt => m -> mt
tr CSR
s) Int
m Int
n
    tr (SparseC CSC
s Int
n Int
m) = CSR -> Int -> Int -> GMatrix
SparseR (forall m mt. Transposable m mt => m -> mt
tr CSC
s) Int
m Int
n
    tr (Diag Vector Double
v Int
n Int
m) = Vector Double -> Int -> Int -> GMatrix
Diag Vector Double
v Int
m Int
n
    tr (Dense Matrix Double
a Int
n Int
m) = Matrix Double -> Int -> Int -> GMatrix
Dense (forall m mt. Transposable m mt => m -> mt
tr Matrix Double
a) Int
m Int
n
    tr' :: GMatrix -> GMatrix
tr' = forall m mt. Transposable m mt => m -> mt
tr