Question

I was trying to solve a puzzle in Haskell and had written the following code:

u 0 p = 0.0
u 1 p = 1.0
u n p = 1.0 + minimum [((1.0-q)*(s k p)) + (u (n-k) p) | k <-[1..n], let q = (1.0-p)**(fromIntegral k)]

s 1 p = 0.0
s n p = 1.0 + minimum [((1.0-q)*(s (n-k) p)) + q*((s k p) + (u (n-k) p)) | k <-[1..(n-1)], let q = (1.0-(1.0-p)**(fromIntegral k))/(1.0-(1.0-p)**(fromIntegral n))]

This code was terribly slow though. I suspect the reason for this is that the same things get calculated again and again. I therefore made a memoized version:

memoUa = array (0,10000) ((0,0.0):(1,1.0):[(k,mua k) | k<- [2..10000]])
mua n = (1.0) + minimum [((1.0-q)*(memoSa ! k)) + (memoUa ! (n-k)) | k <-[1..n], let q = (1.0-0.02)**(fromIntegral k)]

memoSa = array (0,10000) ((0,0.0):(1,0.0):[(k,msa k) | k<- [2..10000]])
msa n = (1.0) + minimum [((1.0-q) * (memoSa ! (n-k))) + q*((memoSa ! k) + (memoUa ! (n-k))) | k <-[1..(n-1)], let q = (1.0-(1.0-0.02)**(fromIntegral k))/(1.0-(1.0-0.02)**(fromIntegral n))]

This seems to be a lot faster, but now I get an out of memory error. I do not understand why this happens (the same strategy in java, without recursion, has no problems). Could somebody point me in the right direction on how to improve this code?

EDIT: I am adding my java version here (as I don't know where else to put it). I realize that the code isn't really reader-friendly (no meaningful names, etc.), but I hope it is clear enough.

public class Main {

public static double calc(double p) {
    double[] u = new double[10001];
    double[] s = new double[10001];

    u[0] = 0.0;
    u[1] = 1.0;
    s[0] = 0.0;
    s[1] = 0.0;

    for (int n=2;n<10001;n++) {
        double q = 1.0;
        double denom = 1.0;
        for (int k = 1; k <= n; k++ ) {
            denom = denom * (1.0 - p);
        }
        denom = 1.0 - denom;    
        s[n] = (double) n;
        u[n] = (double) n;
        for (int k = 1; k <= n; k++ ) {
            q = (1.0 - p) * q;
            if (k<n) {
                double qs = (1.0-q)/denom;
                double bs =  (1.0-qs)*s[n-k] + qs*(s[k]+ u[n-k]) + 1.0;
                if (bs < s[n]) {
                    s[n] = bs;
                }
            }
            double bu = (1.0-q)*s[k] + 1.0 + u[n-k];
            if (bu < u[n]) {
                u[n] = bu;
            }
        }
    }
    return u[10000];
}

public static void main(String[] args) {
    double s = 0.0;
    int i = 2;
    //for (int i = 1; i<51; i++) {
        s = s + calc(i*0.01);
    //}
    System.out.println("result = " + s);
}
}
Was it helpful?

Solution

I don't run out of memory when I run the compiled version, but there is a significant difference between how the Java version works and how the Haskell version works which I'll illustrate here.

The first thing to do is to add some important type signatures. In particular, you don't want Integer array indices, so I added:

memoUa :: Array Int Double
memoSa :: Array Int Double

I found these using ghc-mod check. I also added a main so that you can run it from the command line:

import System.Environment

main = do
  (arg:_) <- getArgs
  let n = read arg
  print $ mua n

Now to gain some insight into what's going on, we can compile the program using profiling:

ghc -O2 -prof memo.hs

Then when we invoke the program like this:

memo 1000 +RTS -s

we will get profiling output which looks like:

164.31333233347755
      98,286,872 bytes allocated in the heap
      29,455,360 bytes copied during GC
         657,080 bytes maximum residency (29 sample(s))
          38,260 bytes maximum slop
               3 MB total memory in use (0 MB lost due to fragmentation)

                                    Tot time (elapsed)  Avg pause  Max pause
  Gen  0       161 colls,     0 par    0.03s    0.03s     0.0002s    0.0011s
  Gen  1        29 colls,     0 par    0.03s    0.03s     0.0011s    0.0017s

  INIT    time    0.00s  (  0.00s elapsed)
  MUT     time    0.21s  (  0.21s elapsed)
  GC      time    0.06s  (  0.06s elapsed)
  RP      time    0.00s  (  0.00s elapsed)
  PROF    time    0.00s  (  0.00s elapsed)
  EXIT    time    0.00s  (  0.00s elapsed)
  Total   time    0.27s  (  0.27s elapsed)

  %GC     time      21.8%  (22.3% elapsed)

  Alloc rate    468,514,624 bytes per MUT second

  Productivity  78.2% of total user, 77.3% of total elapsed

Important things to pay attention to are:

  • maximum residency
  • Total time
  • %GC time (or Productivity)

Maximum residency is a measure of how much memory is needed by the program. %GC time the proportion of the time spent in garbage collection and Productivity is the complement (100% - %GC time).

If you run the program for various input values you will see a productivity of around 80%:

   n       Max Res.  Prod.   Time   Output
   2000     779,076  79.4%   1.10s  328.54535361588535
   4000   1,023,016  80.7%   4.41s  657.0894961398351
   6000   1,299,880  81.3%   9.91s  985.6071032981068
   8000   1,539,352  81.5%  17.64s  1314.0968411684714
  10000   1,815,600  81.7%  27.57s  1642.5891214360522

This means that about 20% of the run time is spent in garbage collection. Also, we see increasing memory usage as n increases.

It turns out we can dramatically improve productivity and memory usage by telling Haskell the order in which to evaluate the array elements instead of relying on lazy evaluation:

import Control.Monad (forM_)

main = do
  (arg:_) <- getArgs
  let n = read arg
  forM_ [1..n] $ \i -> mua i `seq` return ()
  print $ mua n

And the new profiling stats are:

   n        Max Res. Prod.   Time   Output
   2000     482,800  99.3%   1.31s  328.54535361588535
   4000     482,800  99.6%   5.88s  657.0894961398351
   6000     482,800  99.5%  12.09s  985.6071032981068
   8000     482,800  98.1%  21.71s  1314.0968411684714
  10000     482,800  96.1%  34.58s  1642.5891214360522

Some interesting observations here: productivity is up, memory usage is down (constant now over the range of inputs) but run time is up. This suggests that we forced more computations than we needed to. In an imperative language like Java you have to give an evaluation order so you would know exactly which computations need to be performed. It would interesting to see your Java code to see which computations it is performing.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top