Question

I'm trying to implement the Sieve of Eratosthenes in Scala.

I start by initializing a sequence of all odd numbers plus 2:

// (end goal is to find all prime factors of bigNumber)
val largestPrime : Long = Math.ceil(Math.sqrt(bigNumber)).toLong
var nums : Seq[Long] = (3L to largestPrime by 2L).toSeq
nums +: 2L

Now nums contains Seq( 2,3,5,7,9,11,13,15,...,(largestPrime) ). Then, by the Sieve, I want to iterate over each element, and filter all multiples of that element from the Seq. It would look something like this, except this simply iterates over every odd number:

for(i : Long <- 3L to largestPrime by 2L) {
    nums = nums.filter((j : Long) => j == i || j % i != 0)
}

So instead, I would want to use something like this:

for(i <- nums) {
    // filter
}

But of course, this simply copies the sequence into an iterator and then iterates over every value in nums as it was at the beginning of the for loop (so in this case, it's exactly equivalent to the previous example). I want it to, every iteration, grab the next value from nums.

How is the best way to implement this? Should I use an index variable and a while loop? I'm not sure how to get an element from a sequence (i.e. how to get element x of the sequence, where x is the index). Or is there a more functional way to do this?


Edit: I just found the scanLeft function, I'm trying to grasp how to use it as I suspect it might be of use in this case...

Was it helpful?

Solution

Let's start off with what I deem to be the biggest problem above. You have this:

for (i <- mi) { mi = something else }

This will not change the mi that is being iterated over. That mi will stay the same throughout. It might be that you can mutate the value of mi, but changing it won't work. Mutating it may not work either, by the way.

So, how do you do it? You don't use for comprehensions -- or, at least, not this way. You can look at my own version here, which iterates through a different collection than the one being mutated. Or here is a one-liner:

(n: Int) => (2 to n) |> (r => r.foldLeft(r.toSet)((ps, x) => if (ps(x)) ps -- (x * x to n by x) else ps))

Now, back to what you want to do... when you use a for-comprehension, you are actually calling the method foreach, map or flatMap on it, so you'd need a collection which is capable of handling one of these methods and not having trouble with the "next" element changing from one iteration to the next. As I said, I'm not sure any of Scala's collection fit the bill. You'd be better off using a while loop and keeping track of things yourself if you go this way. For instance:

def primes(n: Int) = {
    import scala.collection.mutable.LinkedList
    val primes = LinkedList(3 to n by 2: _*)
    var p = primes
    while (p.nonEmpty) {
        var scanner = p
        while (scanner.next.nonEmpty) {
            if (scanner.next.head % p.head == 0)
                scanner.next = scanner.next.next
            else
                scanner = scanner.next
        }
        p = p.next
    }
    primes
}

Note that I keep a pointer to the start of the LinkedList, move p through each known prime, and move scanner through all remaining numbers to cut the non-primes.

OTHER TIPS

The example in the docs on scala.collection.immutable.Stream is of a sieve:

object Main extends Application {

  def from(n: Int): Stream[Int] =
    Stream.cons(n, from(n + 1))

  def sieve(s: Stream[Int]): Stream[Int] =
    Stream.cons(s.head, sieve(s.tail filter { _ % s.head != 0 }))

  def primes = sieve(from(2))

  primes take 10 print
}

Neither a good functional solution nor a solution which reveals obscure Scala Library treasures but it’s rather easy building a specialised iterator yourself.

class ModifyingIterator(var collection: Seq[Long]) extends Iterator[Long] {
  var current = collection.head
  def next = {
    current = collection.find(_ > current).get
    current
  }
  def hasNext = collection.exists(_ > current)
}

val mi = new ModifyingIterator(nums)

for (i <- mi) {
    mi.collection = mi.collection.filter((j : Long) => j == i || j % i != 0)
}
println(mi.collection)

ModifyingIterator keeps track of the current item and allows for reassigning the collection which is used for iterating. The next item is always larger than the current item.

Of course, one should probably employ a better data structure which does not keep track of the current value but keep a pointer to the current item in order to get rid of the useless search every time.

There is an interesting paper: http://www.cs.hmc.edu/~oneill/papers/Sieve-JFP.pdf

I tried to translate the Haskell code given in that paper to Scala, but I didn't test the performance.

object primes {

    type SI = Stream[Int]

    def sieve:SI = {
        def wheel2357:SI = Stream(4,2,4,6,2,6,4,2,4,6,6,
            2,6,4,2,6,4,6,8,4,2,4,2,4,8,6,4,6,2,4,6,2,6,
            6,4,2,4,6,2,6,4,2,4,2,10,2,10,2) append wheel2357
        def spin(s:SI, n:Int):SI = Stream.cons(n, spin(s.tail, n + s.head))

        case class It(value:Int, step:Int) {
            def next = new It(value + step, step)

            def atLeast(c:Int):It =
            if (value >= c) this
            else new It(value + step, step).atLeast(c)
        }

        implicit object ItOrdering extends Ordering[It] {
            def compare(thiz:It, that:It) = {
                val r = thiz.value - that.value
                if (r == 0) thiz.step - that.step else r
            }

        }

        import scala.collection.immutable.TreeSet

        def sieve(cand:SI, set:Set[It]):SI = {
            val c = cand.head
            val set1 = TreeSet[It]() ++ set.dropWhile(_.value < c) ++
               set.takeWhile(_.value < c).map(_.atLeast(c))
            if (set1.elements.next.value == c) {
                val set2 = TreeSet[It]() ++ set1.dropWhile(_.value == c) ++
                    set1.takeWhile(_.value == c).map(_.next)
                sieve(cand.tail, set2)
            } else {
                Stream.cons(c, sieve(cand.tail, set1 + It(c*c,2*c)))
            }
        }
        Stream(2,3,5,7,11) append sieve(spin(wheel2357,13),
                  new TreeSet[It] + It(121,22))
    }

    def main(args:Array[String]) {
        sieve.takeWhile(_ < 1000).foreach(println)
    }
}
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top