diff --git a/ext/ArrayLayoutsSparseArraysExt.jl b/ext/ArrayLayoutsSparseArraysExt.jl index 21d4a64..3d74958 100644 --- a/ext/ArrayLayoutsSparseArraysExt.jl +++ b/ext/ArrayLayoutsSparseArraysExt.jl @@ -1,25 +1,69 @@ module ArrayLayoutsSparseArraysExt using ArrayLayouts -using ArrayLayouts: _copyto! +using ArrayLayouts: _copyto!, Factorization using SparseArrays # Specifying the full namespace is necessary because of https://github.com/JuliaLang/julia/issues/48533 # See https://github.com/JuliaStats/LogExpFunctions.jl/pull/63 import ArrayLayouts.LinearAlgebra -import Base: copyto! +import Base: copyto!, \, / # ambiguity from sparsematrix.jl copyto!(dest::LayoutMatrix, src::SparseArrays.AbstractSparseMatrixCSC) = - _copyto!(dest, src) + _copyto!(dest, src) copyto!(dest::SubArray{<:Any,2,<:LayoutMatrix}, src::SparseArrays.AbstractSparseMatrixCSC) = - _copyto!(dest, src) + _copyto!(dest, src) @inline LinearAlgebra.dot(a::LayoutArray{<:Number}, b::SparseArrays.SparseVectorUnion{<:Number}) = - ArrayLayouts.dot(a,b) + ArrayLayouts.dot(a,b) @inline LinearAlgebra.dot(a::SparseArrays.SparseVectorUnion{<:Number}, b::LayoutArray{<:Number}) = - ArrayLayouts.dot(a,b) + ArrayLayouts.dot(a,b) + +# disambiguiate sparse matrix dispatches +macro _layoutldivsp(Typ) + ret = quote + (\)(x::SparseArrays.AbstractSparseMatrixCSC, A::$Typ; kwds...) = ArrayLayouts.ldiv(x,A; kwds...) + (/)(x::SparseArrays.AbstractSparseMatrixCSC, A::$Typ; kwds...) = ArrayLayouts.ldiv(x,A; kwds...) + end + esc(ret) +end + +macro layoutldivsp(Typ) + esc(quote + ArrayLayoutsSparseArraysExt.@_layoutldivsp $Typ + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UpperTriangular{T, <:$Typ{T}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UnitUpperTriangular{T, <:$Typ{T}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.LowerTriangular{T, <:$Typ{T}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UnitLowerTriangular{T, <:$Typ{T}} where T + + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UpperTriangular{T, <:SubArray{T,2,<:$Typ{T}}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UnitUpperTriangular{T, <:SubArray{T,2,<:$Typ{T}}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.LowerTriangular{T, <:SubArray{T,2,<:$Typ{T}}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UnitLowerTriangular{T, <:SubArray{T,2,<:$Typ{T}}} where T + + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UpperTriangular{T, <:LinearAlgebra.Adjoint{T,<:$Typ{T}}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UnitUpperTriangular{T, <:LinearAlgebra.Adjoint{T,<:$Typ{T}}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.LowerTriangular{T, <:LinearAlgebra.Adjoint{T,<:$Typ{T}}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UnitLowerTriangular{T, <:LinearAlgebra.Adjoint{T,<:$Typ{T}}} where T + + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UpperTriangular{T, <:LinearAlgebra.Transpose{T,<:$Typ{T}}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UnitUpperTriangular{T, <:LinearAlgebra.Transpose{T,<:$Typ{T}}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.LowerTriangular{T, <:LinearAlgebra.Transpose{T,<:$Typ{T}}} where T + ArrayLayoutsSparseArraysExt.@_layoutldivsp LinearAlgebra.UnitLowerTriangular{T, <:LinearAlgebra.Transpose{T,<:$Typ{T}}} where T + end) +end + +@_layoutldivsp LayoutVector + +macro layoutmatrixsp(Typ) + esc(quote + ArrayLayoutsSparseArraysExt.@layoutldivsp $Typ + end) +end + +@layoutmatrixsp LayoutMatrix end diff --git a/src/ldiv.jl b/src/ldiv.jl index 77ffd94..c9ec3a5 100644 --- a/src/ldiv.jl +++ b/src/ldiv.jl @@ -151,6 +151,8 @@ macro _layoutldiv(Typ) LinearAlgebra.ldiv!(A::LU, x::$Typ; kwds...) = ArrayLayouts.ldiv!(A,x; kwds...) LinearAlgebra.ldiv!(A::Cholesky, x::$Typ; kwds...) = ArrayLayouts.ldiv!(A,x; kwds...) LinearAlgebra.ldiv!(A::LinearAlgebra.QRCompactWY, x::$Typ; kwds...) = ArrayLayouts.ldiv!(A,x; kwds...) + # Type restriction to disambiguiate calls, see https://github.com/JuliaArrays/BlockArrays.jl/issues/319 + LinearAlgebra.ldiv!(A::LinearAlgebra.QRCompactWY{T2}, x::$Typ{T2}; kwds...) where T2<:BlasFloat = ArrayLayouts.ldiv!(A,x; kwds...) LinearAlgebra.ldiv!(A::Bidiagonal, B::$Typ; kwds...) = ArrayLayouts.ldiv!(A,B; kwds...) diff --git a/test/test_layoutarray.jl b/test/test_layoutarray.jl index ee66e17..4f77376 100644 --- a/test/test_layoutarray.jl +++ b/test/test_layoutarray.jl @@ -140,11 +140,7 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor() else @test ldiv!(lu(A), MyVector(copy(c))) ≈ A \ c end - if VERSION < v"1.9" || VERSION >= v"1.10-" - @test_throws ErrorException ldiv!(qr(A), MyVector(copy(c))) - else - @test_throws MethodError ldiv!(qr(A), MyVector(copy(c))) - end + @test_throws ErrorException ldiv!(qr(A), MyVector(copy(c))) # Missing materialize! overload @test_throws ErrorException ldiv!(eigen(randn(5,5)), c) @test ArrayLayouts.ldiv!(svd(A.A), copy(c)) ≈ ArrayLayouts.ldiv!(similar(c), svd(A.A), c) ≈ A \ c if VERSION ≥ v"1.8" @@ -213,6 +209,7 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor() @testset "layoutldiv" begin A = MyMatrix(randn(5,5)) + Asym = A'*A x = randn(5) X = randn(5,5) t = view(randn(10),[1,3,4,6,7]) @@ -221,8 +218,11 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor() T̃ = copy(T) B = Bidiagonal(randn(5),randn(4),:U) D = Diagonal(randn(5)) + S = sprand(5, 5, 0.8) + @test ldiv!(A, copy(x)) ≈ A\x @test A\t ≈ A\t̃ + @test A\t ≈ A\SparseVector(t̃) # QR is not general enough @test_broken ldiv!(A, t) ≈ A\t @test ldiv!(A, copy(X)) ≈ A\X @@ -238,6 +238,32 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor() @test A\MyVector(x) ≈ A\x @test A\MyMatrix(X) ≈ A\X + # Regression for https://github.com/JuliaArrays/BlockArrays.jl/issues/319 + @test qr(A)\MyVector(x) ≈ A\x + @test ldiv!(qr(A), MyVector(copy(x))) ≈ A\x + @test qr(A)\MyMatrix(X) ≈ A\X + @test ldiv!(qr(A), MyMatrix(copy(X))) ≈ A\X + + @test lu(A)\MyVector(x) ≈ A\x + @test ldiv!(lu(A), MyVector(copy(x))) ≈ A\x + @test lu(A)\MyMatrix(X) ≈ A\X + @test ldiv!(lu(A), MyMatrix(copy(X))) ≈ A\X + + @test cholesky(Asym)\MyVector(x) ≈ Asym\x + @test ldiv!(cholesky(Asym), MyVector(copy(x))) ≈ Asym\x + @test cholesky(Asym)\MyMatrix(X) ≈ Asym\X + @test ldiv!(cholesky(Asym), MyMatrix(copy(X))) ≈ Asym\X + + @test S\MyVector(x) ≈ S\x + @test S\MyMatrix(X) ≈ S\X + @test MyMatrix(S)\MyVector(x) ≈ S\x + @test MyMatrix(S)\MyMatrix(X) ≈ S\X + + @test MyVector(x)'/MyMatrix(S) ≈ x'/S + @test X/MyMatrix(S) ≈ X/S + @test MyVector(x)'/MyMatrix(S) ≈ x'/S + @test X/MyMatrix(S) ≈ X/S + if VERSION >= v"1.9" @test A/A ≈ A.A / A.A @test x' / A ≈ x' / A.A