Skip to content

Commit

Permalink
Remove kludgy VectorStyle and MatrixStyle
Browse files Browse the repository at this point in the history
These broadcast styles were introduced in response to #23939 (comment) as a way to limit the "greediness" of Sparse's broadcasting implementation -- sparse only wanted to allow known combinations of array types (including Array but not any AbstractArray). The idea was to allow us to gradually improve the sparse broadcast implementation over 1.x in a non-breaking manner.  Unfortunately, these special styles for Array make defining new styles in the heirarchy a bit of a pain (ref. #23939 (comment)), and it was making my life harder in getting the 1.0 breaking changes in.

This commit removes these special broadcast styles in favor of just having Sparse identify the cases itself and re-dispatch back into the default implementation in the cases it doesn't know how to handle.
  • Loading branch information
mbauman committed Mar 15, 2018
1 parent 983f0f9 commit e37c0e0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 60 deletions.
39 changes: 1 addition & 38 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,32 +160,6 @@ BroadcastStyle(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where N = a
BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
typeof(a)(_max(Val(M),Val(N)))

# FIXME
# The following definitions are necessary to limit SparseArray broadcasting to "plain Arrays"
# (see https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382).
# They should be deleted once the sparse broadcast infrastucture is capable of handling
# arbitrary AbstractArrays.
struct VectorStyle <: AbstractArrayStyle{1} end
struct MatrixStyle <: AbstractArrayStyle{2} end
const VMStyle = Union{VectorStyle,MatrixStyle}
# These lose to DefaultArrayStyle
VectorStyle(::Val{N}) where N = DefaultArrayStyle{N}()
MatrixStyle(::Val{N}) where N = DefaultArrayStyle{N}()

BroadcastStyle(::Type{<:Vector}) = VectorStyle()
BroadcastStyle(::Type{<:Matrix}) = MatrixStyle()

BroadcastStyle(::MatrixStyle, ::VectorStyle) = MatrixStyle()
BroadcastStyle(a::AbstractArrayStyle{Any}, ::VectorStyle) = a
BroadcastStyle(a::AbstractArrayStyle{Any}, ::MatrixStyle) = a
BroadcastStyle(a::AbstractArrayStyle{N}, ::VectorStyle) where N = typeof(a)(_max(Val(N), Val(1)))
BroadcastStyle(a::AbstractArrayStyle{N}, ::MatrixStyle) where N = typeof(a)(_max(Val(N), Val(2)))
BroadcastStyle(::VectorStyle, ::DefaultArrayStyle{N}) where N = DefaultArrayStyle(_max(Val(N), Val(1)))
BroadcastStyle(::MatrixStyle, ::DefaultArrayStyle{N}) where N = DefaultArrayStyle(_max(Val(N), Val(2)))
# to avoid the VectorStyle(::Val) constructor we also need the following
BroadcastStyle(::VectorStyle, ::MatrixStyle) = MatrixStyle()
# end FIXME

## Allocating the output container
"""
broadcast_similar(f, ::BroadcastStyle, ::Type{ElType}, inds, As...)
Expand All @@ -205,17 +179,6 @@ broadcast_similar(f, ::ArrayConflict, ::Type{ElType}, inds::Indices, As...) wher
broadcast_similar(f, ::ArrayConflict, ::Type{Bool}, inds::Indices, As...) =
similar(BitArray, inds)

# FIXME: delete when we get rid of VectorStyle and MatrixStyle
broadcast_similar(f, ::VectorStyle, ::Type{ElType}, inds::Indices{1}, As...) where ElType =
similar(Vector{ElType}, inds)
broadcast_similar(f, ::MatrixStyle, ::Type{ElType}, inds::Indices{2}, As...) where ElType =
similar(Matrix{ElType}, inds)
broadcast_similar(f, ::VectorStyle, ::Type{Bool}, inds::Indices{1}, As...) =
similar(BitArray, inds)
broadcast_similar(f, ::MatrixStyle, ::Type{Bool}, inds::Indices{2}, As...) =
similar(BitArray, inds)
# end FIXME

## Computing the result's indices. Most types probably won't need to specialize this.
broadcast_indices() = ()
broadcast_indices(::Type{T}) where T = ()
Expand Down Expand Up @@ -628,7 +591,7 @@ julia> string.(("one","two","three","four"), ": ", 1:4)
broadcast(f, s, combine_eltypes(f, A, Bs...), combine_indices(A, Bs...),
A, Bs...)

const NonleafHandlingTypes = Union{DefaultArrayStyle,ArrayConflict,VectorStyle,MatrixStyle}
const NonleafHandlingTypes = Union{DefaultArrayStyle,ArrayConflict}

@inline function broadcast(f, s::NonleafHandlingTypes, ::Type{ElType}, inds::Indices, As...) where ElType
if !Base.isconcretetype(ElType)
Expand Down
1 change: 0 additions & 1 deletion stdlib/Pkg3/src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,6 @@ precompile(Tuple{typeof(Core.Compiler.getindex), Tuple{String, typeof(Base.info)
precompile(Tuple{typeof(Core.Compiler.getindex), Tuple{UInt128, UInt128}, Int64})
precompile(Tuple{typeof(Core.Compiler.getindex), Tuple{UInt32}, Int64})
precompile(Tuple{typeof(Core.Compiler.getindex), Tuple{typeof(Pkg3.BinaryProvider.parse_tar_list)}, Int64})
precompile(Tuple{typeof(Core.Compiler.getindex), Type{Any}, Core.Compiler.Const, Core.Compiler.Const, Type{Base.Broadcast.VectorStyle}, Core.Compiler.Const, Type{Tuple{Base.OneTo{Int64}}}, Type{Tuple{Array{Base.SubString{String}, 1}}}})
precompile(Tuple{typeof(Core.Compiler.getindex), Type{Any}, GlobalRef, Bool, typeof(Pkg3.Types.parse_toml), Expr})
precompile(Tuple{typeof(Core.Compiler.getindex), Type{Any}, GlobalRef, Core.SSAValue, Bool, Expr})
precompile(Tuple{typeof(Core.Compiler.getindex), Type{Any}, GlobalRef, Core.SlotNumber, QuoteNode, Expr})
Expand Down
46 changes: 25 additions & 21 deletions stdlib/SparseArrays/src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -973,35 +973,39 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
# and rebroadcast. otherwise, divert to generic AbstractArray broadcast code.

struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse()

PromoteToSparse(::Val{0}) = PromoteToSparse()
PromoteToSparse(::Val{1}) = PromoteToSparse()
PromoteToSparse(::Val{2}) = PromoteToSparse()
PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()

Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse()
Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()

# FIXME: switch to DefaultArrayStyle once we can delete VectorStyle and MatrixStyle
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse()
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Vector} where T}) = Broadcast.MatrixStyle() # Adjoint not yet defined when broadcast.jl loaded
Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Vector} where T}) = Broadcast.MatrixStyle() # Transpose not yet defined when broadcast.jl loaded
const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse()
Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.VectorStyle) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.MatrixStyle) = PromoteToSparse()
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.DefaultArrayStyle{N}) where N =
Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(1)))
Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) where N =
Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(2)))
# end FIXME

broadcast(f, ::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N} =
broadcast(f, map(_sparsifystructured, As)...)
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse()
Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()

# FIXME: currently sparse broadcasts are only well-tested on known array types, while any AbstractArray
# could report itself as a DefaultArrayStyle().
# See https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382 for more details
is_supported_sparse_broadcast() = true
is_supported_sparse_broadcast(::AbstractArray, rest...) = false
is_supported_sparse_broadcast(::AbstractSparseArray, rest...) = is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(::StructuredMatrix, rest...) = is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(::Array, rest...) = is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_supported_sparse_broadcast(t.parent, rest...)
is_supported_sparse_broadcast(x, rest...) = BroadcastStyle(typeof(x)) === Broadcast.Scalar() && is_supported_sparse_broadcast(rest...)
function broadcast(f, s::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N}
if is_supported_sparse_broadcast(As...)
return broadcast(f, map(_sparsifystructured, As)...)
else
return broadcast(f, Broadcast.ArrayConflict(), nothing, nothing, As...)
end
end

# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
Expand Down

0 comments on commit e37c0e0

Please sign in to comment.