Question

In researching how to do Memoization in Scala, I've found some code I didn't grok. I've tried to look this particular "thing" up, but don't know by what to call it; i.e. the term by which to refer to it. Additionally, it's not easy searching using a symbol, ugh!

I saw the following code to do memoization in Scala here:

case class Memo[A,B](f: A => B) extends (A => B) {
  private val cache = mutable.Map.empty[A, B]
  def apply(x: A) = cache getOrElseUpdate (x, f(x))
}

And it's what the case class is extending that is confusing me, the extends (A => B) part. First, what is happening? Secondly, why is it even needed? And finally, what do you call this kind of inheritance; i.e. is there some specific name or term I can use to refer to it?

Next, I am seeing Memo used in this way to calculate a Fibanocci number here:

  val fibonacci: Memo[Int, BigInt] = Memo {
    case 0 => 0
    case 1 => 1
    case n => fibonacci(n-1) + fibonacci(n-2)
  }

It's probably my not seeing all of the "simplifications" that are being applied. But, I am not able to figure out the end of the val line, = Memo {. So, if this was typed out more verbosely, perhaps I would understand the "leap" being made as to how the Memo is being constructed.

Any assistance on this is greatly appreciated. Thank you.

Was it helpful?

Solution 4

This answer is a synthesis of the partial answers provided by both 0__ and Nicolas Rinaudo.

Summary:

There are many convenient (but also highly intertwined) assumptions being made by the Scala compiler.

  1. Scala treats extends (A => B) as synonymous with extends Function1[A, B] (ScalaDoc for Function1[+T1, -R])
  2. A concrete implementation of Function1's inherited abstract method apply(x: A): B must be provided; def apply(x: A): B = cache.getOrElseUpdate(x, f(x))
  3. Scala assumes an implied match for the code block starting with = Memo {
  4. Scala passes the content between {} started in item 3 as a parameter to the Memo case class constructor
  5. Scala assumes an implied type between {} started in item 3 as PartialFunction[Int, BigInt] and the compiler uses the "match" code block as the override for the PartialFunction method's apply() and then provides an additional override for the PartialFunction's method isDefinedAt().

Details:

The first code block defining the case class Memo can be written more verbosely as such:

case class Memo[A,B](f: A => B) extends Function1[A, B] {    //replaced (A => B) with what it's translated to mean by the Scala compiler
  private val cache = mutable.Map.empty[A, B]
  def apply(x: A): B = cache.getOrElseUpdate(x, f(x))  //concrete implementation of unimplemented method defined in parent class, Function1
}

The second code block defining the val fibanocci can be written more verbosely as such:

lazy val fibonacci: Memo[Int, BigInt] = {
  Memo.apply(
    new PartialFunction[Int, BigInt] {
      override def apply(x: Int): BigInt = {
        x match {
          case 0 => 0
          case 1 => 1
          case n => fibonacci(n-1) + fibonacci(n-2)
        }
      }
      override def isDefinedAt(x: Int): Boolean = true
    }
  )
}

Had to add lazy to the second code block's val in order to deal with a self-referential problem in the line case n => fibonacci(n-1) + fibonacci(n-2).

And finally, an example usage of fibonacci is:

val x:BigInt = fibonacci(20) //returns 6765 (almost instantly)

OTHER TIPS

A => B is short for Function1[A, B], so your Memo extends a function from A to B, most prominently defined through method apply(x: A): B which must be defined.

Because of the "infix" notation, you need to put parentheses around the type, i.e. (A => B). You could also write

case class Memo[A, B](f: A => B) extends Function1[A, B] ...

or

case class Memo[A, B](f: Function1[A, B]) extends Function1[A, B] ...

To complete 0_'s answer, fibonacci is being instanciated through the apply method of Memo's companion object, generated automatically by the compiler since Memo is a case class.

This means that the following code is generated for you:

object Memo {
  def apply[A, B](f: A => B): Memo[A, B] = new Memo(f)
}

Scala has special handling for the apply method: its name needs not be typed when calling it. The two following calls are strictly equivalent:

Memo((a: Int) => a * 2)

Memo.apply((a: Int) => a * 2)

The case block is known as pattern matching. Under the hood, it generates a partial function - that is, a function that is defined for some of its input parameters, but not necessarily all of them. I'll not go in the details of partial functions as it's beside the point (this is a memo I wrote to myself on that topic, if you're keen), but what it essentially means here is that the case block is in fact an instance of PartialFunction.

If you follow that link, you'll see that PartialFunction extends Function1 - which is the expected argument of Memo.apply.

So what that bit of code actually means, once desugared (if that's a word), is:

lazy val fibonacci: Memo[Int, BigInt] = Memo.apply(new PartialFunction[Int, BigInt] {
  override def apply(v: Int): Int =
    if(v == 0)      0
    else if(v == 1) 1
    else            fibonacci(v - 1) + fibonacci(v - 2)

  override isDefinedAt(v: Int) = true
})

Note that I've vastly simplified the way the pattern matching is handled, but I thought that starting a discussion about unapply and unapplySeq would be off topic and confusing.

I am the original author of doing memoization this way. You can see some sample usages in that same file. It also works really well when you want to memoize on multiple arguments too because of the way Scala unrolls tuples:

    /**
     * @return memoized function to calculate C(n,r) 
     * see http://mathworld.wolfram.com/BinomialCoefficient.html
     */
     val c: Memo[(Int, Int), BigInt] = Memo {
        case (_, 0) => 1
        case (n, r) if r > n/2 => c(n, n-r)
        case (n, r) => c(n-1, r-1) + c(n-1, r)
     }
     // note how I can invoke a memoized function on multiple args too
     val x = c(10, 3) 

One more word about this extends (A => B): the extends here is not required, but necessary if the instances of Memo are to be used in higher order functions or situations alike.

Without this extends (A => B), it's totally fine if you use the Memo instance fibonacci in just method calls.

case class Memo[A,B](f: A => B) {
    private val cache = scala.collection.mutable.Map.empty[A, B]
    def apply(x: A):B = cache getOrElseUpdate (x, f(x))
}
val fibonacci: Memo[Int, BigInt] = Memo {
    case 0 => 0
    case 1 => 1
    case n => fibonacci(n-1) + fibonacci(n-2)
}

For example:

Scala> fibonacci(30)
res1: BigInt = 832040

But when you want to use it in higher order functions, you'd have a type mismatch error.

Scala> Range(1, 10).map(fibonacci)
<console>:11: error: type mismatch;
 found   : Memo[Int,BigInt]
 required: Int => ?
              Range(1, 10).map(fibonacci)
                               ^

So the extends here only helps to ID the instance fibonacci to others that it has an apply method and thus can do some jobs.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top