From b2523ae48ee2eb32aac48f21379221ed91e21296 Mon Sep 17 00:00:00 2001 From: Sacha Verweij Date: Thu, 28 Jul 2016 12:13:56 -0700 Subject: [PATCH] Make concatenations involving combinations of special matrices with special matrices, sparse matrices, or dense matrices/vectors yield sparse arrays. --- base/sparse/sparsematrix.jl | 14 +++++++++----- test/linalg/special.jl | 38 +++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index 821e04d4d18da0..acabcf7a0d1b74 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -3229,19 +3229,23 @@ function hcat(X::SparseMatrixCSC...) end -# Sparse/dense concatenation +# Sparse/special/dense concatenation -function hcat(Xin::Union{Vector, Matrix, SparseMatrixCSC}...) +# TODO: A similar definition also exists in base/linalg/bidiag.jl. These definitions should +# be consolidated in a more appropriate location, for example base/linalg/special.jl. +SpecialArrays = Union{Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal} + +function hcat(Xin::Union{Vector, Matrix, SparseMatrixCSC, SpecialArrays}...) X = SparseMatrixCSC[issparse(x) ? x : sparse(x) for x in Xin] hcat(X...) end -function vcat(Xin::Union{Vector, Matrix, SparseMatrixCSC}...) +function vcat(Xin::Union{Vector, Matrix, SparseMatrixCSC, SpecialArrays}...) X = SparseMatrixCSC[issparse(x) ? x : sparse(x) for x in Xin] vcat(X...) end -function hvcat(rows::Tuple{Vararg{Int}}, X::Union{Vector, Matrix, SparseMatrixCSC}...) +function hvcat(rows::Tuple{Vararg{Int}}, X::Union{Vector, Matrix, SparseMatrixCSC, SpecialArrays}...) nbr = length(rows) # number of block rows tmp_rows = Array{SparseMatrixCSC}(nbr) @@ -3253,7 +3257,7 @@ function hvcat(rows::Tuple{Vararg{Int}}, X::Union{Vector, Matrix, SparseMatrixCS vcat(tmp_rows...) end -function cat(catdims, Xin::Union{Vector, Matrix, SparseMatrixCSC}...) +function cat(catdims, Xin::Union{Vector, Matrix, SparseMatrixCSC, SpecialArrays}...) X = SparseMatrixCSC[issparse(x) ? x : sparse(x) for x in Xin] T = promote_eltype(Xin...) Base.cat_t(catdims, T, X...) diff --git a/test/linalg/special.jl b/test/linalg/special.jl index 89ddde5ee4b007..05105acd149a86 100644 --- a/test/linalg/special.jl +++ b/test/linalg/special.jl @@ -128,3 +128,41 @@ for typ in [UpperTriangular,LowerTriangular,Base.LinAlg.UnitUpperTriangular,Base @test Base.LinAlg.A_mul_Bc(atri,qrb[:Q]) ≈ full(atri) * qrb[:Q]' @test Base.LinAlg.A_mul_Bc!(copy(atri),qrb[:Q]) ≈ full(atri) * qrb[:Q]' end + +# Test that concatenations of combinations of special and other matrix types yield sparse arrays +let + N = 4 + # Test concatenating pairwise combinations of special matrices + diagmat = Diagonal(ones(N)) + bidiagmat = Bidiagonal(ones(N), ones(N-1), true) + tridiagmat = Tridiagonal(ones(N-1), ones(N), ones(N-1)) + symtridiagmat = SymTridiagonal(ones(N), ones(N-1)) + specialmats = (diagmat, bidiagmat, tridiagmat, symtridiagmat) + for specialmata in specialmats, specialmatb in specialmats + @test issparse(hcat(specialmata, specialmatb)) + @test issparse(vcat(specialmata, specialmatb)) + @test issparse(hvcat((1,1), specialmata, specialmatb)) + @test issparse(cat((1,2), specialmata, specialmatb)) + end + # Test concatenating pairwise combinations of special matrices with sparse matrices, + # dense matrices, or dense vectors + densevec = ones(N) + densemat = diagm(ones(N)) + spmat = spdiagm(ones(N)) + for specialmat in specialmats + # --> Tests applicable only to pairs of matrices + for othermat in (spmat, densemat) + @test issparse(vcat(specialmat, othermat)) + @test issparse(vcat(othermat, specialmat)) + end + # --> Tests applicable also to pairs including vectors + for specialmat in specialmats, othermatorvec in (spmat, densemat, densevec) + @test issparse(hcat(specialmat, othermatorvec)) + @test issparse(hcat(othermatorvec, specialmat)) + @test issparse(hvcat((2,), specialmat, othermatorvec)) + @test issparse(hvcat((2,), othermatorvec, specialmat)) + @test issparse(cat((1,2), specialmat, othermatorvec)) + @test issparse(cat((1,2), othermatorvec, specialmat)) + end + end +end