The Scala standard library has an implementation of trampolines in scala.util.control.TailCalls
. So revisiting your implementation... When you build up the nested calls with continuation(func(t))
, those are tail calls, just not optimized by the compiler. So, let's build up a T => TailRec[T]
, where the stack frames will be replaced with objects in the heap. Then return a function that will take the argument and pass it to that trampolined function:
import util.control.TailCalls._
def n_times_trampolined[T](func: T => T, count: Int): T => T = {
@annotation.tailrec
def n_times_cont(cnt: Int, continuation: T => TailRec[T]): T => TailRec[T] = cnt match {
case _ if cnt < 1 => throw new IllegalArgumentException(s"count was wrong $count")
case 1 => continuation
case _ => n_times_cont(cnt - 1, t => tailcall(continuation(func(t))))
}
val lifted : T => TailRec[T] = t => done(func(t))
t => n_times_cont(count, lifted)(t).result
}