문제

I have a tree structure:

sealed trait Tree
case class Node(l: Tree, r: Tree) extends Tree
case class Leaf(n: Int) extends Tree

And a function that modifies the tree:

def scale(n: Int, tree: Tree): Tree = tree match {
  case l: Leaf => Leaf(l.n * n)
  case Node(l, r) => Node(scale(n, l), scale(n, r))
}

What should be the signature of the method above in order to return the appropriate subtype of Tree, and make the following line compile ?

scale(100, Leaf(1)).n // DOES NOT COMPILE

The closest answer I found so far is here and talks about F-Bounded Quantification. But I can't find a way to apply it to recursive structures such a as trees! Any ideas?

도움이 되었습니까?

해결책 2

Aside from OOP and overriding, here is another solution:

You can actually define a type-class Scaler that represents a function from (T, Int) to T where T is a subtype of Tree. In the companion object of Scaler you can then define the corresponding implicits which will do the actual work. The static type of the return value follows from the definition of the type-class. Note that I extended your definition of Node with two generic parameters. There may be an alternative by using abstract types instead. This is left as an exercise to the reader :)

Well, here we go:

import language.implicitConversions

sealed trait Tree
case class Node[L <: Tree, R <: Tree](l: L, r: R) extends Tree
case class Leaf(n: Int) extends Tree

trait Scaler[T <: Tree] extends ((T, Int) => T)

object Scaler {
  implicit object scalesLeafs extends Scaler[Leaf] {
    def apply(l: Leaf, s: Int) =
      Leaf(l.n * s)
  }

  implicit def scalesNodes[L <: Tree: Scaler, R <: Tree: Scaler] = new Scaler[Node[L,R]] {
    val ls = implicitly[Scaler[L]]
    val rs = implicitly[Scaler[R]]

    def apply(n: Node[L,R], s: Int) =
      Node(ls(n.l, s), rs(n.r, s))
  }
}

object demo extends App {
  def scale[T <: Tree](t: T, s: Int)(implicit ev: Scaler[T]): T =
    ev(t, s)

  val check1 = scale(Leaf(3), 5)
  val check2 = scale(Node(Leaf(3), Leaf(7)), 5)

  Console println check1.n
  Console println check2.l.n
  Console println check2.r.n
}

EDIT: As a sidenote: If you have a tree invariant like all right-hand side leafs are greater or equal than the ones on the left-hand side, you might want to consider implementing scale in terms of mapping over the tree and thereby effecively creating a new tree, because scaling with negative numbers could violate the invariant.

다른 팁

In scala there are often two options to define a function for a type. One is more functional (with pattern-matching like you did) the other one is more object oriented.

Unfortunately i have no solution for the version with pattern-matching (although i would be interested in one). But I have a solution for the more object oriented version, if this a an option for you:

sealed trait Tree {
  def scale(n: Int): Tree
}
case class Node(l: Tree, r: Tree) extends Tree {
  def scale(n: Int): Node = Node(l.scale(n), r.scale(n))
}
case class Leaf(n: Int) extends Tree {
  def scale(m: Int): Leaf = Leaf(n * m) 
}

Leaf(1).scale(100).n // does compile.

This solution is based on the fact that the return type of a method is covariant, so the return type in the implementation of scale in Node and Leaf can be a subtype of the return type of the abstract method scale in Tree.

라이센스 : CC-BY-SA ~와 함께 속성
제휴하지 않습니다 StackOverflow
scroll top