diff --git a/base/linalg/special.jl b/base/linalg/special.jl index ccfd186b27f52..76e0d8f094cd4 100644 --- a/base/linalg/special.jl +++ b/base/linalg/special.jl @@ -7,13 +7,21 @@ convert{T}(::Type{Bidiagonal}, A::Diagonal{T})=Bidiagonal(A.diag, zeros(T, size( convert{T}(::Type{SymTridiagonal}, A::Diagonal{T})=SymTridiagonal(A.diag, zeros(T, size(A.diag,1)-1)) convert{T}(::Type{Tridiagonal}, A::Diagonal{T})=Tridiagonal(zeros(T, size(A.diag,1)-1), A.diag, zeros(T, size(A.diag,1)-1)) convert(::Type{UpperTriangular}, A::Diagonal) = UpperTriangular(full(A), :L) -convert(::Type{UnitUpperTriangular}, A::Diagonal) = UnitUpperTriangular(full(A), :L) convert(::Type{LowerTriangular}, A::Diagonal) = LowerTriangular(full(A), :L) -convert(::Type{UnitLowerTriangular}, A::Diagonal) = UnitLowerTriangular(full(A), :L) convert(::Type{LowerTriangular}, A::Bidiagonal) = !A.isupper ? LowerTriangular(full(A)) : throw(ArgumentError("Bidiagonal matrix must have lower off diagonal to be converted to LowerTriangular")) convert(::Type{UpperTriangular}, A::Bidiagonal) = A.isupper ? UpperTriangular(full(A)) : throw(ArgumentError("Bidiagonal matrix must have upper off diagonal to be converted to UpperTriangular")) convert(::Type{Matrix}, D::Diagonal) = diagm(D.diag) +function convert(::Type{UnitUpperTriangular}, A::Diagonal) + all(A.diag .== one(eltype(A))) || throw(ArgumentError("Matrix cannot be represented as UnitUpperTriangular")) + UnitUpperTriangular(full(A)) +end + +function convert(::Type{UnitLowerTriangular}, A::Diagonal) + all(A.diag .== one(eltype(A))) || throw(ArgumentError("Matrix cannot be represented as UnitLowerTriangular")) + UnitLowerTriangular(full(A)) +end + function convert(::Type{Diagonal}, A::Union(Bidiagonal, SymTridiagonal)) all(A.ev .== 0) || throw(ArgumentError("Matrix cannot be represented as Diagonal")) Diagonal(A.dv) diff --git a/test/linalg4.jl b/test/linalg4.jl index e7c2e14650f02..cf3995a4e7b4c 100644 --- a/test/linalg4.jl +++ b/test/linalg4.jl @@ -13,6 +13,10 @@ let a=[1.0:n;] debug && println("newtype is $(newtype)") @test full(convert(newtype, A)) == full(A) end + for newtype in [Base.LinAlg.UnitUpperTriangular, Base.LinAlg.UnitLowerTriangular] + @test_throws ArgumentError convert(newtype, A) + @test full(convert(newtype, Diagonal(ones(n)))) == eye(n) + end for isupper in (true, false) debug && println("isupper is $(isupper)") @@ -38,6 +42,8 @@ let a=[1.0:n;] for newtype in [Diagonal, Bidiagonal] @test_throws ArgumentError convert(newtype,A) end + A = SymTridiagonal(a, zeros(n-1)) + @test full(convert(Bidiagonal,A)) == full(A) A = Tridiagonal(zeros(n-1), [1.0:n;], zeros(n-1)) #morally Diagonal for newtype in [Diagonal, Bidiagonal, SymTridiagonal, Matrix]