I've adapted Petr's suggestion to come up with what I think is a more generally usable solution, in that it doesn't place restrictions on the positioning of if filters in the for comprehension (although it has a little more syntactic overhead).
The idea is again to enclose the underlying stream in a wrapper object, which delegates the flatMap
, map
and filter
methods without modification, but first applies a takeWhile
call to the underlying stream, with a predicate of !isTruncated
, where isTruncated
is a field belonging to the wrapper object. Calling truncate
on the wrapper object at any point will flip the isTruncated
flag and effectively terminate further iteration over the stream. This relies heavily on the fact that the takeWhile
call on the underlying stream is lazily evaluated, so it is possible for code executed at later stages of the iteration to affect its behaviour.
The downside is you have to hold on to references to the streams you want to be able truncate mid-iteration, by appending || s.truncate
to the filter expression (where s
is the reference to the wrapped stream). You also need to make sure you call reset
on the wrapper object (or use a new wrapper object) before each new iteration through the stream, unless you know that the repeat iterations will behave identically each time through.
import scala.collection._
import scala.collection.generic._
class TruncatableStream[A]( private val underlying: Stream[A]) {
private var isTruncated = false;
private var active = underlying.takeWhile(a => !isTruncated)
def flatMap[B, That](f: (A) => GenTraversableOnce[B])(implicit bf: CanBuildFrom[Stream[A], B, That]): That = active.flatMap(f);
def map[B, That](f: (A) => B)(implicit bf: CanBuildFrom[Stream[A], B, That]): That = active.map(f);
def filter(p: A => Boolean): Stream[A] = active.filter(p);
def truncate() = {
isTruncated = true
false
}
def reset() = {
isTruncated = false
active = underlying.takeWhile(a => !isTruncated)
}
}
val s1 = new TruncatableStream(Stream.from(1))
val s2 = new TruncatableStream(Stream.from(1))
val pairs = for {
i <- s1
// reset the nested iteration at the start of each outer iteration loop
// (not strictly required here as the repeat iterations are all identical)
// alternatively, could just write: s2 = new TruncatableStream(Stream.from(1))
_ = _s2.reset()
j <- s2
if i < 3 || s1.truncate
if j < 3 || s2.truncate
}
yield (i, j)
pairs.take(2).toList // res1: List[(Int, Int)] = List((1,1), (1,2))
pairs.take(4).toList // res2: List[(Int, Int)] = List((1,1), (1,2), (2,1), (2,2))
No doubt this could be improved upon, but it seems a reasonable solution to the problem.