Skip to content

Commit

Permalink
Merge pull request #42172 from JuliaLang/jn/42168
Browse files Browse the repository at this point in the history
fix collect on stateful iterators
  • Loading branch information
JeffBezanson authored Sep 14, 2021
2 parents 60423e2 + 68e0813 commit 4c90ed9
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 88 deletions.
61 changes: 39 additions & 22 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -643,23 +643,38 @@ julia> collect(Float64, 1:2:5)
"""
collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr))

_collect(::Type{T}, itr, isz::HasLength) where {T} = copyto!(Vector{T}(undef, Int(length(itr)::Integer)), itr)
_collect(::Type{T}, itr, isz::HasShape) where {T} = copyto!(similar(Array{T}, axes(itr)), itr)
_collect(::Type{T}, itr, isz::Union{HasLength,HasShape}) where {T} =
copyto!(_array_for(T, isz, _similar_shape(itr, isz)), itr)
function _collect(::Type{T}, itr, isz::SizeUnknown) where T
a = Vector{T}()
for x in itr
push!(a,x)
push!(a, x)
end
return a
end

# make a collection similar to `c` and appropriate for collecting `itr`
_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown) where {T} = similar(c, T, 0)
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength) where {T} =
similar(c, T, Int(length(itr)::Integer))
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape) where {T} =
similar(c, T, axes(itr))
_similar_for(c, ::Type{T}, itr, isz) where {T} = similar(c, T)
_similar_for(c, ::Type{T}, itr, isz, shp) where {T} = similar(c, T)

_similar_shape(itr, ::SizeUnknown) = nothing
_similar_shape(itr, ::HasLength) = length(itr)::Integer
_similar_shape(itr, ::HasShape) = axes(itr)

_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown, ::Nothing) where {T} =
similar(c, T, 0)
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength, len::Integer) where {T} =
similar(c, T, len)
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape, axs) where {T} =
similar(c, T, axs)

# make a collection appropriate for collecting `itr::Generator`
_array_for(::Type{T}, ::SizeUnknown, ::Nothing) where {T} = Vector{T}(undef, 0)
_array_for(::Type{T}, ::HasLength, len::Integer) where {T} = Vector{T}(undef, Int(len))
_array_for(::Type{T}, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)

# used by syntax lowering for simple typed comprehensions
_array_for(::Type{T}, itr, isz) where {T} = _array_for(T, isz, _similar_shape(itr, isz))


"""
collect(collection)
Expand Down Expand Up @@ -698,10 +713,10 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A)
collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr))

_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) =
copyto!(_similar_for(cont, eltype(itr), itr, isz), itr)
copyto!(_similar_for(cont, eltype(itr), itr, isz, _similar_shape(itr, isz)), itr)

function _collect(cont, itr, ::HasEltype, isz::SizeUnknown)
a = _similar_for(cont, eltype(itr), itr, isz)
a = _similar_for(cont, eltype(itr), itr, isz, nothing)
for x in itr
push!(a,x)
end
Expand Down Expand Up @@ -759,24 +774,19 @@ else
end
end

_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr))
_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr))
_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len)
_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)

function collect(itr::Generator)
isz = IteratorSize(itr.iter)
et = @default_eltype(itr)
if isa(isz, SizeUnknown)
return grow_to!(Vector{et}(), itr)
else
shape = isz isa HasLength ? length(itr) : axes(itr)
shp = _similar_shape(itr, isz)
y = iterate(itr)
if y === nothing
return _array_for(et, itr.iter, isz)
return _array_for(et, isz, shp)
end
v1, st = y
dest = _array_for(typeof(v1), itr.iter, isz, shape)
dest = _array_for(typeof(v1), isz, shp)
# The typeassert gives inference a helping hand on the element type and dimensionality
# (work-around for #28382)
et′ = et <: Type ? Type : et
Expand All @@ -786,15 +796,22 @@ function collect(itr::Generator)
end

_collect(c, itr, ::EltypeUnknown, isz::SizeUnknown) =
grow_to!(_similar_for(c, @default_eltype(itr), itr, isz), itr)
grow_to!(_similar_for(c, @default_eltype(itr), itr, isz, nothing), itr)

function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape})
et = @default_eltype(itr)
shp = _similar_shape(itr, isz)
y = iterate(itr)
if y === nothing
return _similar_for(c, @default_eltype(itr), itr, isz)
return _similar_for(c, et, itr, isz, shp)
end
v1, st = y
collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st)
dest = _similar_for(c, typeof(v1), itr, isz, shp)
# The typeassert gives inference a helping hand on the element type and dimensionality
# (work-around for #28382)
et′ = et <: Type ? Type : et
RT = dest isa AbstractArray ? AbstractArray{<:et′, ndims(dest)} : Any
collect_to_with_first!(dest, v1, itr, st)::RT
end

function collect_to_with_first!(dest::AbstractArray, v1, itr, st)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function inline_into_block!(state::CFGInliningState, block::Int)
new_range = state.first_bb+1:block
l = length(state.new_cfg_blocks)
state.bb_rename[new_range] = (l+1:l+length(new_range))
append!(state.new_cfg_blocks, map(copy, state.cfg.blocks[new_range]))
append!(state.new_cfg_blocks, (copy(block) for block in state.cfg.blocks[new_range]))
push!(state.merged_orig_blocks, last(new_range))
end
state.first_bb = block
Expand Down
29 changes: 2 additions & 27 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,6 @@ function try_compute_fieldidx_args(typ::DataType, args::Vector{Any})
return try_compute_fieldidx(typ, field)
end

function lift_defuse(cfg::CFG, ssa::SSADefUse)
# We remove from `uses` any block where all uses are dominated
# by a def. This prevents insertion of dead phi nodes at the top
# of such a block if that block happens to be in a loop
ordered = Tuple{Int, Int, Bool}[(x, block_for_inst(cfg, x), true) for x in ssa.uses]
for x in ssa.defs
push!(ordered, (x, block_for_inst(cfg, x), false))
end
ordered = sort(ordered, by=x->x[1])
bb_defs = Int[]
bb_uses = Int[]
last_bb = last_def_bb = 0
for (_, bb, is_use) in ordered
if bb != last_bb && is_use
push!(bb_uses, bb)
end
last_bb = bb
if last_def_bb != bb && !is_use
push!(bb_defs, bb)
last_def_bb = bb
end
end
SSADefUse(bb_uses, bb_defs, Int[])
end

function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int)
# TODO: This can be much faster by looking at current level and only
# searching for those blocks in a sorted order
Expand Down Expand Up @@ -1209,12 +1184,12 @@ function cfg_simplify!(ir::IRCode)
# Compute (renamed) successors and predecessors given (renamed) block
function compute_succs(i)
orig_bb = follow_merged_succ(result_bbs[i])
return map(i -> bb_rename_succ[i], bbs[orig_bb].succs)
return Int[bb_rename_succ[i] for i in bbs[orig_bb].succs]
end
function compute_preds(i)
orig_bb = result_bbs[i]
preds = bbs[orig_bb].preds
return map(pred -> bb_rename_pred[pred], preds)
return Int[bb_rename_pred[pred] for pred in preds]
end

BasicBlock[
Expand Down
7 changes: 4 additions & 3 deletions base/compiler/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ show_unquoted(io::IO, val::Argument, indent::Int, prec::Int) = show_unquoted(io,

show_unquoted(io::IO, stmt::PhiNode, indent::Int, ::Int) = show_unquoted_phinode(io, stmt, indent, "%")
function show_unquoted_phinode(io::IO, stmt::PhiNode, indent::Int, prefix::String)
args = map(1:length(stmt.edges)) do i
args = String[let
e = stmt.edges[i]
v = !isassigned(stmt.values, i) ? "#undef" :
sprint() do io′
show_unquoted(io′, stmt.values[i], indent)
end
return "$prefix$e => $v"
end
"$prefix$e => $v"
end for i in 1:length(stmt.edges)
]
print(io, "φ ", '(')
join(io, args, ", ")
print(io, ')')
Expand Down
22 changes: 6 additions & 16 deletions base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@ function scan_entry!(result::Vector{SlotInfo}, idx::Int, @nospecialize(stmt))
end


function lift_defuse(cfg::CFG, defuse)
map(defuse) do slot
SlotInfo(
Int[block_for_inst(cfg, x) for x in slot.defs],
Int[block_for_inst(cfg, x) for x in slot.uses],
slot.any_newvar
)
end
end

function scan_slot_def_use(nargs::Int, ci::CodeInfo, code::Vector{Any})
nslots = length(ci.slotflags)
result = SlotInfo[SlotInfo() for i = 1:nslots]
Expand Down Expand Up @@ -524,7 +514,7 @@ function domsort_ssa!(ir::IRCode, domtree::DomTree)
return new_ir
end

function compute_live_ins(cfg::CFG, defuse)
function compute_live_ins(cfg::CFG, defuse #=::Union{SlotInfo,SSADefUse}=#)
# We remove from `uses` any block where all uses are dominated
# by a def. This prevents insertion of dead phi nodes at the top
# of such a block if that block happens to be in a loop
Expand Down Expand Up @@ -586,8 +576,8 @@ function recompute_type(node::Union{PhiNode, PhiCNode}, ci::CodeInfo, ir::IRCode
return new_typ
end

function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse,
slottypes::Vector{Any})
function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree,
defuses::Vector{SlotInfo}, slottypes::Vector{Any})
code = ir.stmts.inst
cfg = ir.cfg
left = Int[]
Expand Down Expand Up @@ -616,7 +606,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse,
for (_, exc) in catch_entry_blocks
phicnodes[exc] = Vector{Tuple{SlotNumber, NewSSAValue, PhiCNode}}()
end
@timeit "idf" for (idx, slot) in Iterators.enumerate(defuse)
@timeit "idf" for (idx, slot) in Iterators.enumerate(defuses)
# No uses => no need for phi nodes
isempty(slot.uses) && continue
# TODO: Restore this optimization
Expand Down Expand Up @@ -671,9 +661,9 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse,
end
# Perform SSA renaming
initial_incoming_vals = Any[
if 0 in defuse[x].defs
if 0 in defuses[x].defs
Argument(x)
elseif !defuse[x].any_newvar
elseif !defuses[x].any_newvar
undef_token
else
SSAValue(-2)
Expand Down
4 changes: 2 additions & 2 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,6 @@ length(t::ImmutableDict) = count(Returns(true), t)
isempty(t::ImmutableDict) = !isdefined(t, :parent)
empty(::ImmutableDict, ::Type{K}, ::Type{V}) where {K, V} = ImmutableDict{K,V}()

_similar_for(c::Dict, ::Type{Pair{K,V}}, itr, isz) where {K, V} = empty(c, K, V)
_similar_for(c::AbstractDict, ::Type{T}, itr, isz) where {T} =
_similar_for(c::AbstractDict, ::Type{Pair{K,V}}, itr, isz, len) where {K, V} = empty(c, K, V)
_similar_for(c::AbstractDict, ::Type{T}, itr, isz, len) where {T} =
throw(ArgumentError("for AbstractDicts, similar requires an element type of Pair;\n if calling map, consider a comprehension instead"))
2 changes: 1 addition & 1 deletion base/set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ empty(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()
# by default, a Set is returned
emptymutable(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()

_similar_for(c::AbstractSet, ::Type{T}, itr, isz) where {T} = empty(c, T)
_similar_for(c::AbstractSet, ::Type{T}, itr, isz, len) where {T} = empty(c, T)

function show(io::IO, s::Set)
if isempty(s)
Expand Down
6 changes: 4 additions & 2 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ void jl_foreach_reachable_mtable(void (*visit)(jl_methtable_t *mt, void *env), v
}
else {
foreach_mtable_in_module(jl_main_module, visit, env, &visited);
foreach_mtable_in_module(jl_core_module, visit, env, &visited);
}
JL_GC_POP();
}
Expand All @@ -493,14 +494,15 @@ static void reset_mt_caches(jl_methtable_t *mt, void *env)


jl_function_t *jl_typeinf_func = NULL;
size_t jl_typeinf_world = 0;
size_t jl_typeinf_world = 1;

JL_DLLEXPORT void jl_set_typeinf_func(jl_value_t *f)
{
size_t newfunc = jl_typeinf_world == 1 && jl_typeinf_func == NULL;
jl_typeinf_func = (jl_function_t*)f;
jl_typeinf_world = jl_get_tls_world_age();
++jl_world_counter; // make type-inference the only thing in this world
if (jl_typeinf_world == 0) {
if (newfunc) {
// give type inference a chance to see all of these
// TODO: also reinfer if max_world != ~(size_t)0
jl_array_t *unspec = jl_alloc_vec_any(0);
Expand Down
8 changes: 3 additions & 5 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -2734,7 +2734,7 @@
(check-no-return expr)
(if (has-break-or-continue? expr)
(error "break or continue outside loop"))
(let ((result (gensy))
(let ((result (make-ssavalue))
(idx (gensy))
(oneresult (make-ssavalue))
(prod (make-ssavalue))
Expand All @@ -2758,16 +2758,14 @@
(let ((overall-itr (if (length= itrs 1) (car iv) prod)))
`(scope-block
(block
(local ,result) (local ,idx)
(local ,idx)
,.(map (lambda (v r) `(= ,v ,(caddr r))) iv itrs)
,.(if (length= itrs 1)
'()
`((= ,prod (call (top product) ,@iv))))
(= ,isz (call (top IteratorSize) ,overall-itr))
(= ,szunk (call (core isa) ,isz (top SizeUnknown)))
(if ,szunk
(= ,result (call (curly (core Array) ,ty 1) (core undef) 0))
(= ,result (call (top _array_for) ,ty ,overall-itr ,isz)))
(= ,result (call (top _array_for) ,ty ,overall-itr ,isz))
(= ,idx (call (top first) (call (top LinearIndices) ,result)))
,(construct-loops (reverse itrs) (reverse iv))
,result)))))
Expand Down
13 changes: 8 additions & 5 deletions test/errorshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ end

# Test that implementation detail of include() is hidden from the user by default
let bt = try
include("testhelpers/include_error.jl")
@noinline include("testhelpers/include_error.jl")
catch
catch_backtrace()
end
Expand All @@ -740,7 +740,7 @@ end
# Test backtrace printing
module B
module C
f(x; y=2.0) = error()
@noinline f(x; y=2.0) = error()
end
module D
import ..C: f
Expand All @@ -749,7 +749,8 @@ module B
end

@testset "backtrace" begin
bt = try B.D.g()
bt = try
B.D.g()
catch
catch_backtrace()
end
Expand Down Expand Up @@ -777,15 +778,17 @@ if Sys.isapple() || (Sys.islinux() && Sys.ARCH === :x86_64)
pair_repeater_b() = pair_repeater_a()

@testset "repeated stack frames" begin
let bt = try single_repeater()
let bt = try
single_repeater()
catch
catch_backtrace()
end
bt_str = sprint(Base.show_backtrace, bt)
@test occursin(r"repeats \d+ times", bt_str)
end

let bt = try pair_repeater_a()
let bt = try
pair_repeater_a()
catch
catch_backtrace()
end
Expand Down
11 changes: 7 additions & 4 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,14 @@ let (a, b) = (1:3, [4 6;
end

# collect stateful iterator
let
itr = (i+1 for i in Base.Stateful([1,2,3]))
let itr
itr = Iterators.Stateful(Iterators.map(identity, 1:5))
@test collect(itr) == 1:5
@test collect(itr) == Int[] # Stateful do not preserve shape
itr = (i+1 for i in Base.Stateful([1, 2, 3]))
@test collect(itr) == [2, 3, 4]
A = zeros(Int, 0, 0)
itr = (i-1 for i in Base.Stateful(A))
@test collect(itr) == Int[] # Stateful do not preserve shape
itr = (i-1 for i in Base.Stateful(zeros(Int, 0, 0)))
@test collect(itr) == Int[] # Stateful do not preserve shape
end

Expand Down

0 comments on commit 4c90ed9

Please sign in to comment.