diff --git a/Project.toml b/Project.toml index 61fa523..6bcead7 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,14 @@ authors = ["Mikkel Paltorp"] version = "1.0.0-DEV" [deps] +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] Aqua = "0.8" +ArrayInterface = "7.18.0" FLoops = "0.2" LinearAlgebra = "1" SafeTestsets = "0.1" diff --git a/src/BlockDiagonalMatrices.jl b/src/BlockDiagonalMatrices.jl index 94e090b..f5b771c 100644 --- a/src/BlockDiagonalMatrices.jl +++ b/src/BlockDiagonalMatrices.jl @@ -1,5 +1,6 @@ module BlockDiagonalMatrices +using ArrayInterface using FLoops using LinearAlgebra using SparseArrays @@ -10,6 +11,7 @@ blocks(B::AbstractBlockDiagonal) = B.blocks _is_square(A::AbstractMatrix) = size(A, 1) == size(A, 2) include("blockdiagonal.jl") +include("factorization.jl") export BlockDiagonal diff --git a/src/factorization.jl b/src/factorization.jl new file mode 100644 index 0000000..044d5cf --- /dev/null +++ b/src/factorization.jl @@ -0,0 +1,42 @@ + +""" +The result of a LU factorization of a block diagonal matrix. +""" +struct BlockDiagonalLU{T} <: AbstractBlockDiagonal{T} + blocks::Vector{T} +end + +function LinearAlgebra.issuccess(F::BlockDiagonalLU; kwargs...) + for b in blocks(F) + if !LinearAlgebra.issuccess(b; kwargs...) + return false + end + end + return true +end + +function ArrayInterface.lu_instance(A::AbstractBlockDiagonal) + return BlockDiagonalLU([ArrayInterface.lu_instance(b) for b in blocks(A)]) +end + +function LinearAlgebra.lu!(B::AbstractBlockDiagonal, args...; kwargs...) + BlockDiagonalLU([lu!(blk, args...; kwargs...) for blk in blocks(B)]) +end + +function LinearAlgebra.lu(B::AbstractBlockDiagonal, args...; kwargs...) + BlockDiagonalLU([lu(blk, args...; kwargs...) for blk in blocks(B)]) +end + +function LinearAlgebra.ldiv!(x::AbstractVecOrMat, A::BlockDiagonalLU, b::AbstractVecOrMat; kwargs...) + row_i = 1 + @assert size(x) == size(b) "dimensions of x and b must match" + @assert mapreduce(a -> size(a, 1), +, blocks(A)) == size(b, 1) "number of rows must match" + for block in blocks(A) + nrow = size(block, 1) + _x = view(x, row_i:(row_i + nrow - 1), :) + _b = view(b, row_i:(row_i + nrow - 1), :) + ldiv!(_x, block, _b; kwargs...) + row_i += nrow + end + x +end diff --git a/test/runtests.jl b/test/runtests.jl index 4ef9906..688112b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,3 +4,4 @@ using SafeTestsets @safetestset "Aqua testing " begin include("test_aqua.jl") end @safetestset "Square blocks " begin include("test_squareblocks.jl") end @safetestset "Rectangular blocks" begin include("test_rectangularblocks.jl") end +@safetestset "Factorization " begin include("test_factorization.jl") end diff --git a/test/test_factorization.jl b/test/test_factorization.jl new file mode 100644 index 0000000..b9d7ae1 --- /dev/null +++ b/test/test_factorization.jl @@ -0,0 +1,33 @@ +using BlockDiagonalMatrices +using LinearAlgebra +using Test + +@testset "LU" begin + x = BlockDiagonal([rand(3, 3), rand(3, 3)]) + lux1 = lu(Matrix(x)) + lux2 = lu(x) + lux3 = lu!(x) + @test all([b1.L ≈ b2.L && b1.U ≈ b2.U for (b1, b2) in zip(lux2.blocks, lux3.blocks)]) + @test lux1.L[1:3, 1:3] ≈ lux2.blocks[1].L + @test lux1.U[1:3, 1:3] ≈ lux2.blocks[1].U + @test lux1.L[4:6, 4:6] ≈ lux2.blocks[2].L + @test lux1.U[4:6, 4:6] ≈ lux2.blocks[2].U +end + +@testset "ldiv!" begin + @testset "Vector" begin + x = BlockDiagonal([rand(3, 3), rand(3, 3)]) + y = rand(6) + z1 = LinearAlgebra.ldiv!(similar(y), lu(x), y) + z2 = LinearAlgebra.ldiv!(similar(y), lu(Matrix(x)), y) + @test z1 ≈ z2 + end + + @testset "Matrix" begin + x = BlockDiagonal([rand(3, 3), rand(3, 3)]) + y = rand(6, 2) + z1 = LinearAlgebra.ldiv!(similar(y), lu(x), y) + z2 = LinearAlgebra.ldiv!(similar(y), lu(Matrix(x)), y) + @test z1 ≈ z2 + end +end