Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround #28126, support SIMDing broadcast in more cases #30973

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,8 @@ julia> Broadcast.combine_axes(1, 1, 1)
()
```
"""
@inline combine_axes(A, B...) = broadcast_shape(axes(A), combine_axes(B...))
@inline combine_axes(A, B, C...) = broadcast_shape(axes(A), combine_axes(B, C...))
@inline combine_axes(A, B) = broadcast_shape(axes(A), axes(B))
combine_axes(A) = axes(A)

# shape (i.e., tuple-of-indices) inputs
Expand Down Expand Up @@ -502,7 +503,7 @@ function check_broadcast_shape(shp, Ashp::Tuple)
_bcsm(shp[1], Ashp[1]) || throw(DimensionMismatch("array could not be broadcast to match destination"))
check_broadcast_shape(tail(shp), tail(Ashp))
end
check_broadcast_axes(shp, A) = check_broadcast_shape(shp, axes(A))
@inline check_broadcast_axes(shp, A) = check_broadcast_shape(shp, axes(A))
# comparing many inputs
@inline function check_broadcast_axes(shp, A, As...)
check_broadcast_axes(shp, A)
Expand Down Expand Up @@ -864,13 +865,14 @@ broadcast_unalias(::Nothing, src) = src

# Preprocessing a `Broadcasted` does two things:
# * unaliases any arguments from `dest`
# * "extrudes" the arguments where it is advantageous to pre-compute the broadcasted indices
@inline preprocess(dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(dest, bc.args), bc.axes)
preprocess(dest, x) = extrude(broadcast_unalias(dest, x))
# * calls `f` on the arguments (typically `extrude`, which pre-computes the broadcasted indices where advantageous)
@inline preprocess(dest, bc) = preprocess(extrude, dest, bc)
@inline preprocess(f, dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(f, dest, bc.args), bc.axes)
preprocess(f, dest, x) = f(broadcast_unalias(dest, x))

@inline preprocess_args(dest, args::Tuple) = (preprocess(dest, args[1]), preprocess_args(dest, tail(args))...)
preprocess_args(dest, args::Tuple{Any}) = (preprocess(dest, args[1]),)
preprocess_args(dest, args::Tuple{}) = ()
@inline preprocess_args(f, dest, args::Tuple) = (preprocess(f, dest, args[1]), preprocess_args(f, dest, tail(args))...)
@inline preprocess_args(f, dest, args::Tuple{Any}) = (preprocess(f, dest, args[1]),)
preprocess_args(f, dest, args::Tuple{}) = ()

# Specialize this method if all you want to do is specialize on typeof(dest)
@inline function copyto!(dest::AbstractArray, bc::Broadcasted{Nothing})
Expand All @@ -882,13 +884,48 @@ preprocess_args(dest, args::Tuple{}) = ()
return copyto!(dest, A)
end
end
bc′ = preprocess(dest, bc)
@simd for I in eachindex(bc′)
@inbounds dest[I] = bc′[I]
# Ugly performance hack around issue #28126: determine if all arguments to the
# broadcast are sized such that the broadcasting core can statically determine
# whether a given dimension is "extruded" or not. If so, we don't need to check
# any array sizes within the inner loop. Ideally this really should be something
# that Julia and/or LLVM could figure out and eliminate... and indeed they can
# for limited numbers of arguments.
if _is_static_broadcast_28126(dest, bc)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to emphasize how awful this is: we now are generating two independent inner loops for every broadcast expression — one that might be slightly more likely to vectorize (depending on array sizes), and one normal one. They've both gotta inline to avoid allocating our broadcast expression tree.

bcs′ = preprocess(_nonextrude_28126, dest, bc)
@simd for I in eachindex(bcs′)
@inbounds dest[I] = bcs′[I]
end
else
bc′ = preprocess(extrude, dest, bc)
@simd for I in eachindex(bc′)
@inbounds dest[I] = bc′[I]
end
end
return dest
end

@inline _is_static_broadcast_28126(dest, bc::Broadcasted{Style}) where {Style} = _is_static_broadcast_28126_args(dest, bc.args)
_is_static_broadcast_28126(dest, x) = false
_is_static_broadcast_28126(dest, x::Union{Ref, Tuple, Type, Number, AbstractArray{<:Any,0}}) = true
_is_static_broadcast_28126(dest::AbstractArray, x::AbstractArray{<:Any,0}) = true
_is_static_broadcast_28126(dest::AbstractArray, x::AbstractArray{<:Any,1}) = axes(dest, 1) == axes(x, 1)
_is_static_broadcast_28126(dest::AbstractArray, x::AbstractArray) = axes(dest) == axes(x) # This can be better with other missing dimensions

@inline _is_static_broadcast_28126_args(dest, args::Tuple) = _is_static_broadcast_28126(dest, args[1]) && _is_static_broadcast_28126_args(dest, tail(args))
@inline _is_static_broadcast_28126_args(dest, args::Tuple{Any}) = _is_static_broadcast_28126(dest, args[1])
_is_static_broadcast_28126_args(dest, args::Tuple{}) = true

struct _NonExtruded28126{T}
x::T
end
@inline axes(b::_NonExtruded28126) = axes(b.x)
Base.@propagate_inbounds _broadcast_getindex(b::_NonExtruded28126, i) = _broadcast_getindex(b, i)
Base.@propagate_inbounds _broadcast_getindex(b::_NonExtruded28126{<:AbstractArray{<:Any,0}}, i) = b.x[]
Base.@propagate_inbounds _broadcast_getindex(b::_NonExtruded28126{<:AbstractVector}, i) = b.x[i[1]]
Base.@propagate_inbounds _broadcast_getindex(b::_NonExtruded28126{<:AbstractArray}, i) = b.x[i]
_nonextrude_28126(x::AbstractArray) = _NonExtruded28126(x)
_nonextrude_28126(x) = x

# Performance optimization: for BitArray outputs, we cache the result
# in a "small" Vector{Bool}, and then copy in chunks into the output
@inline function copyto!(dest::BitArray, bc::Broadcasted{Nothing})
Expand Down
8 changes: 8 additions & 0 deletions test/boundscheck_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,5 +251,13 @@ if bc_opt == bc_default || bc_opt == bc_off
@test occursin("vector.body", sprint(code_llvm, g27079, Tuple{Vector{Int}}))
end

# Ensure broadcasting can vectorize when bounds checks are off
if bc_opt != bc_on
function goo28126(u, uprev, k1, k2, k3, k4, k5, k6, k7)
@. u = uprev + 0.1*(0.1*k1 + 0.2*k2 + 0.3*k3 + 0.4*k4 + 0.5*k5 + 0.6*k6 + 0.7*k7)
nothing
end
@test occursin("vector.body", sprint(code_llvm, goo28126, NTuple{9, Vector{Float32}}))
end

end
10 changes: 10 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,16 @@ let
@test Dict(c .=> d) == Dict("foo" => 1, "bar" => 2)
end

@testset "large fusions vectorize and don't allocate (#28126)" begin
u, uprev, k1, k2, k3, k4, k5, k6, k7 = (ones(1000) for i in 1:9)
function goo(u, uprev, k1, k2, k3, k4, k5, k6, k7)
@. u = uprev + 0.1*(0.1*k1 + 0.2*k2 + 0.3*k3 + 0.4*k4 + 0.5*k5 + 0.6*k6 + 0.7*k7)
nothing
end
@allocated goo(u, uprev, k1, k2, k3, k4, k5, k6, k7)
@test @allocated(goo(u, uprev, k1, k2, k3, k4, k5, k6, k7)) == 0
end

# Broadcasted iterable/indexable APIs
let
bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5))
Expand Down