Question

I have this fairly simple function to compute the mean of elements of a big list, using two accumulators to hold the sum so far and the count so far:

mean = go 0 0
    where
      go s l []     = s / fromIntegral l
      go s l (x:xs) = go (s+x) (l+1) xs

main = do
  putStrLn (show (mean [0..10000000]))

Now, in a strict language, this would be tail-recursive, and there would be no problem. However, as Haskell is lazy, my googling has led me to understand that (s+x) and (l+1) will be passed down the recursion as thunks. So this whole thing crashes and burns:

Stack space overflow: current size 8388608 bytes.

After further googling, I found seq and $!. Which it seems I don't understand because all my attempts at using them in this context proved futile, with error messages saying something about infinite types.

Finally I found -XBangPatterns, which solves it all by changing the recursive call:

go !s !l (x:xs) = go (s+x) (l+1) xs

But I'm not happy with this, as -XBangPatterns is currently an extension. I would like to know how to make the evaluation strict without the use of -XBangPatterns. (And maybe learn something too!)

Just so you understand my lack of understanding, here's what I tried (the only try that compiled, that is):

go s l (x:xs) = go (seq s (s+x)) (seq l (l+1)) xs

From what I could understand, seq should here force the evaluation of the s and l argument, thus avoiding the problem caused by thunks. But I still get a stack overflow.

Was it helpful?

Solution

I've written extensively on this:

Firstly, yes, if you want to require strict evaluation of the accumulators use seq and stay in Haskell 98:

mean = go 0 0
  where
    go s l []     = s / fromIntegral l
    go s l (x:xs) = s `seq` l `seq`
                      go (s+x) (l+1) xs

main = print $ mean [0..10000000]

*Main> main
5000000.0

Secondly: strictness analysis will kick in if you give some type annotations, and compile with -O2:

mean :: [Double] -> Double
mean = go 0 0
 where
  go :: Double -> Int -> [Double] -> Double
  go s l []     = s / fromIntegral l
  go s l (x:xs) = go (s+x) (l+1) xs

main = print $ mean [0..10000000]

$ ghc -O2 --make A.hs
[1 of 1] Compiling Main             ( A.hs, A.o )
Linking A ...

$ time ./A
5000000.0
./A  0.46s user 0.01s system 99% cpu 0.470 total

Because 'Double' is a wrapper over the strict atomic type Double#, with optimizations on, and a precise type, GHC runs strictness analysis and infers that the strict version will be ok.

import Data.Array.Vector

main = print (mean (enumFromToFracU 1 10000000))

data Pair = Pair !Int !Double

mean :: UArr Double -> Double   
mean xs = s / fromIntegral n
  where
    Pair n s       = foldlU k (Pair 0 0) xs
    k (Pair n s) x = Pair (n+1) (s+x)

$ ghc -O2 --make A.hs -funbox-strict-fields
[1 of 1] Compiling Main             ( A.hs, A.o )
Linking A ...

$ time ./A
5000000.5
./A  0.03s user 0.00s system 96% cpu 0.038 total

As described in the RWH chapter above.

OTHER TIPS

The seq function forces evaluation of the first parameter once the function is called. When you pass seq s (s+x) as a parameter the seq function is not called immediately, because there is no need to evaluate the value of that parameter. You want the call to seq to be evaluated before the recursive call, so that that in turn can force its parameter to be evaluated.

Usually this is done link this:

 go s l (x:xs) = s `seq` l `seq` go (s+x) (l+1) xs

This is a syntactic variation of seq s (seq l (go (s+x) (l+1) xs)). Here the calls to seq are the outermost function calls in the expression. Because of Haskell's laziness this causes them to be evaluated first: seq is called with the still unevaluated parameters s and seq l (go (s+x) (l+1) xs), evaluating the parameters is deferred to the point where somebody actually tries to access their values.

Now seq can force its first parameter to be evaluated before returning the rest of the expression. Then the next step in the evaluation would be the second seq. If the calls to seq are buried somewhere in some parameter they might not be executed for a long time, defeating their purpose.

With the changed positions of the seqs the program executes fine, without using excessive amounts of memory.

Another solution to the problem would be to simply enable optimizations in GHC when the program is compiled (-O or -O2). The optimizer recognizes the dispensable laziness and produces code that doesn't allocate unnecessary memory.

You are right in your understanding that seq s (s+x) forces the evaluation of s. But it doesn't force s+x, therefore you're still building up thunks.

By using $! you can force the evaluation of the addition (two times, for both arguments). This achieves the same effect as using the bang patterns:

mean = go 0 0
 where
    go s l []     = s / fromIntegral l
    go s l (x:xs) = ((go $! s+x) $! l+1) xs

The use of the $! function will translate the go $! (s+x) to the equivalent of:

let y = s+x 
in seq y (go y)

Thus y is first forced into weak head normal form, which means that the outermost function is applied. In the case of y, the outermost function is +, thus y is fully evaluated to a number before being passed to go.


Oh, and you probably got the infinite type error message because you didn't have the parenthesis in the right place. I got the same error when I first wrote your program down :-)

Because the $! operator is right associative, without parenthesis go $! (s+x) $! (l+1) means the same as: go $! ((s+x) $! (l+1)), which is obviously wrong.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top