Skip to content

Commit

Permalink
Merge pull request #25238 from Sacha0/higho2
Browse files Browse the repository at this point in the history
optimize and fix map/broadcast over Adjoint/Transpose vectors, take 2
  • Loading branch information
Sacha0 authored Dec 28, 2017
2 parents eb91796 + 5ff0863 commit 18bed58
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
33 changes: 9 additions & 24 deletions base/linalg/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,30 +136,15 @@ hcat(tvs::Transpose{T,Vector{T}}...) where {T} = _transpose_hcat(tvs...)
### higher order functions
# preserve Adjoint/Transpose wrapper around vectors
# to retain the associated semantics post-map/broadcast

# vectorfy takes an Adoint/Transpose-wrapped vector and builds
# an unwrapped vector with the entrywise-same contents
vectorfy(x::Number) = x
vectorfy(adjvec::AdjointAbsVec) = map(Adjoint, adjvec.parent)
vectorfy(transvec::TransposeAbsVec) = map(Transpose, transvec.parent)
vectorfyall(transformedvecs...) = (map(vectorfy, transformedvecs)...,)

# map over collections of Adjoint/Transpose-wrapped vectors
# note that the caller's operation `f` should be applied to the entries of the wrapped
# vectors, rather than the entires of the wrapped vector's parents. so first we use vectorfy
# to build unwrapped vectors with entrywise-same contents as the wrapped input vectors.
# then we map the caller's operation over that set of unwrapped vectors. but now re-wrapping
# the resulting vector would inappropriately transform the result vector's entries. so
# instead of simply mapping the caller's operation over the set of unwrapped vectors,
# we map Adjoint/Transpose composed with the caller's operationt over the set of unwrapped
# vectors. then re-wrapping the result vector yields a wrapped vector with the correct entries.
map(f, avs::AdjointAbsVec...) = Adjoint(map(Adjointf, vectorfyall(avs...)...))
map(f, tvs::TransposeAbsVec...) = Transpose(map(Transposef, vectorfyall(tvs...)...))

# broadcast over collections of Adjoint/Transpose-wrapped vectors and numbers
# similar explanation for these definitions as for map above
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = Adjoint(broadcast(Adjointf, vectorfyall(avs...)...))
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = Transpose(broadcast(Transposef, vectorfyall(tvs...) ...))
#
# note that the caller's operation f operates in the domain of the wrapped vectors' entries.
# hence the Adjoint->f->Adjoint shenanigans applied to the parent vectors' entries.
map(f, avs::AdjointAbsVec...) = Adjoint(map((xs...) -> Adjoint(f(Adjoint.(xs)...)), parent.(avs)...))
map(f, tvs::TransposeAbsVec...) = Transpose(map((xs...) -> Transpose(f(Transpose.(xs)...)), parent.(tvs)...))
quasiparentt(x) = parent(x); quasiparentt(x::Number) = x # to handle numbers in the defs below
quasiparenta(x) = parent(x); quasiparenta(x::Number) = conj(x) # to handle numbers in the defs below
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = Adjoint(broadcast((xs...) -> Adjoint(f(Adjoint.(xs)...)), quasiparenta.(avs)...))
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = Transpose(broadcast((xs...) -> Transpose(f(Transpose.(xs)...)), quasiparentt.(tvs)...))


### linear algebra
Expand Down
10 changes: 10 additions & 0 deletions test/linalg/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,16 @@ end
# trinary broadcast over wrapped vectors with concrete scalar eltype and numbers
@test broadcast(+, Adjoint(vec), 1, Adjoint(vec))::Adjoint{Complex{Int},Vector{Complex{Int}}} == avec + avec .+ 1
@test broadcast(+, Transpose(vec), 1, Transpose(vec))::Transpose{Complex{Int},Vector{Complex{Int}}} == tvec + tvec .+ 1
@test broadcast(+, Adjoint(vec), 1im, Adjoint(vec))::Adjoint{Complex{Int},Vector{Complex{Int}}} == avec + avec .+ 1im
@test broadcast(+, Transpose(vec), 1im, Transpose(vec))::Transpose{Complex{Int},Vector{Complex{Int}}} == tvec + tvec .+ 1im
# ascertain inference friendliness, ref. https://github.com/JuliaLang/julia/pull/25083#issuecomment-353031641
sparsevec = SparseVector([1.0, 2.0, 3.0])
@test map(-, Adjoint(sparsevec), Adjoint(sparsevec)) isa Adjoint{Float64,SparseVector{Float64,Int}}
@test map(-, Transpose(sparsevec), Transpose(sparsevec)) isa Transpose{Float64,SparseVector{Float64,Int}}
@test broadcast(-, Adjoint(sparsevec), Adjoint(sparsevec)) isa Adjoint{Float64,SparseVector{Float64,Int}}
@test broadcast(-, Transpose(sparsevec), Transpose(sparsevec)) isa Transpose{Float64,SparseVector{Float64,Int}}
@test broadcast(+, Adjoint(sparsevec), 1.0, Adjoint(sparsevec)) isa Adjoint{Float64,SparseVector{Float64,Int}}
@test broadcast(+, Transpose(sparsevec), 1.0, Transpose(sparsevec)) isa Transpose{Float64,SparseVector{Float64,Int}}
end

@testset "Adjoint/Transpose-wrapped vector multiplication" begin
Expand Down

0 comments on commit 18bed58

Please sign in to comment.