Declaring a type class for multiplication of an N-by-N-element matrix and an N-element column vector

StackOverflow https://stackoverflow.com/questions/14811282

문제

In Haskell, if you have a "family" of types (say, N-by-N-element matrices, for some values of N), and a parallel family of "related" types (say, N-element vectors, for the same values of N), and an operation that requires one specific type from each family (say, multiplying an N-by-N-element matrix and an N-element column vector), is it then possible to declare a type class for that operation?

For this specific example, I imagine it would look something like this:

class MatrixNxN m where

  --| Multiplication of two N-by-N-element matrices
  mmul :: Num a => m a -> m a -> m a

  --| Multiplication of an N-by-N-element matrix and an N-element column vector
  vmul :: Num a => m a -> v a -> v a

I don't know how to constrain the type v, however. Is it possible to do something like this?

Please note that I welcome both answers to the general question of declaring a type class of multiple, related types, as well as answers to the specific question of declaring a type class for matrix-vector multiplication. In my specific case, there is only a small, known set of values of N (2, 3, and 4), but I'm generally interested in understanding what is possible to encode in Haskell's type system.

EDIT: I implemented this using MultiParamTypeClasses and FunctionalDependencies as suggested by Gabriel Gonzalez and MFlamer below. This is what the relevant bits of my implementation ended up looking like:

class MatrixVectorMultiplication m v | m -> v, v -> m where
  vmul :: Num a => m a -> v a -> v a

data Matrix3x3 a = ...
data Vector3 a = ...

instance MatrixVectorMultiplication Matrix3x3 Vector3 where
  vmul = ...

This is the type signature of vmul, on its own and partially applied:

vmul :: (Num a, MatrixVectorMultiplication m v) => m a -> v a -> v a
(`vmul` v) :: Matrix3x3 Integer -> Vector3 Integer
(m `vmul`) :: Vector3 Integer -> Vector3 Integer

I find this all very elegant. Thanks for the answers! :)

도움이 되었습니까?

해결책 2

This is a very tiny variation on MFlamer's answer, which also makes m dependent on v:

{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies #-}

class MatrixNxN m v | m -> v, v -> m where
    mmul :: Num a => m a -> m a -> m a
    vmul :: Num a => m a -> v a -> v a

That way, if you do:

(`vmul` someVector)

... then the compiler can select the correct instance on the basis of someVector's type alone.

The type family solution won't work for the same reason, mainly because if you declare the v type constructor to be a type function of the m type constructor, that type function is not necessarily 1-to-1, so the compiler wouldn't be able to infer m from v. This is why people say that functional dependencies are more powerful than type families.

다른 팁

Note that matrix/vector dimension can be encoded using type-level numerals which allows more generic definitions of class MatrixNum (n :: Nat) instead of manual coding of Matrix3x3, Matrix4x4 etc. while also preventing in compile type multiplication agains incompatible objects. In GHC 7.4.* it can be defined in the following way.

{-# LANGUAGE TypeFamilies, DataKinds, FlexibleInstances #-}

data Nat = Zero | Succ Nat

class MatrixNum (n :: Nat) where
    type Matrix n :: * -> *
    type Vector n :: * -> *
    mmul :: Num a => Matrix n a -> Matrix n a -> Matrix n a
    vmul :: Num a => Matrix n a -> Vector n a -> Vector n a

newtype ListMatrix (n :: Nat) a = ListMatrix [[a]] deriving Show
newtype ListVector (n :: Nat) a = ListVector [a] deriving Show

instance MatrixNum n where
    type Matrix n = ListMatrix n
    type Vector n = ListVector n
    mmul (ListMatrix xss) (ListMatrix yss) = ListMatrix $ error "Not implemented"
    vmul (ListMatrix xss) (ListVector ys) = ListVector $ error "Not implemented"

It is even nicer in GHC 7.6.* which now supports type-level promoted literals so you can drop the above Nat definitions and use Nat from GHC.TypeLits and use numeric literal in types to specify dimensions of your objects:

m1 :: ListMatrix 3 Int
m1 = ListMatrix [[1,2,3],[4,5,6],[7,8,9]]

v1 :: ListVector 3 Int
v1 = ListVector [1,2,3]

v2 = m1 `vmul` v1 -- has type ListVector 3 Int

This is how I have done this using MultiParamTypeClasses as also suggested in the responses above. You also need to use the FunctionalDependencies extension because both types are not used in each class function. Someone else will probably provide a more complete answer, but I have been using this pattern a lot lately so I thought it might help.

{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies #-}

module Test where

class MatrixNxN m v | m -> v where 
  mmul :: Num a => m a -> m a -> m a
  vmul :: Num a => m a -> v a -> v a

I suppose it's also possible to implement it using type classes. To avoid circular functional dependencies, we declare that the types of vectors and matrices depend on the scalar type. While it requires to have a newtype for scalars as well as for their corresponding vectors and matrices, it has also some advantages. In particular, we don't need to keep the constraint Num a at the declaration of mmul and vmul. We can leave it up to instance implementations what constraints they impose on their scalar values.

{-# LANGUAGE TypeFamilies #-}

class MatrixNxN a where
    data Matrix a :: *
    data Vector a :: *

    -- | Multiplication of two N-by-N-element matrices
    mmul :: Matrix a -> Matrix a -> Matrix a

    -- | Multiplication of an N-by-N-element matrix
    -- and an N-element column vector
    vmul :: Matrix a -> Vector a -> Vector a

-- List matrices on any kind of numbers:
newtype ListScalar a = ListScalar a
instance Num a => MatrixNxN (ListScalar a) where
    newtype Matrix (ListScalar a) = ListMatrix [[a]]
    newtype Vector (ListScalar a) = ListVector [a]

    vmul (ListMatrix mss)  (ListVector vs)   = ...
    mmul (ListMatrix m1ss) (ListMatrix m2ss) = ...


-- We can have matrices that have no `Num` instance for
-- their scalars, like Z2 implemented as `Bool`:
newtype ListBool = ListBool Bool
instance MatrixNxN ListBool where
    newtype Matrix ListBool = ListBoolMatrix [[Bool]]
    newtype Vector ListBool = ListBoolVector [Bool]

    vmul (ListBoolMatrix mss)  (ListBoolVector vs)   = ...
    mmul (ListBoolMatrix m1ss) (ListBoolMatrix m2ss) = ...
라이센스 : CC-BY-SA ~와 함께 속성
제휴하지 않습니다 StackOverflow
scroll top