Frage

I would like to program a Scala macro that takes an instance of a case class as argument. All objects that can be passed to the macro have to implement a specific marker trait.

The following snippet shows the marker trait and two example case classes implementing it:

trait Domain
case class Country( id: String, name: String ) extends Domain
case class Town( id: String, longitude: Double, latitude: Double ) extends Domain

Now, I would like to write the following code using macros to avoid the heaviness of runtime reflection and its thread unsafety:

object Test extends App {

  // instantiate example domain object
  val myCountry = Country( "CH", "Switzerland" )

  // this is a macro call
  logDomain( myCountry )
} 

The macro logDomain is implemented in a different project and looks similar to:

object Macros {
  def logDomain( domain: Domain ): Unit = macro logDomainMacroImpl

  def logDomainMacroImpl( c: Context )( domain: c.Expr[Domain] ): c.Expr[Unit] = {
    // Here I would like to introspect the argument object but do not know how?
    // I would like to generate code that prints out all val's with their values
  }
}

The macro's purpose should be to generate code that - at runtime - outputs all values (id and name) of the given object and prints them as shown next:

id (String) : CH
name (String) : Switzerland

To achieve this, I would have to dynamically inspect the passed type argument and determine its members (vals). Then I would have to generate an AST representing the code that creates the log output. The macro should work regardless of what specific object implementing the marker trait "Domain" is passed to the macro.

At this point I am lost. I would appreciate if someone could give me a starting point or point me to some documentation? I am relatively new to Scala and have not found a solution in the Scala API docs or the Macro guide.

War es hilfreich?

Lösung

Listing the accessors of a case class is such a common operation when you're working with macros that I tend to keep a method like this around:

def accessors[A: u.WeakTypeTag](u: scala.reflect.api.Universe) = {
  import u._

  u.weakTypeOf[A].declarations.collect {
    case acc: MethodSymbol if acc.isCaseAccessor => acc
  }.toList
}

This will give us all the case class accessor method symbols for A, if it has any. Note that I'm using the general reflection API here—there's no need to make this macro-specific yet.

We can wrap this method up with some other convenience stuff:

trait ReflectionUtils {
  import scala.reflect.api.Universe

  def accessors[A: u.WeakTypeTag](u: Universe) = {
    import u._

    u.weakTypeOf[A].declarations.collect {
      case acc: MethodSymbol if acc.isCaseAccessor => acc
    }.toList
  }

  def printfTree(u: Universe)(format: String, trees: u.Tree*) = {
    import u._

    Apply(
      Select(reify(Predef).tree, "printf"),
      Literal(Constant(format)) :: trees.toList
    )
  }
}

And now we can write the actual macro code pretty concisely:

trait Domain

object Macros extends ReflectionUtils {
  import scala.language.experimental.macros
  import scala.reflect.macros.Context

  def log[D <: Domain](domain: D): Unit = macro log_impl[D]
  def log_impl[D <: Domain: c.WeakTypeTag](c: Context)(domain: c.Expr[D]) = {
    import c.universe._

    if (!weakTypeOf[D].typeSymbol.asClass.isCaseClass) c.abort(
      c.enclosingPosition,
      "Need something typed as a case class!"
    ) else c.Expr(
      Block(
        accessors[D](c.universe).map(acc =>
          printfTree(c.universe)(
            "%s (%s) : %%s\n".format(
              acc.name.decoded,
              acc.typeSignature.typeSymbol.name.decoded
            ),
            Select(domain.tree.duplicate, acc.name)
          )
        ),
        c.literalUnit.tree
      )
    )
  }
}

Note that we still need to keep track of the specific case class type we're dealing with, but type inference will take care of that at the call site—we won't need to specify the type parameter explicitly.

Now we can open a REPL, paste in your case class definitions, and then write the following:

scala> Macros.log(Town("Washington, D.C.", 38.89, 77.03))
id (String) : Washington, D.C.
longitude (Double) : 38.89
latitude (Double) : 77.03

Or:

scala> Macros.log(Country("CH", "Switzerland"))
id (String) : CH
name (String) : Switzerland

As desired.

Andere Tipps

From what I can see, you need to solve two problems: 1) get the necessary information from the macro argument, 2) generate trees that represent the code you need.

In Scala 2.10 these things are done with the reflection API. Follow Is there a tutorial on Scala 2.10's reflection API yet? to see what documentation is available for it.

import scala.reflect.macros.Context
import language.experimental.macros

trait Domain
case class Country(id: String, name: String) extends Domain
case class Town(id: String, longitude: Double, latitude: Double) extends Domain

object Macros {
  def logDomain(domain: Domain): Unit = macro logDomainMacroImpl

  def logDomainMacroImpl(c: Context)(domain: c.Expr[Domain]): c.Expr[Unit] = {
    import c.universe._

    // problem 1: getting the list of all declared vals and their types
    //   * declarations return declared, but not inherited members
    //   * collect filters out non-methods
    //   * isCaseAccessor only leaves accessors of case class vals
    //   * typeSignature is how you get types of members
    //     (for generic members you might need to use typeSignatureIn)
    val vals = typeOf[Country].declarations.toList.collect{ case sym if sym.isMethod => sym.asMethod }.filter(_.isCaseAccessor)
    val types = vals map (_.typeSignature)

    // problem 2: generating the code which would print:
    // id (String) : CH
    // name (String) : Switzerland
    //
    // usually reify is of limited usefulness
    // (see https://stackoverflow.com/questions/13795490/how-to-use-type-calculated-in-scala-macro-in-a-reify-clause)
    // but here it's perfectly suitable
    // a subtle detail: `domain` will be possibly used multiple times
    // therefore we need to duplicate it
    val stmts = vals.map(v => c.universe.reify(println(
      c.literal(v.name.toString).splice +
      "(" + c.literal(v.returnType.toString).splice + ")" +
      " : " + c.Expr[Any](Select(domain.tree.duplicate, v)).splice)).tree)

    c.Expr[Unit](Block(stmts, Literal(Constant(()))))
  }
}
Lizenziert unter: CC-BY-SA mit Zuschreibung
Nicht verbunden mit StackOverflow
scroll top