Question

I realize this is counter to the usual sense of SO questions, but the following code works even though I think it should not work. Below is a small Scala program that uses continuations with a while loop. According to my understanding of continuation passing style, this code should produce a stack overflow error by adding a frame to the stack for each iteration of the while loop. However, it works just fine.

import util.continuations.{shift, reset}


class InfiniteCounter extends Iterator[Int] {
    var count = 0
    var callback: Unit=>Unit = null
    reset {
        while (true) {
            shift {f: (Unit=>Unit) =>
                callback = f
            }
            count += 1
        }

    }

    def hasNext: Boolean = true

    def next(): Int = {
        callback()
        count
    }
}

object Experiment3 {

    def main(args: Array[String]) {
        val counter = new InfiniteCounter()
        println(counter.next())
        println("Hello")
        println(counter.next())
        for (i <- 0 until 100000000) {
            counter.next()
        }
        println(counter.next())
    }

}

The output is:

1
Hello
2
100000003

My question is: why is there no stack overflow? Is the Scala compiler doing tail call optimization (which I thought it couldn't do with continuations) or is there some other thing going on?

(This experiment is on github along with the sbt configuration needed to run it: https://github.com/jcrudy/scala-continuation-experiments. See commit 7cec9befcf58820b925bb222bc25f2a48cbec4a6)

Was it helpful?

Solution

The reason that you don't get a stack overflow here because the way you're using shift and callback() is acting like a trampoline.

Each time the execution thread reaches the shift construct, it sets callback equal to the current continuation (a closure), and then immediately returns Unit to the calling context. When you call next() and invoke callback(), you execute the continuation closure, which just executes count += 1, then jumps back to the beginning of the loop and executes the shift again.

One of the key benefits of the CPS transformation is that it capture the flow of control in the continuation rather than using the stack. When you set callback = f on each "iteration" you're overwriting your only reference to the previous continuation/state of the function, and that allows it to be garbage collected.

The stack here only ever reaches a depth of a few frames (it's probably around 10 because of all the nested closures). Each time you execute the shift it captures the current state in a closure (in the heap), and then the stack unrolls back to your for expression.

I feel like a diagram would make this clearer—but stepping through the code with your debugger would probably be just as useful. I think the key point here is, since you've essentially built a trampoline, you'll never blow the stack.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top