Question

Is it possible to create a function which rewrites haskell code at compile time from outside of template haskell quotes?

For example:

differentiate :: Floating a => (a -> a) -> a -> (a,a)
differentiate = -- what goes here?

f :: Num a => a -> a
f = sin

g :: Num a => a -> (a,a)
g = differentiate f

and at compile time it would transform g to:

g x = (sin x, cos x)

I would like my "differentiate" function to be passed the AST of "f" and let me rewrite it before compiling. As far as I'm aware you cannot do this in template haskell without passing it the full syntax of the function i.e. "g = differentiate sin".

Thank you

Was it helpful?

Solution

You are talking about macros as in scheme. The answer is no. Haskell functions must be "referentially transparent", which means if you give it two denotationally equal arguments, the results must be denotationally equal. I.e., every f must have

f (1 + 1) = f 2

And if f were a macro this would not necessarily be so. However, this property is essential to the "purity" of the language -- what makes Haskell so nice to reason with and refactor.

However, there is extensive work on automatic differentiation in Haskell, none of which needs a macro system -- abstract modeling (and typeclasses to make it look nice) are all that is necessary.

OTHER TIPS

It is theoretically possible to do if you are willing to use your own set of math functions and numbers. What you would need to do is create a system of types that track how each function is computed. This would then be reflected in the type of expressions. Using either template haskell and the reify function, or using type class code, you could then generate at compile time the correct code.

Here is a hacky sample implementation using type classes. It works with sin, cos, constants, and addition. It would be a lot of work to implement the full set of operations. Also, there is a fair bit of duplication in the code, if you were planning on using such an approach, you should attempt to fix that problem:

{-# LANGUAGE ScopedTypeVariables, UndecidableInstances, FlexibleInstances, MultiParamTypeClasses, FunctionalDependencies #-}
module TrackedComputation where
import Prelude hiding (sin, cos, Num(..))
import Data.Function (on)
import qualified Prelude as P    

-- A tracked computation (TC for short).
-- It stores how a value is computed in the computation phantom variable
newtype TC newComp val = TC { getVal :: val }
    deriving (Eq)

instance (Show val) => Show (TC comp val) where
    show = show . getVal


data SinT comp = SinT
data CosT comp = CosT

data AddT comp1 comp2 = AddT

data ConstantT = ConstantT

data VariableT = VariableT

sin :: (P.Floating a) => TC comp1 a -> TC (SinT comp1) a
sin = TC . P.sin . getVal
cos :: (P.Floating a) => TC comp1 a -> TC (CosT comp1) a
cos = TC . P.cos . getVal

(+) :: (P.Num a) => TC comp1 a -> TC comp2 a -> TC (AddT comp1 comp2) a
(TC a) + (TC b) = TC $ (P.+) a b

toNum :: a -> TC ConstantT a
toNum = TC

class Differentiate comp compRIn compROut | comp compRIn -> compROut where
    differentiate :: P.Floating a => (TC VariableT a -> TC comp a) -> (TC compRIn a -> TC compROut a)


instance Differentiate ConstantT compIn ConstantT where
    differentiate _ = const $ toNum 0

instance Differentiate (SinT VariableT) compIn (CosT compIn) where
    differentiate _ = cos
instance Differentiate VariableT compIn (ConstantT) where
    differentiate _ = const $ toNum 1

instance (Differentiate add1 compIn add1Out, Differentiate add2 compIn add2Out) =>
    Differentiate (AddT add1 add2) compIn (AddT add1Out add2Out) where
    differentiate _ (val :: TC compROut a) = result where
        first = differentiate (undefined :: TC VariableT a -> TC add1 a) val :: TC add1Out a
        second = differentiate (undefined :: TC VariableT a -> TC add2 a) val :: TC add2Out a
        result = first + second

instance P.Num val => P.Num (TC ConstantT val) where
    (+) = (TC .) . ((P.+) `on` getVal)
    (*) = (TC .) . ((P.*) `on` getVal)
    abs = (TC) . ((P.abs) . getVal)
    signum = (TC) . ((P.signum) . getVal)
    fromInteger = TC . P.fromInteger

f x = sin x

g = differentiate f

h x = sin x + x + toNum 42 + x

test1 = f . toNum
test2 = g . toNum
test3 = differentiate h . toNum
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top