{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ViewPatterns #-}
module Internal.ST (
ST, runST,
STVector, newVector, thawVector, freezeVector, runSTVector,
readVector, writeVector, modifyVector, liftSTVector,
STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..),
newUndefinedVector,
unsafeReadVector, unsafeWriteVector,
unsafeThawVector, unsafeFreezeVector,
newUndefinedMatrix,
unsafeReadMatrix, unsafeWriteMatrix,
unsafeThawMatrix, unsafeFreezeMatrix
) where
import Internal.Vector
import Internal.Matrix
import Internal.Vectorized
import Control.Monad.ST(ST, runST)
import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
import Control.Monad.ST.Unsafe(unsafeIOToST)
{-# INLINE ioReadV #-}
ioReadV :: Storable t => Vector t -> Int -> IO t
ioReadV :: forall t. Storable t => Vector t -> Int -> IO t
ioReadV Vector t
v Int
k = forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
unsafeWith Vector t
v forall a b. (a -> b) -> a -> b
$ \Ptr t
s -> forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr t
s Int
k
{-# INLINE ioWriteV #-}
ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
ioWriteV :: forall t. Storable t => Vector t -> Int -> t -> IO ()
ioWriteV Vector t
v Int
k t
x = forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
unsafeWith Vector t
v forall a b. (a -> b) -> a -> b
$ \Ptr t
s -> forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr t
s Int
k t
x
newtype STVector s t = STVector (Vector t)
thawVector :: Storable t => Vector t -> ST s (STVector s t)
thawVector :: forall t s. Storable t => Vector t -> ST s (STVector s t)
thawVector = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall s t. Vector t -> STVector s t
STVector forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Storable t => Vector t -> IO (Vector t)
cloneVector
unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t)
unsafeThawVector :: forall t s. Storable t => Vector t -> ST s (STVector s t)
unsafeThawVector = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t. Vector t -> STVector s t
STVector
runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
runSTVector :: forall t. Storable t => (forall s. ST s (STVector s t)) -> Vector t
runSTVector forall s. ST s (STVector s t)
st = forall a. (forall s. ST s a) -> a
runST (forall s. ST s (STVector s t)
st forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall t s. Storable t => STVector s t -> ST s (Vector t)
unsafeFreezeVector)
{-# INLINE unsafeReadVector #-}
unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t
unsafeReadVector :: forall t s. Storable t => STVector s t -> Int -> ST s t
unsafeReadVector (STVector Vector t
x) = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Storable t => Vector t -> Int -> IO t
ioReadV Vector t
x
{-# INLINE unsafeWriteVector #-}
unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s ()
unsafeWriteVector :: forall t s. Storable t => STVector s t -> Int -> t -> ST s ()
unsafeWriteVector (STVector Vector t
x) Int
k = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Storable t => Vector t -> Int -> t -> IO ()
ioWriteV Vector t
x Int
k
{-# INLINE modifyVector #-}
modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
modifyVector :: forall t s.
Storable t =>
STVector s t -> Int -> (t -> t) -> ST s ()
modifyVector STVector s t
x Int
k t -> t
f = forall t s. Storable t => STVector s t -> Int -> ST s t
readVector STVector s t
x Int
k forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> t
f forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall t s. Storable t => STVector s t -> Int -> t -> ST s ()
unsafeWriteVector STVector s t
x Int
k
liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a
liftSTVector :: forall t a s.
Storable t =>
(Vector t -> a) -> STVector s t -> ST s a
liftSTVector Vector t -> a
f (STVector Vector t
x) = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Vector t -> a
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Storable t => Vector t -> IO (Vector t)
cloneVector forall a b. (a -> b) -> a -> b
$ Vector t
x
freezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
freezeVector :: forall t s. Storable t => STVector s t -> ST s (Vector t)
freezeVector STVector s t
v = forall t a s.
Storable t =>
(Vector t -> a) -> STVector s t -> ST s a
liftSTVector forall a. a -> a
id STVector s t
v
unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
unsafeFreezeVector :: forall t s. Storable t => STVector s t -> ST s (Vector t)
unsafeFreezeVector (STVector Vector t
x) = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Vector t
x
{-# INLINE safeIndexV #-}
safeIndexV :: Storable t2
=> (STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t
safeIndexV :: forall t2 s t t1.
Storable t2 =>
(STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t
safeIndexV STVector s t2 -> Int -> t
f (STVector Vector t2
v) Int
k
| Int
k forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
kforall a. Ord a => a -> a -> Bool
>= forall t. Storable t => Vector t -> Int
dim Vector t2
v = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"out of range error in vector (dim="
forall a. [a] -> [a] -> [a]
++forall a. Show a => a -> [Char]
show (forall t. Storable t => Vector t -> Int
dim Vector t2
v)forall a. [a] -> [a] -> [a]
++[Char]
", pos="forall a. [a] -> [a] -> [a]
++forall a. Show a => a -> [Char]
show Int
kforall a. [a] -> [a] -> [a]
++[Char]
")"
| Bool
otherwise = STVector s t2 -> Int -> t
f (forall s t. Vector t -> STVector s t
STVector Vector t2
v) Int
k
{-# INLINE readVector #-}
readVector :: Storable t => STVector s t -> Int -> ST s t
readVector :: forall t s. Storable t => STVector s t -> Int -> ST s t
readVector = forall t2 s t t1.
Storable t2 =>
(STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t
safeIndexV forall t s. Storable t => STVector s t -> Int -> ST s t
unsafeReadVector
{-# INLINE writeVector #-}
writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
writeVector :: forall t s. Storable t => STVector s t -> Int -> t -> ST s ()
writeVector = forall t2 s t t1.
Storable t2 =>
(STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t
safeIndexV forall t s. Storable t => STVector s t -> Int -> t -> ST s ()
unsafeWriteVector
newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
newUndefinedVector :: forall t s. Storable t => Int -> ST s (STVector s t)
newUndefinedVector = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall s t. Vector t -> STVector s t
STVector forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Storable a => Int -> IO (Vector a)
createVector
{-# INLINE newVector #-}
newVector :: Storable t => t -> Int -> ST s (STVector s t)
newVector :: forall t s. Storable t => t -> Int -> ST s (STVector s t)
newVector t
x Int
n = do
STVector s t
v <- forall t s. Storable t => Int -> ST s (STVector s t)
newUndefinedVector Int
n
let go :: Int -> ST s (STVector s t)
go (-1) = forall (m :: * -> *) a. Monad m => a -> m a
return STVector s t
v
go !Int
k = forall t s. Storable t => STVector s t -> Int -> t -> ST s ()
unsafeWriteVector STVector s t
v Int
k t
x forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ST s (STVector s t)
go (Int
kforall a. Num a => a -> a -> a
-Int
1 :: Int)
Int -> ST s (STVector s t)
go (Int
nforall a. Num a => a -> a -> a
-Int
1)
{-# INLINE ioReadM #-}
ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
ioReadM :: forall t. Storable t => Matrix t -> Int -> Int -> IO t
ioReadM Matrix t
m Int
r Int
c = forall t. Storable t => Vector t -> Int -> IO t
ioReadV (forall t. Matrix t -> Vector t
xdat Matrix t
m) (Int
r forall a. Num a => a -> a -> a
* forall t. Matrix t -> Int
xRow Matrix t
m forall a. Num a => a -> a -> a
+ Int
c forall a. Num a => a -> a -> a
* forall t. Matrix t -> Int
xCol Matrix t
m)
{-# INLINE ioWriteM #-}
ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
ioWriteM :: forall t. Storable t => Matrix t -> Int -> Int -> t -> IO ()
ioWriteM Matrix t
m Int
r Int
c t
val = forall t. Storable t => Vector t -> Int -> t -> IO ()
ioWriteV (forall t. Matrix t -> Vector t
xdat Matrix t
m) (Int
r forall a. Num a => a -> a -> a
* forall t. Matrix t -> Int
xRow Matrix t
m forall a. Num a => a -> a -> a
+ Int
c forall a. Num a => a -> a -> a
* forall t. Matrix t -> Int
xCol Matrix t
m) t
val
newtype STMatrix s t = STMatrix (Matrix t)
thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t)
thawMatrix :: forall t s. Element t => Matrix t -> ST s (STMatrix s t)
thawMatrix = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall s t. Matrix t -> STMatrix s t
STMatrix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Element t => Matrix t -> IO (Matrix t)
cloneMatrix
unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
unsafeThawMatrix :: forall t s. Storable t => Matrix t -> ST s (STMatrix s t)
unsafeThawMatrix = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t. Matrix t -> STMatrix s t
STMatrix
runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
runSTMatrix :: forall t. Storable t => (forall s. ST s (STMatrix s t)) -> Matrix t
runSTMatrix forall s. ST s (STMatrix s t)
st = forall a. (forall s. ST s a) -> a
runST (forall s. ST s (STMatrix s t)
st forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall t s. Storable t => STMatrix s t -> ST s (Matrix t)
unsafeFreezeMatrix)
{-# INLINE unsafeReadMatrix #-}
unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
unsafeReadMatrix :: forall t s. Storable t => STMatrix s t -> Int -> Int -> ST s t
unsafeReadMatrix (STMatrix Matrix t
x) Int
r = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Storable t => Matrix t -> Int -> Int -> IO t
ioReadM Matrix t
x Int
r
{-# INLINE unsafeWriteMatrix #-}
unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
unsafeWriteMatrix :: forall t s.
Storable t =>
STMatrix s t -> Int -> Int -> t -> ST s ()
unsafeWriteMatrix (STMatrix Matrix t
x) Int
r Int
c = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Storable t => Matrix t -> Int -> Int -> t -> IO ()
ioWriteM Matrix t
x Int
r Int
c
{-# INLINE modifyMatrix #-}
modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
modifyMatrix :: forall t s.
Storable t =>
STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
modifyMatrix STMatrix s t
x Int
r Int
c t -> t
f = forall t s. Storable t => STMatrix s t -> Int -> Int -> ST s t
readMatrix STMatrix s t
x Int
r Int
c forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> t
f forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall t s.
Storable t =>
STMatrix s t -> Int -> Int -> t -> ST s ()
unsafeWriteMatrix STMatrix s t
x Int
r Int
c
liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a
liftSTMatrix :: forall t a s.
Element t =>
(Matrix t -> a) -> STMatrix s t -> ST s a
liftSTMatrix Matrix t -> a
f (STMatrix Matrix t
x) = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Matrix t -> a
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Element t => Matrix t -> IO (Matrix t)
cloneMatrix forall a b. (a -> b) -> a -> b
$ Matrix t
x
unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
unsafeFreezeMatrix :: forall t s. Storable t => STMatrix s t -> ST s (Matrix t)
unsafeFreezeMatrix (STMatrix Matrix t
x) = forall a s. IO a -> ST s a
unsafeIOToST forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Matrix t
x
freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t)
freezeMatrix :: forall t s. Element t => STMatrix s t -> ST s (Matrix t)
freezeMatrix STMatrix s t
m = forall t a s.
Element t =>
(Matrix t -> a) -> STMatrix s t -> ST s a
liftSTMatrix forall a. a -> a
id STMatrix s t
m
cloneMatrix :: Element t => Matrix t -> IO (Matrix t)
cloneMatrix :: forall t. Element t => Matrix t -> IO (Matrix t)
cloneMatrix Matrix t
m = forall t. Element t => MatrixOrder -> Matrix t -> IO (Matrix t)
copy (forall t. Matrix t -> MatrixOrder
orderOf Matrix t
m) Matrix t
m
{-# INLINE safeIndexM #-}
safeIndexM :: (STMatrix s t2 -> Int -> Int -> t)
-> STMatrix t1 t2 -> Int -> Int -> t
safeIndexM :: forall s t2 t t1.
(STMatrix s t2 -> Int -> Int -> t)
-> STMatrix t1 t2 -> Int -> Int -> t
safeIndexM STMatrix s t2 -> Int -> Int -> t
f (STMatrix Matrix t2
m) Int
r Int
c
| Int
rforall a. Ord a => a -> a -> Bool
<Int
0 Bool -> Bool -> Bool
|| Int
rforall a. Ord a => a -> a -> Bool
>=forall t. Matrix t -> Int
rows Matrix t2
m Bool -> Bool -> Bool
||
Int
cforall a. Ord a => a -> a -> Bool
<Int
0 Bool -> Bool -> Bool
|| Int
cforall a. Ord a => a -> a -> Bool
>=forall t. Matrix t -> Int
cols Matrix t2
m = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"out of range error in matrix (size="
forall a. [a] -> [a] -> [a]
++forall a. Show a => a -> [Char]
show (forall t. Matrix t -> Int
rows Matrix t2
m,forall t. Matrix t -> Int
cols Matrix t2
m)forall a. [a] -> [a] -> [a]
++[Char]
", pos="forall a. [a] -> [a] -> [a]
++forall a. Show a => a -> [Char]
show (Int
r,Int
c)forall a. [a] -> [a] -> [a]
++[Char]
")"
| Bool
otherwise = STMatrix s t2 -> Int -> Int -> t
f (forall s t. Matrix t -> STMatrix s t
STMatrix Matrix t2
m) Int
r Int
c
{-# INLINE readMatrix #-}
readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
readMatrix :: forall t s. Storable t => STMatrix s t -> Int -> Int -> ST s t
readMatrix = forall s t2 t t1.
(STMatrix s t2 -> Int -> Int -> t)
-> STMatrix t1 t2 -> Int -> Int -> t
safeIndexM forall t s. Storable t => STMatrix s t -> Int -> Int -> ST s t
unsafeReadMatrix
{-# INLINE writeMatrix #-}
writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
writeMatrix :: forall t s.
Storable t =>
STMatrix s t -> Int -> Int -> t -> ST s ()
writeMatrix = forall s t2 t t1.
(STMatrix s t2 -> Int -> Int -> t)
-> STMatrix t1 t2 -> Int -> Int -> t
safeIndexM forall t s.
Storable t =>
STMatrix s t -> Int -> Int -> t -> ST s ()
unsafeWriteMatrix
setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s ()
setMatrix :: forall t s.
Element t =>
STMatrix s t -> Int -> Int -> Matrix t -> ST s ()
setMatrix (STMatrix Matrix t
x) Int
i Int
j Matrix t
m = forall a s. IO a -> ST s a
unsafeIOToST forall a b. (a -> b) -> a -> b
$ forall a. Element a => Int -> Int -> Matrix a -> Matrix a -> IO ()
setRect Int
i Int
j Matrix t
m Matrix t
x
newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
newUndefinedMatrix :: forall t s.
Storable t =>
MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
newUndefinedMatrix MatrixOrder
ord Int
r Int
c = forall a s. IO a -> ST s a
unsafeIOToST forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall s t. Matrix t -> STMatrix s t
STMatrix forall a b. (a -> b) -> a -> b
$ forall a. Storable a => MatrixOrder -> Int -> Int -> IO (Matrix a)
createMatrix MatrixOrder
ord Int
r Int
c
{-# NOINLINE newMatrix #-}
newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
newMatrix :: forall t s. Storable t => t -> Int -> Int -> ST s (STMatrix s t)
newMatrix t
v Int
r Int
c = forall t s. Storable t => Matrix t -> ST s (STMatrix s t)
unsafeThawMatrix forall a b. (a -> b) -> a -> b
$ forall t. Storable t => Int -> Vector t -> Matrix t
reshape Int
c forall a b. (a -> b) -> a -> b
$ forall t. Storable t => (forall s. ST s (STVector s t)) -> Vector t
runSTVector forall a b. (a -> b) -> a -> b
$ forall t s. Storable t => t -> Int -> ST s (STVector s t)
newVector t
v (Int
rforall a. Num a => a -> a -> a
*Int
c)
data ColRange = AllCols
| ColRange Int Int
| Col Int
| FromCol Int
getColRange :: Int -> ColRange -> (Int, Int)
getColRange :: Int -> ColRange -> (Int, Int)
getColRange Int
c ColRange
AllCols = (Int
0,Int
cforall a. Num a => a -> a -> a
-Int
1)
getColRange Int
c (ColRange Int
a Int
b) = (Int
a forall a. Integral a => a -> a -> a
`mod` Int
c, Int
b forall a. Integral a => a -> a -> a
`mod` Int
c)
getColRange Int
c (Col Int
a) = (Int
a forall a. Integral a => a -> a -> a
`mod` Int
c, Int
a forall a. Integral a => a -> a -> a
`mod` Int
c)
getColRange Int
c (FromCol Int
a) = (Int
a forall a. Integral a => a -> a -> a
`mod` Int
c, Int
cforall a. Num a => a -> a -> a
-Int
1)
data RowRange = AllRows
| RowRange Int Int
| Row Int
| FromRow Int
getRowRange :: Int -> RowRange -> (Int, Int)
getRowRange :: Int -> RowRange -> (Int, Int)
getRowRange Int
r RowRange
AllRows = (Int
0,Int
rforall a. Num a => a -> a -> a
-Int
1)
getRowRange Int
r (RowRange Int
a Int
b) = (Int
a forall a. Integral a => a -> a -> a
`mod` Int
r, Int
b forall a. Integral a => a -> a -> a
`mod` Int
r)
getRowRange Int
r (Row Int
a) = (Int
a forall a. Integral a => a -> a -> a
`mod` Int
r, Int
a forall a. Integral a => a -> a -> a
`mod` Int
r)
getRowRange Int
r (FromRow Int
a) = (Int
a forall a. Integral a => a -> a -> a
`mod` Int
r, Int
rforall a. Num a => a -> a -> a
-Int
1)
data RowOper t = AXPY t Int Int ColRange
| SCAL t RowRange ColRange
| SWAP Int Int ColRange
rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s ()
rowOper :: forall t s.
(Num t, Element t) =>
RowOper t -> STMatrix s t -> ST s ()
rowOper (AXPY t
x Int
i1 Int
i2 ColRange
r) (STMatrix Matrix t
m) = forall a s. IO a -> ST s a
unsafeIOToST forall a b. (a -> b) -> a -> b
$ forall a.
Element a =>
Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
rowOp Int
0 t
x Int
i1' Int
i2' Int
j1 Int
j2 Matrix t
m
where
(Int
j1,Int
j2) = Int -> ColRange -> (Int, Int)
getColRange (forall t. Matrix t -> Int
cols Matrix t
m) ColRange
r
i1' :: Int
i1' = Int
i1 forall a. Integral a => a -> a -> a
`mod` (forall t. Matrix t -> Int
rows Matrix t
m)
i2' :: Int
i2' = Int
i2 forall a. Integral a => a -> a -> a
`mod` (forall t. Matrix t -> Int
rows Matrix t
m)
rowOper (SCAL t
x RowRange
rr ColRange
rc) (STMatrix Matrix t
m) = forall a s. IO a -> ST s a
unsafeIOToST forall a b. (a -> b) -> a -> b
$ forall a.
Element a =>
Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
rowOp Int
1 t
x Int
i1 Int
i2 Int
j1 Int
j2 Matrix t
m
where
(Int
i1,Int
i2) = Int -> RowRange -> (Int, Int)
getRowRange (forall t. Matrix t -> Int
rows Matrix t
m) RowRange
rr
(Int
j1,Int
j2) = Int -> ColRange -> (Int, Int)
getColRange (forall t. Matrix t -> Int
cols Matrix t
m) ColRange
rc
rowOper (SWAP Int
i1 Int
i2 ColRange
r) (STMatrix Matrix t
m) = forall a s. IO a -> ST s a
unsafeIOToST forall a b. (a -> b) -> a -> b
$ forall a.
Element a =>
Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
rowOp Int
2 t
0 Int
i1' Int
i2' Int
j1 Int
j2 Matrix t
m
where
(Int
j1,Int
j2) = Int -> ColRange -> (Int, Int)
getColRange (forall t. Matrix t -> Int
cols Matrix t
m) ColRange
r
i1' :: Int
i1' = Int
i1 forall a. Integral a => a -> a -> a
`mod` (forall t. Matrix t -> Int
rows Matrix t
m)
i2' :: Int
i2' = Int
i2 forall a. Integral a => a -> a -> a
`mod` (forall t. Matrix t -> Int
rows Matrix t
m)
extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a)
(STMatrix Matrix a
m) RowRange
rr ColRange
rc = forall a s. IO a -> ST s a
unsafeIOToST (forall a.
Element a =>
MatrixOrder
-> Matrix a
-> CInt
-> Vector CInt
-> CInt
-> Vector CInt
-> IO (Matrix a)
extractR (forall t. Matrix t -> MatrixOrder
orderOf Matrix a
m) Matrix a
m CInt
0 ([Int] -> Vector CInt
idxs[Int
i1,Int
i2]) CInt
0 ([Int] -> Vector CInt
idxs[Int
j1,Int
j2]))
where
(Int
i1,Int
i2) = Int -> RowRange -> (Int, Int)
getRowRange (forall t. Matrix t -> Int
rows Matrix a
m) RowRange
rr
(Int
j1,Int
j2) = Int -> ColRange -> (Int, Int)
getColRange (forall t. Matrix t -> Int
cols Matrix a
m) ColRange
rc
data Slice s t = Slice (STMatrix s t) Int Int Int Int
slice :: Element a => Slice t a -> Matrix a
slice :: forall a t. Element a => Slice t a -> Matrix a
slice (Slice (STMatrix Matrix a
m) Int
r0 Int
c0 Int
nr Int
nc) = forall a.
Element a =>
(Int, Int) -> (Int, Int) -> Matrix a -> Matrix a
subMatrix (Int
r0,Int
c0) (Int
nr,Int
nc) Matrix a
m
gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
gemmm :: forall t s.
Element t =>
t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
gemmm t
beta (forall a t. Element a => Slice t a -> Matrix a
slice->Matrix t
r) t
alpha (forall a t. Element a => Slice t a -> Matrix a
slice->Matrix t
a) (forall a t. Element a => Slice t a -> Matrix a
slice->Matrix t
b) = forall {s}. ST s ()
res
where
res :: ST s ()
res = forall a s. IO a -> ST s a
unsafeIOToST (forall a.
Element a =>
Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
gemm Vector t
v Matrix t
a Matrix t
b Matrix t
r)
v :: Vector t
v = forall a. Storable a => [a] -> Vector a
fromList [t
alpha,t
beta]
mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
mutable :: forall t u.
Element t =>
(forall s. (Int, Int) -> STMatrix s t -> ST s u)
-> Matrix t -> (Matrix t, u)
mutable forall s. (Int, Int) -> STMatrix s t -> ST s u
f Matrix t
a = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
STMatrix s t
x <- forall t s. Element t => Matrix t -> ST s (STMatrix s t)
thawMatrix Matrix t
a
u
info <- forall s. (Int, Int) -> STMatrix s t -> ST s u
f (forall t. Matrix t -> Int
rows Matrix t
a, forall t. Matrix t -> Int
cols Matrix t
a) STMatrix s t
x
Matrix t
r <- forall t s. Storable t => STMatrix s t -> ST s (Matrix t)
unsafeFreezeMatrix STMatrix s t
x
forall (m :: * -> *) a. Monad m => a -> m a
return (Matrix t
r,u
info)