Domanda

I was explaining a friend that I expected non tail recursive function in Scala to be slower than tail recursive ones, so I decided to verify it. I wrote a good old factorial function both ways and attempted to compare the results. Here's the code:

def main(args: Array[String]): Unit = {
  val N = 2000 // not too much or else stackoverflows
  var spent1: Long = 0
  var spent2: Long = 0
  for ( i <- 1 to 100 ) { // repeat to average the results
    val t0 = System.nanoTime
    factorial(N)
    val t1 = System.nanoTime
    tailRecFact(N)
    val t2 = System.nanoTime
    spent1 += t1 - t0
    spent2 += t2 - t1
  }
  println(spent1/1000000f) // get milliseconds
  println(spent2/1000000f)
}

@tailrec
def tailRecFact(n: BigInt, s: BigInt = 1): BigInt = if (n == 1) s else tailRecFact(n - 1, s * n)

def factorial(n: BigInt): BigInt = if (n == 1) 1 else n * factorial(n - 1)

The results are confusing me, I get this kind of output:

578.2985

870.22125

Meaning the non tail recursive function is 30% faster than the tail recursive one, and the number of operation is the same!

What would explain those results?

È stato utile?

Soluzione 2

In addition to the problem shown by @monkjack (i.e multiplying small * big is faster than big * small, which does account for a greater chunk of the difference), your algorithm is different in each case so they're not really comparable.

In the tail-recursive version you're mutiplying big-to-small:

n * n-1 * n-2 * ... * 2 * 1

In the non-tail recursive version you're multiplying small-to-big:

n * (n-1 * (n-2 * (... * (2 * 1))))

If you alter the tail-recursive version so it multiplies small-to-big:

def tailRecFact2(n: BigInt) = {
  def loop(x: BigInt, out: BigInt): BigInt =
    if (x > n) out else loop(x + 1, x * out)
  loop(1, 1)
}

then tail-recursion is about 20% faster than normal-recursion, rather than 10% slower as it is if you just make monkjack's correction. This is because multiplying together small BigInts is faster than multiplying large ones.

Altri suggerimenti

It's actually not where you would first look.The reason is in your tail recursion method, you are doing more work with its multiply. Try swapping around the order of the params n and s in the recursive call and it will even out.

def tailRecFact(n: BigInt, s: BigInt): BigInt = if (n == 1) s else tailRecFact(n - 1, n * s)

Moreover, most of the time in this sample is taken up with the BigInt operations which dwarf the time of the recursive call. If we switch these over to Ints (compiled to Java primitives) then you can see the how tail recursion (goto) compares to method invocation.

object Test extends App {

  val N = 2000

  val t0 = System.nanoTime()
  for ( i <- 1 to 1000 ) {
    factorial(N)
  }
  val t1 = System.nanoTime
  for ( i <- 1 to 1000 ) {
    tailRecFact(N, 1)
  }
  val t2 = System.nanoTime

  println((t1 - t0) / 1000000f) // get milliseconds
  println((t2 - t1) / 1000000f)

  def factorial(n: Int): Int = if (n == 1) 1 else n * factorial(n - 1)

  @tailrec
  final def tailRecFact(n: Int, s: Int): Int = if (n == 1) s else tailRecFact(n - 1, s * n)
}

95.16733
3.987605

For interest, the decompiled output

  public final scala.math.BigInt tailRecFact(scala.math.BigInt, scala.math.BigInt);
    Code:
       0: aload_1       
       1: iconst_1      
       2: invokestatic  #16                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
       5: invokestatic  #20                 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
       8: ifeq          13
      11: aload_2       
      12: areturn       
      13: aload_1       
      14: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      17: iconst_1      
      18: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      21: invokevirtual #36                 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
      24: aload_1       
      25: aload_2       
      26: invokevirtual #39                 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
      29: astore_2      
      30: astore_1      
      31: goto          0

  public scala.math.BigInt factorial(scala.math.BigInt);
    Code:
       0: aload_1       
       1: iconst_1      
       2: invokestatic  #16                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
       5: invokestatic  #20                 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
       8: ifeq          21
      11: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      14: iconst_1      
      15: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      18: goto          40
      21: aload_1       
      22: aload_0       
      23: aload_1       
      24: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      27: iconst_1      
      28: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      31: invokevirtual #36                 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
      34: invokevirtual #47                 // Method factorial:(Lscala/math/BigInt;)Lscala/math/BigInt;
      37: invokevirtual #39                 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
      40: areturn   
Autorizzato sotto: CC-BY-SA insieme a attribuzione
Non affiliato a StackOverflow
scroll top