Question

I sometimes find myself wanting to perform nested iterations over infinite streams in Scala for comprehensions, but specifying the loop termination condition can be a little tricky. Is there a better way of doing this kind of thing?

The use case I have in mind is where I don't necessarily know up front how many elements will be needed from each of the infinite streams I'm iterating over (but obviously I know it's not going to be an infinite number). Assume the termination condition for each stream might depend on the values of other elements in the for expression in some complicated way.

The initial thought is to try to write the stream termination condition as an if filter clause in the for expression, however this runs into trouble when looping over nested infinite streams, because there's no way of short-circuiting the iteration over the first infinite stream, which ultimately leads to an OutOfMemoryError. I understand why this is the case, given how for expressions map to map, flatMap and withFilter method calls - my question is whether there's a better idiom for doing this kind of thing (perhaps not involving for comprehensions at all).

To give a somewhat contrived example to illustrate the issue just described, consider the following (very naive) code to generate all pairings of the numbers 1 and 2:

val pairs = for {
  i <- Stream.from(1) 
  if i < 3 
  j <- Stream.from(1) 
  if j < 3
} 
yield (i, j)

pairs.take(2).toList 
// result: List[(Int, Int)] = List((1,1), (1,2)) 

pairs.take(4).toList
// 'hoped for' result: List[(Int, Int)] = List((1,1), (1,2), (2,1), (2,2))
// actual result:
//  java.lang.OutOfMemoryError: Java heap space
//      at scala.collection.immutable.Stream$.from(Stream.scala:1105)

Obviously, in this simple example, the problem could easily be averted by moving the if filters into takeWhile method calls on the original streams, as follows:

val pairs = for {
  i <- Stream.from(1).takeWhile(_ < 3) 
  j <- Stream.from(1).takeWhile(_ < 3) 
}    
yield (i, j)

but for the purposes of the question imagine a more complex use case where the stream termination condition can't easily be moved onto the stream expression itself.

Was it helpful?

Solution 2

I've adapted Petr's suggestion to come up with what I think is a more generally usable solution, in that it doesn't place restrictions on the positioning of if filters in the for comprehension (although it has a little more syntactic overhead).

The idea is again to enclose the underlying stream in a wrapper object, which delegates the flatMap, map and filter methods without modification, but first applies a takeWhile call to the underlying stream, with a predicate of !isTruncated, where isTruncated is a field belonging to the wrapper object. Calling truncate on the wrapper object at any point will flip the isTruncated flag and effectively terminate further iteration over the stream. This relies heavily on the fact that the takeWhile call on the underlying stream is lazily evaluated, so it is possible for code executed at later stages of the iteration to affect its behaviour.

The downside is you have to hold on to references to the streams you want to be able truncate mid-iteration, by appending || s.truncate to the filter expression (where s is the reference to the wrapped stream). You also need to make sure you call reset on the wrapper object (or use a new wrapper object) before each new iteration through the stream, unless you know that the repeat iterations will behave identically each time through.

import scala.collection._
import scala.collection.generic._

class TruncatableStream[A]( private val underlying: Stream[A]) {
  private var isTruncated = false;

  private var active = underlying.takeWhile(a => !isTruncated)

  def flatMap[B, That](f: (A) => GenTraversableOnce[B])(implicit bf: CanBuildFrom[Stream[A], B, That]): That = active.flatMap(f);

  def map[B, That](f: (A) => B)(implicit bf: CanBuildFrom[Stream[A], B, That]): That = active.map(f);

  def filter(p: A => Boolean): Stream[A] = active.filter(p);

  def truncate() = {
    isTruncated = true
    false
  }

  def reset() = {
    isTruncated = false
    active = underlying.takeWhile(a => !isTruncated)
  }
}

val s1 = new TruncatableStream(Stream.from(1))
val s2 = new TruncatableStream(Stream.from(1))

val pairs = for {
  i <- s1

  // reset the nested iteration at the start of each outer iteration loop 
  // (not strictly required here as the repeat iterations are all identical)
  // alternatively, could just write: s2 = new TruncatableStream(Stream.from(1))  
  _ = _s2.reset()      

  j <- s2
  if i < 3 || s1.truncate
  if j < 3 || s2.truncate
} 
yield (i, j)

pairs.take(2).toList  // res1: List[(Int, Int)] = List((1,1), (1,2))
pairs.take(4).toList  // res2: List[(Int, Int)] = List((1,1), (1,2), (2,1), (2,2))

No doubt this could be improved upon, but it seems a reasonable solution to the problem.

OTHER TIPS

One possibility is to wrap Stream into your own class that handles filter differently, in this case, as takeWhile:

import scala.collection._
import scala.collection.generic._

class MyStream[+A]( val underlying: Stream[A] ) {
  def flatMap[B, That](f: (A) => GenTraversableOnce[B])(implicit bf: CanBuildFrom[Stream[A], B, That]): That = underlying.flatMap(f);

  def map[B, That](f: (A) ⇒ B)(implicit bf: CanBuildFrom[Stream[A], B, That]): That = underlying.map(f);

  def filter(p: A => Boolean): Stream[A] = underlying.takeWhile(p);
  //                                       ^^^^^^^^^^^^^^^^^^^^^^^^
}

object MyStream extends App {
  val pairs = for {
    i <- new MyStream(Stream.from(1))
    if i < 3
    j <- new MyStream(Stream.from(1))
    if j < 3
  } yield (i, j);

  print(pairs.toList);
}

This prints List((1,1), (1,2), (2,1), (2,2)).

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top