Question

Why does haskell require multiple rewrite rules depending on the function composition technique and length? Is there a way to avoid this?

For example, given the following code...

{-# RULES
"f/f" forall a. f ( f a ) = 4*a
  #-}
f a = 2 * a

this works for

test1 = f ( f 1 )

however we need to add a rule for

test2 = f . f $ 1

and

test3 = f $ f 1

leaving us with the following rules

{-# RULES
"f/f1" forall a. f ( f a ) = 4 * a
"f/f2" forall a. f . f $ a  = 4 * a
"f/f3" forall a. f $ f $ a  = 4 * a
   #-}

However, when we string these together or use some other forms of composition the rules do not fire.

test4 = f . f . f $ 1
test5 = f $ f $ f $ 1
test6 = f $ 1

Why is this? Do I have to write rewrite rules for every possible implementation?

Was it helpful?

Solution

The rule doesn't fire in many cases because the very simple function f is inlined before the rule had a chance to fire. If you delay the inlining,

{-# INLINE [1] f #-}

the rule

{-# RULES "f/f" forall a. f (f a) = 4*a #-}

should fire for all these cases (worked here with 7.2.2 and 7.4.1).

The reason is that the rule matcher is not overly elaborate, it matches only expressions having the syntactic form of the rule (not entirely true, the rule body undergoes some normalisation too). The expressions f $ f 3 or f . f $ 4 do not match the syntactic form of the rule. For the rule to match, some rewriting has to take place, ($) and (.) have to be inlined before the rule matches the expression. But if you do not prevent f from being inlined in the first phase of the simplifier, it gets replaced by its body in the same run as ($) and (.) are inlined, so in the next iteration, the simplifier doesn't see f anymore, it only sees 2*(2*x), which doesn't match the rule.

OTHER TIPS

I would have thought that this would work by default, but you can add two more rewrite rules to make ./$ reduced to lambdas/application, so that this will always match:

{-# RULES
"f/f" forall a. f ( f a ) = 4*a

"app" forall f x. f $ x = f x
"comp" forall f g. f . g = (\x -> f (g x))
  #-}

f a = 3 * a -- make this 3*a so can see the difference

A test:

main = do
    print (f . f $ 1)
    print (f (f 1))
    print (f $ f 1)
    print (f $ f $ 1)
    print (f $ f $ f $ f $ 1)
    print (f . f . f . f $ 1)
    print (f $ f $ f $ 1)
    print (f . f . f $ 1)
    print (f $ 1)

Output:

4
4
4
4
16
16
12
12
3

This will also work in some (but not all) more obscure cases, due to other rewrite rules. For example, all of these will work:

mapf x = map f $ map f $ [x]
mapf' x = map (f.f) $ [x]
mapf'' x = map (\x -> f (f x)) $ [x]
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top