Question

I am trying to create a memoized version of Factorial function. When I call factMemoized(4), it computes the factorial of 4 for the first time and stores it in a Map. When I call factMemoized(4) again, it now gives the stored result instead of recomputing it again. This works as expected. But, when I call factMemoized(3), it recomputes the value, eventhough it had computed fact(3) as part of computing fact(4). Is there any way to make sure that even the values computed as part of recursive calls will be stored in the map without adding the memoization function within the fact() function?

import java.util.HashMap;
import java.util.Map;


public class MemoizeBetter {

public static <F, T> Function<F, T> memoize(final Function<F, T> inputFunction) {
    return new Function<F, T>() {
      // Holds previous results
      Map<F, T> memoization = new HashMap<F, T>();

      @Override
      public T apply(final F input) {
        // Check for previous results
        if (!memoization.containsKey(input)) {
          // None exists, so compute and store a new one

          memoization.put(input, inputFunction.apply(input));
        }else{
            System.out.println("Cache hit:"+input);
        }

        // At this point a result is guaranteed in the memoization
        return memoization.get(input);
      }
    };
  }

public static void main(String args[]){


final Function<Integer, Integer> fact = new Function<Integer, Integer>() {
      @Override
      public Integer apply(final Integer input) {
        System.out.println("Fact: " + input);
        if(input == 1)
            return 1;
        else return input * apply(input -1);

      }
    };

    final Function<Integer, Integer> factMemoized = MemoizeBetter.memoize(fact);

    System.out.println("Result:"+ factMemoized.apply(1));
    System.out.println("Result:"+factMemoized.apply(2));
    System.out.println("Result:"+factMemoized.apply(3));
    System.out.println("Result:"+factMemoized.apply(2));
    System.out.println("Result:"+factMemoized.apply(4));
    System.out.println("Result:"+factMemoized.apply(1));    }    
}

interface Function<F,T>{
    T apply(F input);
}
Was it helpful?

Solution

The issue is that your Factorial function does not call recursively into the memoized version of the function.

To fix this, there are a few options.

  1. You could parameterize your Factorial function and give it reference to the Function it should call recursively. In the unmemoized case, this will be the function itself; in the memoized case, this will be the memoizing wrapper.

  2. You could implement memoization through extending the Factorial function class, overriding, rather than delegating to, the unmemoized apply(). This is difficult to do ad-hoc, but there are utilities out there to create subclasses dynamically (this is a common way of implementing AOP, for example).

  3. You could give the base function full knowledge of the memoization to start with.

Here's the gist of the first option:

interface MemoizableFunction<I, O> extends Function<I, O> {

    //in apply, always recurse to the "recursive Function"
    O apply(I input);

    setRecursiveFunction(Function<? super I, ? extends O>);
}

final MemoizableFunction<Integer, Integer> fact = new MemoizableFunction<Integer, Integer>() {

  private Function<Integer, Integer> recursiveFunction = this;

  @Override
  public Integer apply(final Integer input) {
    System.out.println("Fact: " + input);
    if(input == 1)
        return 1;
    else return input * recursiveFunction.apply(input -1);
  }

  //...
};

OTHER TIPS

Another way to solve this problem would be to use an array to store the already computed fibonacci values. The way it works is that if the fibonacci for the 'n'th position exists at 'n'th index of the array then this value is not calculated again and simply picked from the array.

However, if the value is not present in the array at the 'n'th position then its calculated. Given below is code for such a method fibonacci() -

public static long fibonacci(long n){
    long fibValue=0;
    if(n==0 ){
        return 0;
    }else if(n==1){
        return 1;
    }else if(fibArray[(int)n]!=0){
        return fibArray[(int)n];    
    }
    else{
        fibValue=fibonacci(n-1)+fibonacci(n-2);
        fibArray[(int) n]=fibValue;
        return fibValue;
    }
}

Note that this method uses a global(class level) static array fibArray[]. To have a look at the whole code with explanation you can also see the following - http://www.javabrahman.com/gen-java-programs/recursive-fibonacci-in-java-with-memoization/

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top