Question

Why is it that in F#, I can do this...

let s = seq { for i in 0 .. 4095 do yield i } :?> IEnumerator

... but this throws a System.InvalidCastException?

let s = Seq.init 4095 (fun i -> i) :?> IEnumerator
Was it helpful?

Solution

A sequence expression creates an object that implements IEnumerable<T> and IEnumerator<T>

let s = seq { for i in 0 .. 4095 do yield i }
printfn "%b" (s :? IEnumerable<int>) // true
printfn "%b" (s :? IEnumerator<int>) // true

But Seq.init does not:

let s = Seq.init 4095 (fun i -> i)
printfn "%b" (s :? IEnumerable<int>) // true
printfn "%b" (s :? IEnumerator<int>) // false

You could refactor your code to use IEnumerable<T> instead of IEnumerator since both constructs produce an IEnumerable<T>.

Alternatively, if you really want an IEnumerator, you could simply call GetEnumerator to return an Enumerator from an Enumerable:

let s = (Seq.init 4095 (fun i -> i)).GetEnumerator()
printfn "%b" (s :? IEnumerable<int>) // false
printfn "%b" (s :? IEnumerator<int>) // true

OTHER TIPS

If you look at the specification, you sequence expression is converted to:

Seq.collect (fun pat -> Seq.singleton(pat)) (0 .. 4095)

if you look at the source for the definition of Seq.collect it is:

let collect f sources = map f sources |> concat

and if you look at the definition for concat it is:

let concat sources = 
            checkNonNull "sources" sources
            mkConcatSeq sources

mkConcatSeq is defined as:

let mkConcatSeq (sources: seq<'U :> seq<'T>>) = 
            mkSeq (fun () -> new ConcatEnumerator<_,_>(sources) :> IEnumerator<'T>)

so you can see that the returned sequence implements IEnumerator<'T> and therefore IEnumerator.

Now Seq.init is defined as:

let init count f =
            if count < 0 then invalidArg "count" (SR.GetString(SR.inputMustBeNonNegative))
            mkSeq (fun () -> IEnumerator.upto (Some (count-1)) f)

and mkSeq is defined as:

let mkSeq f = 
            { new IEnumerable<'U> with 
                member x.GetEnumerator() = f()
              interface IEnumerable with 
                member x.GetEnumerator() = (f() :> IEnumerator) }

so it only implements IEnumerable<'T> and not IEnumerator.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top