Question

I have this scala function, that, because of performance problems, needs to be rewritten to be tail-recursive. The stack blows off when processing a not-so-big dataset, therefore I conclude it can be fixed only by making it tail-recursive. This is the function ::

private def carve2(x: Int, y: Int) {

    var rand: Int = random.nextInt(4)

    (1 to 4) foreach { _ =>
        val (x1, y1, x2, y2) = randomize(x, y, rand)

        if (canUpdate(x1, y1, x2, y2)) {
           maze(y1)(x1) = 0
           maze(y2)(x2) = 0
           carve2(x2, y2)
       }
       rand = (rand + 1) % 4
    }

}

The main problem is ::

> How to get rid of the foreach/for loop

For this I tried a recursive approach, but it is tricky to get the semantics right, especially because, after self-call inside the if block, the value of the rand var is mutated...

What I tried is to push the state modification of rand out of the body, and mutate it when passed as parameter::

def carve3(x: Int, y: Int, rand: Int)  {

  for (i <- 1 to 4) {

    val (x1, y1, x2, y2) = randomize(x, y, rand)

    if (canUpdate(x1, y1, x2, y2)) {
       maze(y1)(x1) = 0
       maze(y2)(x2) = 0
      if (i == 1) carve3(x2, y2, random.nextInt(4))
      else carve3(x2, y2, (rand + 1) % 4)
    }
 }
}

This does not work...

One more thing, I know this coding approach is not functional, but I am trying to get there...this is code I try to re-factor. Also, randomize and canUpdate functions are not relevant in this context.

Any suggestions? Thanks a lot in advance...

Was it helpful?

Solution 2

I'm assuming there that your profiler has identified this as a hotspot :)

Your problem, as you correctly infer, is the for-comprehension which adds an extra level of indirection for each time through the loop. The cost of this will be negligible for just 4 passes, but I can see that the method calls itself recursively...

What I wouldn't do is start by attempting to refactor to use tail-recursion, you have two better options you can try first:

  1. Change the foreach to a for-comprehension then compile with the optimize flag, this should cause the compiler to emit a while loop.

  2. If that doesn't help, convert the comprehension to a while loop by hand.

Then... and only then, you might want to try tail recursion to see if it's faster than the while loop. Chances are that it won't be.

UPDATE

I was heading towards Petr's solution anyway :)

So here's the full thing tidied up to be a bit more idiomatic:

private def carve2(x: Int, y: Int) {
  carve2(List((x, y)))
}

@tailrec private def carve2(coords: List[(Int,Int)]) = coords match {
  case (x,y) :: rest =>
    val rand: Int = random.nextInt(4)

    //note that this won't necessarily yield four pairs of co-ords
    //due to the guard condition
    val add = for {
      i <- 1 to 4
      (x1, y1, x2, y2) = randomize(x, y, (i + rand) % 4)
      if canUpdate(x1, y1, x2, y2)
    } yield {
      maze(y1)(x1) = 0
      maze(y2)(x2) = 0
      (x2, y2)
    }

    // tail recursion happens here...
    carve2(rest ++ add)

  case _ => 
}

OTHER TIPS

If a function calls itself multiple times, it can't be converted to a tail-recursive one. A function can be tail recursive only if the recursive call is the last thing it does, so that it doesn't need to remember anything.

A standard trick for solving this kind of problem is to use the heap instead of the stack by keeping the tasks to be computed in a queue. For example:

private def carve2(x: Int, y: Int) {
    carve2(Seq((x, y)));
}

@annotation.tailrec
private def carve2(coords: Seq[(Int,Int)]) {
    // pick a pair of coordinates
    val ((x, y), rest) = coords match {
      case Seq(x, xs@_*) => (x, xs);
      case _             => return; // empty
    }
    // This is functional approach, although perhaps slower.
    // Using a `while` loop instead would result in faster code.
    val rand: Int = random.nextInt(4)
    val add: Seq[(Int,Int)] =
      for(i <- 1 to 4;
          (x1, y1, x2, y2) = randomize(x, y, (i + rand) % 4);
          if (canUpdate(x1, y1, x2, y2))
         ) yield {
        // do something with `maze`
        // ...
        // return the new coordinates to be added to the queue
        (x2, y2)
      }
    // the tail-recursive call
    carve2(rest ++ add);
}

(I haven't tried to compile the code as the code sample you posted isn't self-contained.)

Here carve2 runs in a tail-recursive loop. Each pass possibly adds new coordinates to the end of the queue, and finishes when the queue is empty.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top