Skip to content

Commit

Permalink
Fix collect on stateful generator (JuliaLang#41919)
Browse files Browse the repository at this point in the history
Previously this code would drop 1 from the length of some generators.

Fixes JuliaLang#35530
  • Loading branch information
jakobnissen authored and LilithHafner committed Feb 22, 2022
1 parent 178472e commit 1d5c29a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
10 changes: 7 additions & 3 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -758,21 +758,25 @@ else
end
end

_array_for(::Type{T}, itr, ::HasLength) where {T} = Vector{T}(undef, Int(length(itr)::Integer))
_array_for(::Type{T}, itr, ::HasShape{N}) where {T,N} = similar(Array{T,N}, axes(itr))
_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr))
_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr))
_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len)
_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)

function collect(itr::Generator)
isz = IteratorSize(itr.iter)
et = @default_eltype(itr)
if isa(isz, SizeUnknown)
return grow_to!(Vector{et}(), itr)
else
shape = isz isa HasLength ? length(itr) : axes(itr)
y = iterate(itr)
if y === nothing
return _array_for(et, itr.iter, isz)
end
v1, st = y
collect_to_with_first!(_array_for(typeof(v1), itr.iter, isz), v1, itr, st)
arr = _array_for(typeof(v1), itr.iter, isz, shape)
return collect_to_with_first!(arr, v1, itr, st)
end
end

Expand Down
11 changes: 10 additions & 1 deletion test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,15 @@ let (a, b) = (1:3, [4 6;
end
end

# collect stateful iterator
let
itr = (i+1 for i in Base.Stateful([1,2,3]))
@test collect(itr) == [2, 3, 4]
A = zeros(Int, 0, 0)
itr = (i-1 for i in Base.Stateful(A))
@test collect(itr) == Int[] # Stateful do not preserve shape
end

# with 1D inputs
let a = 1:2,
b = 1.0:10.0,
Expand Down Expand Up @@ -860,4 +869,4 @@ end
@test Iterators.peel(1:10)[2] |> collect == 2:10
@test Iterators.peel(x^2 for x in 2:4)[1] == 4
@test Iterators.peel(x^2 for x in 2:4)[2] |> collect == [9, 16]
end
end

0 comments on commit 1d5c29a

Please sign in to comment.