Question

Just playing with continuations. The goal is to create function which will receive another function as parameter, and execution amount - and return function which will apply parameter given amount times.

The implementation looks pretty obvious

def n_times[T](func:T=>T,count:Int):T=>T = {
  @tailrec
  def n_times_cont(cnt:Int, continuation:T=>T):T=>T= cnt match {
        case _ if cnt < 1 => throw new IllegalArgumentException(s"count was wrong $count")
        case 1 => continuation
        case _ => n_times_cont(cnt-1,i=>continuation(func(i)))
      }
  n_times_cont(count, func)
}

def inc (x:Int) = x+1

    val res1 = n_times(inc,1000)(1)  // Works OK, returns 1001

val res = n_times(inc,10000000)(1) // FAILS

But there is no problem - this code fails with StackOverflow error. Why there is no tail-call optimization here?

I'm running it in Eclipse using Scala plugin, and it returns Exception in thread "main" java.lang.StackOverflowError at scala.runtime.BoxesRunTime.boxToInteger(Unknown Source) at Task_Mult$$anonfun$1.apply(Task_Mult.scala:25) at Task_Mult$$anonfun$n_times_cont$1$1.apply(Task_Mult.scala:18)

p.s.

F# code, which is almost direct translation, is working without any issues

let n_times_cnt func count = 
    let rec n_times_impl count' continuation = 
        match count' with
        | _ when count'<1 -> failwith "wrong count"
        | 1 -> continuation
        | _ -> n_times_impl (count'-1) (func >> continuation) 
    n_times_impl count func

let inc x = x+1
let res = (n_times_cnt inc 10000000) 1

printfn "%o" res
Was it helpful?

Solution

The Scala standard library has an implementation of trampolines in scala.util.control.TailCalls. So revisiting your implementation... When you build up the nested calls with continuation(func(t)), those are tail calls, just not optimized by the compiler. So, let's build up a T => TailRec[T], where the stack frames will be replaced with objects in the heap. Then return a function that will take the argument and pass it to that trampolined function:

import util.control.TailCalls._
def n_times_trampolined[T](func: T => T, count: Int): T => T = {
  @annotation.tailrec
  def n_times_cont(cnt: Int, continuation: T => TailRec[T]): T => TailRec[T] = cnt match {
    case _ if cnt < 1 => throw new IllegalArgumentException(s"count was wrong $count")
    case 1 => continuation
    case _ => n_times_cont(cnt - 1, t => tailcall(continuation(func(t))))
  }
  val lifted : T => TailRec[T] = t => done(func(t))
  t => n_times_cont(count, lifted)(t).result
}

OTHER TIPS

I could be wrong here but I suspect that the n_times_cont inner function is properly converted to use tail recursion; the culprit's not there.

The stack is blown up by the collected continuation closures (i.e. the i=>continuation(func(i))) which make 10000000 nested calls to your inc method, once you apply the result of the main function.

in fact you can try

scala> val rs = n_times(inc, 1000000)
rs: Int => Int = <function1> //<- we're happy here

scala> rs(1) //<- this blows up the stack!

As an aside, you can rewrite

i=>continuation(func(i))

as

continuation compose func

for the sake of greater readability

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top