From b150c78e04b26446e7bda0847a75e8e52e17c6a9 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Wed, 8 Sep 2021 18:54:31 -0400 Subject: [PATCH] fix collect on stateful iterators Generalization of #41919 Fixes #42168 --- base/array.jl | 61 ++++++++++++++++++++++++++++---------------- base/dict.jl | 4 +-- base/set.jl | 2 +- src/julia-syntax.scm | 8 +++--- test/iterators.jl | 11 +++++--- 5 files changed, 52 insertions(+), 34 deletions(-) diff --git a/base/array.jl b/base/array.jl index bd9d3b8733541..15fab9d95c566 100644 --- a/base/array.jl +++ b/base/array.jl @@ -643,23 +643,38 @@ julia> collect(Float64, 1:2:5) """ collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr)) -_collect(::Type{T}, itr, isz::HasLength) where {T} = copyto!(Vector{T}(undef, Int(length(itr)::Integer)), itr) -_collect(::Type{T}, itr, isz::HasShape) where {T} = copyto!(similar(Array{T}, axes(itr)), itr) +_collect(::Type{T}, itr, isz::Union{HasLength,HasShape}) where {T} = + copyto!(_array_for(T, isz, _similar_shape(itr, isz)), itr) function _collect(::Type{T}, itr, isz::SizeUnknown) where T a = Vector{T}() for x in itr - push!(a,x) + push!(a, x) end return a end # make a collection similar to `c` and appropriate for collecting `itr` -_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown) where {T} = similar(c, T, 0) -_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength) where {T} = - similar(c, T, Int(length(itr)::Integer)) -_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape) where {T} = - similar(c, T, axes(itr)) -_similar_for(c, ::Type{T}, itr, isz) where {T} = similar(c, T) +_similar_for(c, ::Type{T}, itr, isz, shp) where {T} = similar(c, T) + +_similar_shape(itr, ::SizeUnknown) = nothing +_similar_shape(itr, ::HasLength) = length(itr)::Integer +_similar_shape(itr, ::HasShape) = axes(itr) + +_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown, ::Nothing) where {T} = + similar(c, T, 0) +_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength, len::Integer) where {T} = + similar(c, T, len) +_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape, axs) where {T} = + similar(c, T, axs) + +# make a collection appropriate for collecting `itr::Generator` +_array_for(::Type{T}, ::SizeUnknown, ::Nothing) where {T} = Vector{T}(undef, 0) +_array_for(::Type{T}, ::HasLength, len::Integer) where {T} = Vector{T}(undef, Int(len)) +_array_for(::Type{T}, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs) + +# used by syntax lowering for simple typed comprehensions +_array_for(::Type{T}, itr, isz) where {T} = _array_for(T, isz, _similar_shape(itr, isz)) + """ collect(collection) @@ -698,10 +713,10 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A) collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr)) _collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) = - copyto!(_similar_for(cont, eltype(itr), itr, isz), itr) + copyto!(_similar_for(cont, eltype(itr), itr, isz, _similar_shape(itr, isz)), itr) function _collect(cont, itr, ::HasEltype, isz::SizeUnknown) - a = _similar_for(cont, eltype(itr), itr, isz) + a = _similar_for(cont, eltype(itr), itr, isz, nothing) for x in itr push!(a,x) end @@ -759,24 +774,19 @@ else end end -_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr)) -_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr)) -_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len) -_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs) - function collect(itr::Generator) isz = IteratorSize(itr.iter) et = @default_eltype(itr) if isa(isz, SizeUnknown) return grow_to!(Vector{et}(), itr) else - shape = isz isa HasLength ? length(itr) : axes(itr) + shp = _similar_shape(itr, isz) y = iterate(itr) if y === nothing - return _array_for(et, itr.iter, isz) + return _array_for(et, isz, shp) end v1, st = y - dest = _array_for(typeof(v1), itr.iter, isz, shape) + dest = _array_for(typeof(v1), isz, shp) # The typeassert gives inference a helping hand on the element type and dimensionality # (work-around for #28382) et′ = et <: Type ? Type : et @@ -786,15 +796,22 @@ function collect(itr::Generator) end _collect(c, itr, ::EltypeUnknown, isz::SizeUnknown) = - grow_to!(_similar_for(c, @default_eltype(itr), itr, isz), itr) + grow_to!(_similar_for(c, @default_eltype(itr), itr, isz, nothing), itr) function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape}) + et = @default_eltype(itr) + shp = _similar_shape(itr, isz) y = iterate(itr) if y === nothing - return _similar_for(c, @default_eltype(itr), itr, isz) + return _similar_for(c, et, itr, isz, shp) end v1, st = y - collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st) + dest = _similar_for(c, typeof(v1), itr, isz, shp) + # The typeassert gives inference a helping hand on the element type and dimensionality + # (work-around for #28382) + et′ = et <: Type ? Type : et + RT = dest isa AbstractArray ? AbstractArray{<:et′, ndims(dest)} : Any + collect_to_with_first!(dest, v1, itr, st)::RT end function collect_to_with_first!(dest::AbstractArray, v1, itr, st) diff --git a/base/dict.jl b/base/dict.jl index 6918677c4f0bb..1978323e88503 100644 --- a/base/dict.jl +++ b/base/dict.jl @@ -826,6 +826,6 @@ length(t::ImmutableDict) = count(Returns(true), t) isempty(t::ImmutableDict) = !isdefined(t, :parent) empty(::ImmutableDict, ::Type{K}, ::Type{V}) where {K, V} = ImmutableDict{K,V}() -_similar_for(c::Dict, ::Type{Pair{K,V}}, itr, isz) where {K, V} = empty(c, K, V) -_similar_for(c::AbstractDict, ::Type{T}, itr, isz) where {T} = +_similar_for(c::AbstractDict, ::Type{Pair{K,V}}, itr, isz, len) where {K, V} = empty(c, K, V) +_similar_for(c::AbstractDict, ::Type{T}, itr, isz, len) where {T} = throw(ArgumentError("for AbstractDicts, similar requires an element type of Pair;\n if calling map, consider a comprehension instead")) diff --git a/base/set.jl b/base/set.jl index 6511d1dd7e108..dd1400d11dba1 100644 --- a/base/set.jl +++ b/base/set.jl @@ -44,7 +44,7 @@ empty(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}() # by default, a Set is returned emptymutable(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}() -_similar_for(c::AbstractSet, ::Type{T}, itr, isz) where {T} = empty(c, T) +_similar_for(c::AbstractSet, ::Type{T}, itr, isz, len) where {T} = empty(c, T) function show(io::IO, s::Set) if isempty(s) diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 428b0513b7e52..ef533059b3993 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -2734,7 +2734,7 @@ (check-no-return expr) (if (has-break-or-continue? expr) (error "break or continue outside loop")) - (let ((result (gensy)) + (let ((result (make-ssavalue)) (idx (gensy)) (oneresult (make-ssavalue)) (prod (make-ssavalue)) @@ -2758,16 +2758,14 @@ (let ((overall-itr (if (length= itrs 1) (car iv) prod))) `(scope-block (block - (local ,result) (local ,idx) + (local ,idx) ,.(map (lambda (v r) `(= ,v ,(caddr r))) iv itrs) ,.(if (length= itrs 1) '() `((= ,prod (call (top product) ,@iv)))) (= ,isz (call (top IteratorSize) ,overall-itr)) (= ,szunk (call (core isa) ,isz (top SizeUnknown))) - (if ,szunk - (= ,result (call (curly (core Array) ,ty 1) (core undef) 0)) - (= ,result (call (top _array_for) ,ty ,overall-itr ,isz))) + (= ,result (call (top _array_for) ,ty ,overall-itr ,isz)) (= ,idx (call (top first) (call (top LinearIndices) ,result))) ,(construct-loops (reverse itrs) (reverse iv)) ,result))))) diff --git a/test/iterators.jl b/test/iterators.jl index c7d00c4e7e2e8..86c325a85b617 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -293,11 +293,14 @@ let (a, b) = (1:3, [4 6; end # collect stateful iterator -let - itr = (i+1 for i in Base.Stateful([1,2,3])) +let itr + itr = Iterators.Stateful(Iterators.map(identity, 1:5)) + @test collect(itr) == 1:5 + @test collect(itr) == Int[] # Stateful do not preserve shape + itr = (i+1 for i in Base.Stateful([1, 2, 3])) @test collect(itr) == [2, 3, 4] - A = zeros(Int, 0, 0) - itr = (i-1 for i in Base.Stateful(A)) + @test collect(itr) == Int[] # Stateful do not preserve shape + itr = (i-1 for i in Base.Stateful(zeros(Int, 0, 0))) @test collect(itr) == Int[] # Stateful do not preserve shape end