Question

Following the tutorial at:

https://github.com/julienrf/lms-tutorial/wiki

I have succesfully compiled and understood most of the code. While the concepts and examples are extremely sexy, i immediately wanted to change the code from hard-coding "Double" as the set of scalars into anything implementing the type-class Numeric[T] from the standard library. I was, however, unsuccessful.

I tried things like adding the following code to the LinearAlgebraExp trait:

  override type Scalar = Double
  override type Vector = Seq[Scalar]
  implicit val num:Numeric[Scalar] = implicitly[Numeric[Scalar]]

Which did not work. My next (probably better idea) was to add implicit numeric arguments to any implementing function (i.e. all actual implementations of vector_scale). I still couldn't quite wrap my head around it due to exotic compile time errors.

Is there any support in LMS currently for using numeric types? Looking in the source of LMS, it seems like it might actually be a mess right now.

Was it helpful?

Solution

I reached some working code by adding a Manifest to the contextbounds in my program:

import scala.virtualization.lms.common._

trait LinearAlgebra extends Base {

  type Vector[T]
  type Matrix[T]

  def vector_scale[T:Manifest:Numeric](v: Rep[Vector[T]], k: Rep[T]): Rep[Vector[T]]

  // Tensor product between 2 matrices
  def tensor_prod[T:Manifest:Numeric](A:Rep[Matrix[T]],B:Rep[Matrix[T]]):Rep[Matrix[T]]

  // Concrete syntax
  implicit class VectorOps[T:Numeric:Manifest](v: Rep[Vector[T]]) {
    def *(k: Rep[T]):Rep[Vector[T]] = vector_scale[T](v, k)
  }
  implicit class MatrixOps[T:Numeric:Manifest](A:Rep[Matrix[T]]) {
    def |*(B:Rep[Matrix[T]]):Rep[Matrix[T]] = tensor_prod(A,B)
  }

  implicit def any2rep[T:Manifest](t:T) = unit(t)

}

trait Interpreter extends Base {
  override type Rep[+A] = A
  override protected def unit[A: Manifest](a: A) = a
}

trait LinearAlgebraInterpreter extends LinearAlgebra with Interpreter {

  override type Vector[T] = Array[T]
  override type Matrix[T] = Array[Array[T]]
  override def vector_scale[T:Manifest](v: Vector[T], k: T)(implicit num:Numeric[T]):Rep[Vector[T]] =  v map {x => num.times(x,k)}
  def tensor_prod[T:Manifest](A:Matrix[T],B:Matrix[T])(implicit num:Numeric[T]):Matrix[T] = {
    def smm(s:T,m:Matrix[T]) = m.map(_.map(x => num.times(x,s)))
    def concat(A:Matrix[T],B:Matrix[T]) = (A,B).zipped.map(_++_)
    A flatMap (row => row map ( s => smm(s,B)) reduce concat)
  }
}


trait LinearAlgebraExp extends LinearAlgebra with BaseExp {
  // Here we say how a Rep[Vector] will be bound to a Array[Scalar] in regular Scala code
  override type Vector[T] = Array[T]
  type Matrix[T] = Array[Array[T]]

  // Reification of the concept of scaling a vector `v` by a factor `k`
  case class VectorScale[T:Manifest:Numeric](v: Exp[Vector[T]], k: Exp[T]) extends Def[Vector[T]]

  override def vector_scale[T:Manifest:Numeric](v: Exp[Vector[T]], k: Exp[T]) = toAtom(VectorScale(v, k))
  def tensor_prod[T:Manifest:Numeric](A:Rep[Matrix[T]],B:Rep[Matrix[T]]):Rep[Matrix[T]] = ???
}

trait ScalaGenLinearAlgebra extends ScalaGenBase {
  // This code generator works with IR nodes defined by the LinearAlgebraExp trait
  val IR: LinearAlgebraExp
  import IR._

  override def emitNode(sym: Sym[Any], node: Def[Any]): Unit = node match {
    case VectorScale(v, k) => {
      emitValDef(sym, quote(v) + ".map(x => x * " + quote(k) + ")")
    }
    case _ => super.emitNode(sym, node)
  }
}

trait LinearAlgebraExpOpt extends LinearAlgebraExp {
  override def vector_scale[T:Manifest:Numeric](v: Exp[Vector[T]], k: Exp[T]) = k match {
    case Const(1.0) => v
    case _ => super.vector_scale(v, k)
  }
}

trait Prog extends LinearAlgebra {
  def f[T:Manifest](v: Rep[Vector[T]])(implicit num:Numeric[T]): Rep[Vector[T]] = v * unit(num.fromInt(3))
  def g[T:Manifest](v: Rep[Vector[T]])(implicit num:Numeric[T]): Rep[Vector[T]] = v * unit(num.fromInt(1))
  //def h(A:Rep[Matrix],B:Rep[Matrix]):Rep[Matrix] = A |* B

}

object TestLinAlg extends App {

  val interpretedProg = new Prog with LinearAlgebraInterpreter {
    println(g(Array(1.0, 2.0)).mkString(","))
  }

  val optProg = new Prog with LinearAlgebraExpOpt with EffectExp with CompileScala { self =>
    override val codegen = new ScalaGenEffect with ScalaGenLinearAlgebra { val IR: self.type = self }
    codegen.emitSource(g[Double], "optimizedG", new java.io.PrintWriter(System.out))
  }

  val nonOptProg = new Prog with LinearAlgebraExp with EffectExp with CompileScala { self =>
    override val codegen = new ScalaGenEffect with ScalaGenLinearAlgebra { val IR: self.type = self }
    codegen.emitSource(g[Double], "nonOptimizedG", new java.io.PrintWriter(System.out))
  }

  def compareInterpCompiled = {
    val optcomp = optProg.compile(optProg.g[Double])
    val nonOptComp = nonOptProg.compile(nonOptProg.g[Double])
    val a = Array(1.0,2.0)
    optcomp(a).toList == nonOptComp(a).toList
  }

  println(compareInterpCompiled)

}

My goal was to use the example at https://github.com/julienrf/lms-tutorial/wiki and then modify it to use numeric types instead. I wanted to discover that the overhead from passing the implicit type-class was totally stripped. The (successful) output from the above program is here.

We see that the call to num.times is indeed stripped

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top