From f62a380368484913dd022c99055056a027268134 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 23 Sep 2024 18:32:04 +0530 Subject: [PATCH] Specialize indexing triangular matrices with BandIndex (#55644) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With this, certain indexing operations involving a `BandIndex` may be evaluated as constants. This isn't used directly presently, but might allow for more performant broadcasting in the future. With this, ```julia julia> n = 3; T = Tridiagonal(rand(n-1), rand(n), rand(n-1)); julia> @code_warntype ((T,j) -> UpperTriangular(T)[LinearAlgebra.BandIndex(2,j)])(T, 1) MethodInstance for (::var"#17#18")(::Tridiagonal{Float64, Vector{Float64}}, ::Int64) from (::var"#17#18")(T, j) @ Main REPL[12]:1 Arguments #self#::Core.Const(var"#17#18"()) T::Tridiagonal{Float64, Vector{Float64}} j::Int64 Body::Float64 1 ─ %1 = Main.UpperTriangular(T)::UpperTriangular{Float64, Tridiagonal{Float64, Vector{Float64}}} │ %2 = LinearAlgebra.BandIndex::Core.Const(LinearAlgebra.BandIndex) │ %3 = (%2)(2, j)::Core.PartialStruct(LinearAlgebra.BandIndex, Any[Core.Const(2), Int64]) │ %4 = Base.getindex(%1, %3)::Core.Const(0.0) └── return %4 ``` The indexing operation may be evaluated at compile-time, as the band index is constant-propagated. --- stdlib/LinearAlgebra/src/bidiag.jl | 5 ++-- stdlib/LinearAlgebra/src/dense.jl | 2 +- stdlib/LinearAlgebra/src/triangular.jl | 14 +++++++++ stdlib/LinearAlgebra/test/triangular.jl | 38 ++++++++++++++++++++++++- 4 files changed, 55 insertions(+), 4 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 0aab9ceeca6b9..e5482cbba5595 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -166,10 +166,11 @@ end end @inline function getindex(A::Bidiagonal{T}, b::BandIndex) where T - @boundscheck checkbounds(A, _cartinds(b)) + @boundscheck checkbounds(A, b) if b.band == 0 return @inbounds A.dv[b.index] - elseif b.band == _offdiagind(A.uplo) + elseif b.band ∈ (-1,1) && b.band == _offdiagind(A.uplo) + # we explicitly compare the possible bands as b.band may be constant-propagated return @inbounds A.ev[b.index] else return bidiagzero(A, Tuple(_cartinds(b))...) diff --git a/stdlib/LinearAlgebra/src/dense.jl b/stdlib/LinearAlgebra/src/dense.jl index 62096cbb172f2..aacc5479bfa9d 100644 --- a/stdlib/LinearAlgebra/src/dense.jl +++ b/stdlib/LinearAlgebra/src/dense.jl @@ -110,7 +110,7 @@ norm2(x::Union{Array{T},StridedVector{T}}) where {T<:BlasFloat} = # Conservative assessment of types that have zero(T) defined for themselves haszero(::Type) = false haszero(::Type{T}) where {T<:Number} = isconcretetype(T) -@propagate_inbounds _zero(M::AbstractArray{T}, i, j) where {T} = haszero(T) ? zero(T) : zero(M[i,j]) +@propagate_inbounds _zero(M::AbstractArray{T}, inds...) where {T} = haszero(T) ? zero(T) : zero(M[inds...]) """ triu!(M, k::Integer) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 03634aa7d68e1..e1d61e4035966 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -236,6 +236,20 @@ Base.isstored(A::UpperTriangular, i::Int, j::Int) = @propagate_inbounds getindex(A::UpperTriangular, i::Int, j::Int) = i <= j ? A.data[i,j] : _zero(A.data,j,i) +# these specialized getindex methods enable constant-propagation of the band +Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitLowerTriangular{T}, b::BandIndex) where {T} + b.band < 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T)) +end +Base.@constprop :aggressive @propagate_inbounds function getindex(A::LowerTriangular, b::BandIndex) + b.band <= 0 ? A.data[b] : _zero(A.data, b) +end +Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitUpperTriangular{T}, b::BandIndex) where {T} + b.band > 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T)) +end +Base.@constprop :aggressive @propagate_inbounds function getindex(A::UpperTriangular, b::BandIndex) + b.band >= 0 ? A.data[b] : _zero(A.data, b) +end + _zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower" _zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper" diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index 42c5494f73e6f..ec9a3079e2643 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -6,7 +6,7 @@ debug = false using Test, LinearAlgebra, Random using LinearAlgebra: BlasFloat, errorbounds, full!, transpose!, UnitUpperTriangular, UnitLowerTriangular, - mul!, rdiv!, rmul!, lmul! + mul!, rdiv!, rmul!, lmul!, BandIndex const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") @@ -1286,4 +1286,40 @@ end end end +@testset "indexing with a BandIndex" begin + # these tests should succeed even if the linear index along + # the band isn't a constant, or type-inferred at all + M = rand(Int,2,2) + f(A,j, v::Val{n}) where {n} = Val(A[BandIndex(n,j)]) + function common_tests(M, ind) + j = ind[] + @test @inferred(f(UpperTriangular(M), j, Val(-1))) == Val(0) + @test @inferred(f(UnitUpperTriangular(M), j, Val(-1))) == Val(0) + @test @inferred(f(UnitUpperTriangular(M), j, Val(0))) == Val(1) + @test @inferred(f(LowerTriangular(M), j, Val(1))) == Val(0) + @test @inferred(f(UnitLowerTriangular(M), j, Val(1))) == Val(0) + @test @inferred(f(UnitLowerTriangular(M), j, Val(0))) == Val(1) + end + common_tests(M, Any[1]) + + M = Diagonal([1,2]) + common_tests(M, Any[1]) + # extra tests for banded structure of the parent + for T in (UpperTriangular, UnitUpperTriangular) + @test @inferred(f(T(M), 1, Val(1))) == Val(0) + end + for T in (LowerTriangular, UnitLowerTriangular) + @test @inferred(f(T(M), 1, Val(-1))) == Val(0) + end + + M = Tridiagonal([1,2], [1,2,3], [1,2]) + common_tests(M, Any[1]) + for T in (UpperTriangular, UnitUpperTriangular) + @test @inferred(f(T(M), 1, Val(2))) == Val(0) + end + for T in (LowerTriangular, UnitLowerTriangular) + @test @inferred(f(T(M), 1, Val(-2))) == Val(0) + end +end + end # module TestTriangular