Skip to content

Commit

Permalink
Combine diag methods for SymTridiagonal (#56014)
Browse files Browse the repository at this point in the history
Currently, there are two branches, one for an `eltype` that is a
`Number`, and the other that deals with generic `eltype`s. They do
similar things, so we may combine these, and use branches wherever
necessary to retain the performance. We also may replace explicit
materialized arrays by generators in `copyto!`. Overall, this improves
performance in `diag` for matrices of matrices, whereas the performance
in the common case of matrices of numbers remains unchanged.
```julia
julia> using StaticArrays, LinearAlgebra

julia> s = SMatrix{2,2}(1:4);

julia> S = SymTridiagonal(fill(s,100), fill(s,99));

julia> @Btime diag($S);
  1.292 μs (5 allocations: 7.16 KiB) # nightly, v"1.12.0-DEV.1317"
  685.012 ns (3 allocations: 3.19 KiB) # This PR
```
This PR also allows computing the `diag` for more values of the band
index `n`:
```julia
julia> diag(S,99)
1-element Vector{SMatrix{2, 2, Int64, 4}}:
 [0 0; 0 0]
```
This would work as long as `getindex` works for the `SymTridiagonal` for
that band, and the zero element may be converted to the `eltype`.
  • Loading branch information
jishnub authored Oct 11, 2024
1 parent 055e37e commit 41b1778
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 37 deletions.
41 changes: 12 additions & 29 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,44 +183,27 @@ issymmetric(S::SymTridiagonal) = true

tr(S::SymTridiagonal) = sum(symmetric, S.dv)

@noinline function throw_diag_outofboundserror(n, sz)
sz1, sz2 = sz
throw(ArgumentError(LazyString(lazy"requested diagonal, $n, must be at least $(-sz1) ",
lazy"and at most $sz2 for an $(sz1)-by-$(sz2) matrix")))
end
_diagiter(M::SymTridiagonal{<:Number}) = M.dv
_diagiter(M::SymTridiagonal) = (symmetric(x, :U) for x in M.dv)
_eviter_transposed(M::SymTridiagonal{<:Number}) = _evview(M)
_eviter_transposed(M::SymTridiagonal) = (transpose(x) for x in _evview(M))

function diag(M::SymTridiagonal{T}, n::Integer=0) where T<:Number
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
absn = abs(n)
if absn == 0
return copyto!(similar(M.dv, length(M.dv)), M.dv)
elseif absn == 1
return copyto!(similar(M.ev, length(M.dv)-1), _evview(M))
elseif absn <= size(M,1)
v = similar(M.dv, size(M,1)-absn)
for i in eachindex(v)
v[i] = M[BandIndex(n,i)]
end
return v
else
throw_diag_outofboundserror(n, size(M))
end
end
function diag(M::SymTridiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
v = similar(M.dv, max(0, length(M.dv)-abs(n)))
if n == 0
return copyto!(similar(M.dv, length(M.dv)), symmetric.(M.dv, :U))
return copyto!(v, _diagiter(M))
elseif n == 1
return copyto!(similar(M.ev, length(M.dv)-1), _evview(M))
return copyto!(v, _evview(M))
elseif n == -1
return copyto!(similar(M.ev, length(M.dv)-1), transpose.(_evview(M)))
elseif n <= size(M,1)
throw(ArgumentError("requested diagonal contains undefined zeros of an array type"))
return copyto!(v, _eviter_transposed(M))
else
throw_diag_outofboundserror(n, size(M))
for i in eachindex(v)
v[i] = M[BandIndex(n,i)]
end
end
return v
end

+(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, _evview(A)+_evview(B))
Expand Down
29 changes: 21 additions & 8 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,13 @@ end
@test (@inferred diag(A, 1))::typeof(d) == (mat_type == Tridiagonal ? du : dl)
@test (@inferred diag(A, -1))::typeof(d) == dl
@test (@inferred diag(A, n-1))::typeof(d) == zeros(elty, 1)
@test_throws ArgumentError diag(A, -n - 1)
@test_throws ArgumentError diag(A, n + 1)
if A isa SymTridiagonal
@test isempty(@inferred diag(A, -n - 1))
@test isempty(@inferred diag(A, n + 1))
else
@test_throws ArgumentError diag(A, -n - 1)
@test_throws ArgumentError diag(A, n + 1)
end
GA = mat_type == Tridiagonal ? mat_type(GenericArray.((dl, d, du))...) : mat_type(GenericArray.((d, dl))...)
@test (@inferred diag(GA))::typeof(GenericArray(d)) == GenericArray(d)
@test (@inferred diag(GA, -1))::typeof(GenericArray(d)) == GenericArray(dl)
Expand Down Expand Up @@ -501,10 +506,11 @@ end
@test @inferred diag(A, 1) == fill(M, n-1)
@test @inferred diag(A, 0) == fill(Symmetric(M), n)
@test @inferred diag(A, -1) == fill(transpose(M), n-1)
@test_throws ArgumentError diag(A, -2)
@test_throws ArgumentError diag(A, 2)
@test_throws ArgumentError diag(A, n+1)
@test_throws ArgumentError diag(A, -n-1)
@test_broken diag(A, -2) == fill(M, n-2)
@test_broken diag(A, 2) == fill(M, n-2)
@test isempty(@inferred diag(A, n+1))
@test isempty(@inferred diag(A, -n-1))

A[1,1] = Symmetric(2M)
@test A[1,1] == Symmetric(2M)
@test_throws ArgumentError A[1,1] = M
Expand All @@ -519,8 +525,8 @@ end
@test @inferred diag(A, 1) == fill(M, n-1)
@test @inferred diag(A, 0) == fill(M, n)
@test @inferred diag(A, -1) == fill(M, n-1)
@test_throws MethodError diag(A, -2)
@test_throws MethodError diag(A, 2)
@test_broken diag(A, -2) == fill(M, n-2)
@test_broken diag(A, 2) == fill(M, n-2)
@test_throws ArgumentError diag(A, n+1)
@test_throws ArgumentError diag(A, -n-1)

Expand All @@ -532,6 +538,13 @@ end
A = Tridiagonal(ev, dv, ev)
@test A == Matrix{eltype(A)}(A)
end

M = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
S = SymTridiagonal(fill(M,4), fill(M,3))
@test diag(S,2) == fill(zero(M), 2)
@test diag(S,-2) == fill(zero(M), 2)
@test isempty(diag(S,4))
@test isempty(diag(S,-4))
end

@testset "Issue 12068" begin
Expand Down

0 comments on commit 41b1778

Please sign in to comment.