Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix collect on stateful iterators #42172

Merged
merged 4 commits into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -643,23 +643,38 @@ julia> collect(Float64, 1:2:5)
"""
collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr))

_collect(::Type{T}, itr, isz::HasLength) where {T} = copyto!(Vector{T}(undef, Int(length(itr)::Integer)), itr)
_collect(::Type{T}, itr, isz::HasShape) where {T} = copyto!(similar(Array{T}, axes(itr)), itr)
_collect(::Type{T}, itr, isz::Union{HasLength,HasShape}) where {T} =
copyto!(_array_for(T, isz, _similar_shape(itr, isz)), itr)
function _collect(::Type{T}, itr, isz::SizeUnknown) where T
a = Vector{T}()
for x in itr
push!(a,x)
push!(a, x)
end
return a
end

# make a collection similar to `c` and appropriate for collecting `itr`
_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown) where {T} = similar(c, T, 0)
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength) where {T} =
similar(c, T, Int(length(itr)::Integer))
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape) where {T} =
similar(c, T, axes(itr))
_similar_for(c, ::Type{T}, itr, isz) where {T} = similar(c, T)
_similar_for(c, ::Type{T}, itr, isz, shp) where {T} = similar(c, T)

_similar_shape(itr, ::SizeUnknown) = nothing
_similar_shape(itr, ::HasLength) = length(itr)::Integer
_similar_shape(itr, ::HasShape) = axes(itr)

_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown, ::Nothing) where {T} =
similar(c, T, 0)
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength, len::Integer) where {T} =
similar(c, T, len)
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape, axs) where {T} =
similar(c, T, axs)

# make a collection appropriate for collecting `itr::Generator`
_array_for(::Type{T}, ::SizeUnknown, ::Nothing) where {T} = Vector{T}(undef, 0)
_array_for(::Type{T}, ::HasLength, len::Integer) where {T} = Vector{T}(undef, Int(len))
_array_for(::Type{T}, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)

# used by syntax lowering for simple typed comprehensions
_array_for(::Type{T}, itr, isz) where {T} = _array_for(T, isz, _similar_shape(itr, isz))


"""
collect(collection)
Expand Down Expand Up @@ -698,10 +713,10 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A)
collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr))

_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) =
copyto!(_similar_for(cont, eltype(itr), itr, isz), itr)
copyto!(_similar_for(cont, eltype(itr), itr, isz, _similar_shape(itr, isz)), itr)

function _collect(cont, itr, ::HasEltype, isz::SizeUnknown)
a = _similar_for(cont, eltype(itr), itr, isz)
a = _similar_for(cont, eltype(itr), itr, isz, nothing)
for x in itr
push!(a,x)
end
Expand Down Expand Up @@ -759,24 +774,19 @@ else
end
end

_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr))
_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr))
_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len)
_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)

function collect(itr::Generator)
isz = IteratorSize(itr.iter)
et = @default_eltype(itr)
if isa(isz, SizeUnknown)
return grow_to!(Vector{et}(), itr)
else
shape = isz isa HasLength ? length(itr) : axes(itr)
shp = _similar_shape(itr, isz)
y = iterate(itr)
if y === nothing
return _array_for(et, itr.iter, isz)
return _array_for(et, isz, shp)
end
v1, st = y
dest = _array_for(typeof(v1), itr.iter, isz, shape)
dest = _array_for(typeof(v1), isz, shp)
# The typeassert gives inference a helping hand on the element type and dimensionality
# (work-around for #28382)
et′ = et <: Type ? Type : et
Expand All @@ -786,15 +796,22 @@ function collect(itr::Generator)
end

_collect(c, itr, ::EltypeUnknown, isz::SizeUnknown) =
grow_to!(_similar_for(c, @default_eltype(itr), itr, isz), itr)
grow_to!(_similar_for(c, @default_eltype(itr), itr, isz, nothing), itr)

function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape})
et = @default_eltype(itr)
shp = _similar_shape(itr, isz)
y = iterate(itr)
if y === nothing
return _similar_for(c, @default_eltype(itr), itr, isz)
return _similar_for(c, et, itr, isz, shp)
end
v1, st = y
collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st)
dest = _similar_for(c, typeof(v1), itr, isz, shp)
# The typeassert gives inference a helping hand on the element type and dimensionality
# (work-around for #28382)
et′ = et <: Type ? Type : et
RT = dest isa AbstractArray ? AbstractArray{<:et′, ndims(dest)} : Any
collect_to_with_first!(dest, v1, itr, st)::RT
end

function collect_to_with_first!(dest::AbstractArray, v1, itr, st)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function inline_into_block!(state::CFGInliningState, block::Int)
new_range = state.first_bb+1:block
l = length(state.new_cfg_blocks)
state.bb_rename[new_range] = (l+1:l+length(new_range))
append!(state.new_cfg_blocks, map(copy, state.cfg.blocks[new_range]))
append!(state.new_cfg_blocks, (copy(block) for block in state.cfg.blocks[new_range]))
push!(state.merged_orig_blocks, last(new_range))
end
state.first_bb = block
Expand Down
29 changes: 2 additions & 27 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,6 @@ function try_compute_fieldidx_args(typ::DataType, args::Vector{Any})
return try_compute_fieldidx(typ, field)
end

function lift_defuse(cfg::CFG, ssa::SSADefUse)
# We remove from `uses` any block where all uses are dominated
# by a def. This prevents insertion of dead phi nodes at the top
# of such a block if that block happens to be in a loop
ordered = Tuple{Int, Int, Bool}[(x, block_for_inst(cfg, x), true) for x in ssa.uses]
for x in ssa.defs
push!(ordered, (x, block_for_inst(cfg, x), false))
end
ordered = sort(ordered, by=x->x[1])
bb_defs = Int[]
bb_uses = Int[]
last_bb = last_def_bb = 0
for (_, bb, is_use) in ordered
if bb != last_bb && is_use
push!(bb_uses, bb)
end
last_bb = bb
if last_def_bb != bb && !is_use
push!(bb_defs, bb)
last_def_bb = bb
end
end
SSADefUse(bb_uses, bb_defs, Int[])
end

function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int)
# TODO: This can be much faster by looking at current level and only
# searching for those blocks in a sorted order
Expand Down Expand Up @@ -1209,12 +1184,12 @@ function cfg_simplify!(ir::IRCode)
# Compute (renamed) successors and predecessors given (renamed) block
function compute_succs(i)
orig_bb = follow_merged_succ(result_bbs[i])
return map(i -> bb_rename_succ[i], bbs[orig_bb].succs)
return Int[bb_rename_succ[i] for i in bbs[orig_bb].succs]
end
function compute_preds(i)
orig_bb = result_bbs[i]
preds = bbs[orig_bb].preds
return map(pred -> bb_rename_pred[pred], preds)
return Int[bb_rename_pred[pred] for pred in preds]
end

BasicBlock[
Expand Down
7 changes: 4 additions & 3 deletions base/compiler/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ show_unquoted(io::IO, val::Argument, indent::Int, prec::Int) = show_unquoted(io,

show_unquoted(io::IO, stmt::PhiNode, indent::Int, ::Int) = show_unquoted_phinode(io, stmt, indent, "%")
function show_unquoted_phinode(io::IO, stmt::PhiNode, indent::Int, prefix::String)
args = map(1:length(stmt.edges)) do i
args = String[let
e = stmt.edges[i]
v = !isassigned(stmt.values, i) ? "#undef" :
sprint() do io′
show_unquoted(io′, stmt.values[i], indent)
end
return "$prefix$e => $v"
end
"$prefix$e => $v"
end for i in 1:length(stmt.edges)
]
print(io, "φ ", '(')
join(io, args, ", ")
print(io, ')')
Expand Down
22 changes: 6 additions & 16 deletions base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@ function scan_entry!(result::Vector{SlotInfo}, idx::Int, @nospecialize(stmt))
end


function lift_defuse(cfg::CFG, defuse)
map(defuse) do slot
SlotInfo(
Int[block_for_inst(cfg, x) for x in slot.defs],
Int[block_for_inst(cfg, x) for x in slot.uses],
slot.any_newvar
)
end
end

function scan_slot_def_use(nargs::Int, ci::CodeInfo, code::Vector{Any})
nslots = length(ci.slotflags)
result = SlotInfo[SlotInfo() for i = 1:nslots]
Expand Down Expand Up @@ -524,7 +514,7 @@ function domsort_ssa!(ir::IRCode, domtree::DomTree)
return new_ir
end

function compute_live_ins(cfg::CFG, defuse)
function compute_live_ins(cfg::CFG, defuse #=::Union{SlotInfo,SSADefUse}=#)
# We remove from `uses` any block where all uses are dominated
# by a def. This prevents insertion of dead phi nodes at the top
# of such a block if that block happens to be in a loop
Expand Down Expand Up @@ -586,8 +576,8 @@ function recompute_type(node::Union{PhiNode, PhiCNode}, ci::CodeInfo, ir::IRCode
return new_typ
end

function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse,
slottypes::Vector{Any})
function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree,
defuses::Vector{SlotInfo}, slottypes::Vector{Any})
code = ir.stmts.inst
cfg = ir.cfg
left = Int[]
Expand Down Expand Up @@ -616,7 +606,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse,
for (_, exc) in catch_entry_blocks
phicnodes[exc] = Vector{Tuple{SlotNumber, NewSSAValue, PhiCNode}}()
end
@timeit "idf" for (idx, slot) in Iterators.enumerate(defuse)
@timeit "idf" for (idx, slot) in Iterators.enumerate(defuses)
# No uses => no need for phi nodes
isempty(slot.uses) && continue
# TODO: Restore this optimization
Expand Down Expand Up @@ -671,9 +661,9 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse,
end
# Perform SSA renaming
initial_incoming_vals = Any[
if 0 in defuse[x].defs
if 0 in defuses[x].defs
Argument(x)
elseif !defuse[x].any_newvar
elseif !defuses[x].any_newvar
undef_token
else
SSAValue(-2)
Expand Down
4 changes: 2 additions & 2 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,6 @@ length(t::ImmutableDict) = count(Returns(true), t)
isempty(t::ImmutableDict) = !isdefined(t, :parent)
empty(::ImmutableDict, ::Type{K}, ::Type{V}) where {K, V} = ImmutableDict{K,V}()

_similar_for(c::Dict, ::Type{Pair{K,V}}, itr, isz) where {K, V} = empty(c, K, V)
_similar_for(c::AbstractDict, ::Type{T}, itr, isz) where {T} =
_similar_for(c::AbstractDict, ::Type{Pair{K,V}}, itr, isz, len) where {K, V} = empty(c, K, V)
_similar_for(c::AbstractDict, ::Type{T}, itr, isz, len) where {T} =
throw(ArgumentError("for AbstractDicts, similar requires an element type of Pair;\n if calling map, consider a comprehension instead"))
2 changes: 1 addition & 1 deletion base/set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ empty(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()
# by default, a Set is returned
emptymutable(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()

_similar_for(c::AbstractSet, ::Type{T}, itr, isz) where {T} = empty(c, T)
_similar_for(c::AbstractSet, ::Type{T}, itr, isz, len) where {T} = empty(c, T)

function show(io::IO, s::Set)
if isempty(s)
Expand Down
6 changes: 4 additions & 2 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ void jl_foreach_reachable_mtable(void (*visit)(jl_methtable_t *mt, void *env), v
}
else {
foreach_mtable_in_module(jl_main_module, visit, env, &visited);
foreach_mtable_in_module(jl_core_module, visit, env, &visited);
}
JL_GC_POP();
}
Expand All @@ -493,14 +494,15 @@ static void reset_mt_caches(jl_methtable_t *mt, void *env)


jl_function_t *jl_typeinf_func = NULL;
size_t jl_typeinf_world = 0;
size_t jl_typeinf_world = 1;

JL_DLLEXPORT void jl_set_typeinf_func(jl_value_t *f)
{
size_t newfunc = jl_typeinf_world == 1 && jl_typeinf_func == NULL;
jl_typeinf_func = (jl_function_t*)f;
jl_typeinf_world = jl_get_tls_world_age();
++jl_world_counter; // make type-inference the only thing in this world
if (jl_typeinf_world == 0) {
if (newfunc) {
// give type inference a chance to see all of these
// TODO: also reinfer if max_world != ~(size_t)0
jl_array_t *unspec = jl_alloc_vec_any(0);
Expand Down
8 changes: 3 additions & 5 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -2734,7 +2734,7 @@
(check-no-return expr)
(if (has-break-or-continue? expr)
(error "break or continue outside loop"))
(let ((result (gensy))
(let ((result (make-ssavalue))
(idx (gensy))
(oneresult (make-ssavalue))
(prod (make-ssavalue))
Expand All @@ -2758,16 +2758,14 @@
(let ((overall-itr (if (length= itrs 1) (car iv) prod)))
`(scope-block
(block
(local ,result) (local ,idx)
(local ,idx)
,.(map (lambda (v r) `(= ,v ,(caddr r))) iv itrs)
,.(if (length= itrs 1)
'()
`((= ,prod (call (top product) ,@iv))))
(= ,isz (call (top IteratorSize) ,overall-itr))
(= ,szunk (call (core isa) ,isz (top SizeUnknown)))
(if ,szunk
(= ,result (call (curly (core Array) ,ty 1) (core undef) 0))
(= ,result (call (top _array_for) ,ty ,overall-itr ,isz)))
(= ,result (call (top _array_for) ,ty ,overall-itr ,isz))
(= ,idx (call (top first) (call (top LinearIndices) ,result)))
,(construct-loops (reverse itrs) (reverse iv))
,result)))))
Expand Down
13 changes: 8 additions & 5 deletions test/errorshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ end

# Test that implementation detail of include() is hidden from the user by default
let bt = try
include("testhelpers/include_error.jl")
@noinline include("testhelpers/include_error.jl")
catch
catch_backtrace()
end
Expand All @@ -740,7 +740,7 @@ end
# Test backtrace printing
module B
module C
f(x; y=2.0) = error()
@noinline f(x; y=2.0) = error()
end
module D
import ..C: f
Expand All @@ -749,7 +749,8 @@ module B
end

@testset "backtrace" begin
bt = try B.D.g()
bt = try
B.D.g()
catch
catch_backtrace()
end
Expand Down Expand Up @@ -777,15 +778,17 @@ if Sys.isapple() || (Sys.islinux() && Sys.ARCH === :x86_64)
pair_repeater_b() = pair_repeater_a()

@testset "repeated stack frames" begin
let bt = try single_repeater()
let bt = try
single_repeater()
catch
catch_backtrace()
end
bt_str = sprint(Base.show_backtrace, bt)
@test occursin(r"repeats \d+ times", bt_str)
end

let bt = try pair_repeater_a()
let bt = try
pair_repeater_a()
catch
catch_backtrace()
end
Expand Down
11 changes: 7 additions & 4 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,14 @@ let (a, b) = (1:3, [4 6;
end

# collect stateful iterator
let
itr = (i+1 for i in Base.Stateful([1,2,3]))
let itr
itr = Iterators.Stateful(Iterators.map(identity, 1:5))
@test collect(itr) == 1:5
@test collect(itr) == Int[] # Stateful do not preserve shape
itr = (i+1 for i in Base.Stateful([1, 2, 3]))
@test collect(itr) == [2, 3, 4]
A = zeros(Int, 0, 0)
itr = (i-1 for i in Base.Stateful(A))
@test collect(itr) == Int[] # Stateful do not preserve shape
itr = (i-1 for i in Base.Stateful(zeros(Int, 0, 0)))
@test collect(itr) == Int[] # Stateful do not preserve shape
end

Expand Down