Question

I am going through the following Shift/Reset tutorial: http://www.is.ocha.ac.jp/~asai/cw2011tutorial/main-e.pdf.

I got pretty good results so far in translating the OchaCaml examples to Scala (all the way up to section 2.11). But now I seem to have hit a wall. The code from the paper from Asai/Kiselyov defines the following recursive function (this is OchaCaml - I think):

(* a_normal : term_t => term_t *)
let rec a_normal term = match term with
    Var (x) -> Var (x)
  | Lam (x, t) -> Lam (x, reset (fun () -> a_normal t))
  | App (t1, t2) ->
      shift (fun k ->
        let t = gensym () in (* generate fresh variable *)
        App (Lam (t, (* let expression *)
                  k (Var (t))), (* continue with new variable *)
             App (a_normal t1, a_normal t2))) ;;

The function is supposed to A-normalize a lambda expression. This is my Scala translation:

// section 2.11
object ShiftReset extends App {

  sealed trait Term
  case class Var(x: String) extends Term
  case class Lam(x: String, t: Term) extends Term
  case class App(t1: Term, t2: Term) extends Term

  val gensym = {
    var i = 0
    () => { i += 1; "t" + i }
  }

  def a_normal(t: Term): Term@cps[Term] = t match {
    case Var(x) => Var(x)
    case Lam(x, t) => Lam(x, reset(a_normal(t)))
    case App(t1, t2) => shift{ (k:Term=>Term) =>
      val t = gensym()
      val u = Lam(t, k(Var(t)))
      val v = App(a_normal(t1), a_normal(t2))
      App(u, v): Term
    }
  }

}

I am getting the following compilation error:

 found   : ShiftReset.Term @scala.util.continuations.cpsSynth 
             @scala.util.continuations.cpsParam[ShiftReset.Term,ShiftReset.Term]
 required: ShiftReset.Term
    case App(t1, t2) => shift{ (k:Term=>Term) =>
                                              ^
one error found

I think the plugin is telling me that it cannot deal with nested shift... Is there a way to make the code compile (either a basic error I've overlooked or some workaround)? I have tried to convert the pattern match to a if/else if/else and to introduce more local variables but I got the same error.

Alternately, would I have more luck using Haskell and the Cont monad (+ the shift/reset from here) or is there going to be the same type of limitation with nested shift? I am adding the Haskell tag as well, since I don't mind switching to Haskell to go through the rest of the tutorial.

Edit: Thanks to James who figured out which line the continuation plugin could not deal with and how to tweak it, it now works. Using the version in his answer and the following formatting code:

def format(t: Term): String = t match {
  case Var(x) => s"$x"
  case Lam(x, t) => s"\\$x.${format(t)}"
  case App(Lam(x, t1), t2) => s"let $x = ${format(t2)} in ${format(t1)}"
  case App(Var(x), Var(y)) => s"$x$y"
  case App(Var(x), t2) => s"$x (${format(t2)})"
  case App(t1, t2) => s"(${format(t1)}) (${format(t2)})"
}

I get the output that the paper mentions (though I don't grasp yet how the continuation actually manages it):

sCombinator:
\x.\y.\z.(xz) (yz)
reset{a_normal(sCombinator)}:
\x.\y.\z.let t1 = xz in let t2 = yz in let t3 = t1t2 in t3
Was it helpful?

Solution

The problem is the line:

val v = App(a_normal(t1), a_normal(t2))

I'm not certain, but I think the type inferencer is getting confused by the fact that a_normal returns a Term@cps[Term], but we're inside a shift so the continuation isn't annotated the same way.

It will compile if you pull the line up out of the shift block:

def a_normal(t: Term): Term@cps[Term] = t match {
  case Var(x) => Var(x)
  case Lam(x, t) => Lam(x, reset(a_normal(t)))
  case App(t1, t2) =>
    val v = App(a_normal(t1), a_normal(t2))
    shift{ (k:Term=>Term) =>
      val t = gensym()
      val u = Lam(t, k(Var(t)))
      App(u, v): Term
    }
}

Regarding nested shifts in general, you can definitely do this if each nested continuation has a compatible type:

object NestedShifts extends App {

  import scala.util.continuations._

  def foo(x: Int): Int@cps[Unit] = shift { k: (Int => Unit) =>   
    k(x)
  }

  reset {
    val x = foo(1)
    println("x: " + x)

    val y = foo(2)
    println("y: " + y)

    val z = foo(foo(3))
    println("z: " + z)
  }
}

This program prints to stdout:

x: 1
y: 2
z: 3
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top