As other said, the problem is that you are memoizing the top-level calls to the function but you are not using that information to avoid recomputing recursive calls. Let's see how we could make sure that the recursive calls are also cached.
First, we have the obvious import.
import Data.Function.Memoize
And then we are going to describe the call graph of the function fib. To do that, we write a higher order function which uses its argument instead of a recursive call:
fib_rec :: (Int -> Int) -> Int -> Int
fib_rec f 0 = 0
fib_rec f 1 = 1
fib_rec f n = f (n - 1) + f (n - 2)
Now, what we want is an operator which takes such an higher order function and somehow "ties the knot" making sure that the recursive calls are indeed the function we are interested in. We could write fix
:
fix :: ((a -> b) -> (a -> b)) -> (a -> b)
fix f = f (fix f)
but then we are back to an inefficient solution: we never memoize anything. The alternative solution is to write something that looks like fix but makes sure that memoization happens all over the place. Let's call it memoized_fix
:
memoized_fix :: Memoizable a => ((a -> b) -> (a -> b)) -> (a -> b)
memoized_fix = memoize . go
where go f = f (memoized_fix f)
And now you have your efficient function fib_mem
:
fib_mem :: Int -> Int
fib_mem = memoized_fix fib_rec
You don't even have to write memoized_fix
yourself, it's part of the memoize package.