Skip to content
Open
6 changes: 1 addition & 5 deletions doc/specs/stdlib_specialmatrices.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Experimental

With the exception of `extended precision` and `quadruple precision`, all the types provided by `stdlib_specialmatrices` benefit from specialized kernels for matrix-vector products accessible via the common `spmv` interface.

- For `tridiagonal` matrices, the LAPACK `lagtm` backend is being used.
- For `tridiagonal` matrices, the backend is either LAPACK `lagtm` or the generalized routine `glagtm`, depending on the values and types of `alpha` and `beta`.

#### Syntax

Expand All @@ -110,10 +110,6 @@ With the exception of `extended precision` and `quadruple precision`, all the ty

- `op` (optional) : In-place operator identifier. Shall be a character(1) argument. It can have any of the following values: `N`: no transpose, `T`: transpose, `H`: hermitian or complex transpose.

@warning
Due to limitations of the underlying `lapack` driver, currently `alpha` and `beta` can only take one of the values `[-1, 0, 1]` for `tridiagonal` and `symtridiagonal` matrices. See `lagtm` for more details.
@endwarning

#### Examples

```fortran
Expand Down
1 change: 1 addition & 0 deletions example/specialmatrices/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
ADD_EXAMPLE(specialmatrices_dp_spmv)
ADD_EXAMPLE(specialmatrices_cdp_spmv)
ADD_EXAMPLE(tridiagonal_dp_type)
30 changes: 30 additions & 0 deletions example/specialmatrices/example_specialmatrices_cdp_spmv.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
program example_tridiagonal_matrix_cdp
use stdlib_linalg_constants, only: dp
use stdlib_specialmatrices, only: tridiagonal_cdp_type, tridiagonal, dense, spmv
implicit none

integer, parameter :: n = 5
type(tridiagonal_cdp_type) :: A
complex(dp) :: dl(n-1), dv(n), du(n-1)
complex(dp) :: x(n), y(n), y_dense(n)
integer :: i
complex(dp) :: alpha, beta

dl = [(cmplx(i,i, dp), i=1, n - 1)]
dv = [(cmplx(2*i,2*i, dp), i=1, n)]
du = [(cmplx(3*i,3*i, dp), i=1, n - 1)]

A = tridiagonal(dl, dv, du)

x = (1.0_dp, 0.0_dp)
y = (3.0_dp, -7.0_dp)
y_dense = (0.0_dp, 0.0_dp)
alpha = cmplx(2.0_dp, 3.0_dp)
beta = cmplx(-1.0_dp, 5.0_dp)

y_dense = alpha * matmul(dense(A), x) + beta * y
call spmv(A, x, y, alpha, beta)

print *, 'dense :', y_dense
print *, 'Tridiagonal :', y
end program example_tridiagonal_matrix_cdp
3 changes: 2 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ if (NOT STDLIB_NO_BITSET)
endif()
add_subdirectory(blas)
add_subdirectory(lapack)
add_subdirectory(lapack_extended)
if (NOT STDLIB_NO_STATS)
add_subdirectory(stats)
endif()
Expand Down Expand Up @@ -115,4 +116,4 @@ configure_stdlib_target(${PROJECT_NAME} f90Files fppFiles cppFiles)
target_link_libraries(${PROJECT_NAME} PUBLIC
$<$<NOT:$<BOOL:${STDLIB_NO_BITSET}>>:bitsets>
$<$<NOT:$<BOOL:${STDLIB_NO_STATS}>>:stats>
blas lapack)
blas lapack lapack_extended)
10 changes: 10 additions & 0 deletions src/lapack_extended/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
set(lapack_extended_fppFiles
../stdlib_kinds.fypp
stdlib_lapack_extended_base.fypp
stdlib_lapack_extended.fypp
)
set(lapack_extended_cppFiles
../stdlib_linalg_constants.fypp
)

configure_stdlib_target(lapack_extended "" lapack_extended_fppFiles lapack_extended_cppFiles)
85 changes: 85 additions & 0 deletions src/lapack_extended/stdlib_lapack_extended.fypp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#:include "common.fypp"
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
#:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES

submodule(stdlib_lapack_extended_base) stdlib_lapack_extended
implicit none
contains
#:for ik,it,ii in LINALG_INT_KINDS_TYPES
#:for k1,t1,s1 in KINDS_TYPES
pure module subroutine stdlib${ii}$_glagtm_${s1}$(trans, n, nrhs, alpha, dl, d, du, x, ldx, beta, b, ldb)
character, intent(in) :: trans
integer(${ik}$), intent(in) :: ldb, ldx, n, nrhs
${t1}$, intent(in) :: alpha, beta
${t1}$, intent(inout) :: b(ldb,*)
${t1}$, intent(in) :: d(*), dl(*), du(*), x(ldx,*)
Comment on lines +15 to +16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we modernize this interface?

Suggested change
${t1}$, intent(inout) :: b(ldb,*)
${t1}$, intent(in) :: d(*), dl(*), du(*), x(ldx,*)
${t1}$, intent(inout) :: b(:,:)
${t1}$, intent(in) :: d(:), dl(:), du(:), x(:,:)

What is the reason to keep it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jvdp1 Thanks for the review. I kept the LAPACK style interface(*) to stay close to lagtm, but I agree that assumed shape arrays would be more modern. I'm happy to switch to that, if you prefer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@perazz @jalvesz what do you think about that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such change should not be taken lightly. While assumed-shape is considered more safe than assumed-size for the simple reason that it enforces bound checking and helps with runtime problem detection, it also implies an API behavioral change on what is allowed or not to do with the interface.

With assumed-size declaration, one can have a 1D working array passed to the function which will be reinterpreted as a 2D array internally. with assumed-shape, this is no longer possible as ranks should match between caller and callee.

If this routine is intended to closely match but extending the capabilities of *lagtm, then I would suggest not changing the array declaration style.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such change should not be taken lightly.

I fully agree with you. As I don't have a full overview of all the old and newLAPACK features, a second opinion is needed.

If this routine is intended to closely match but extending the capabilities of *lagtm, then I would suggest not changing the array declaration style.

I agree with this justification.

@jalvesz @Mahmood-Sinan I consider this suggestion has being resolved. The PRcan be merged IMHO.


! Internal variables.
integer(${ik}$) :: i, j
${t1}$ :: temp
if(n == 0) then
return
endif
if(beta == 0.0_${k1}$) then
b(1:n, 1:nrhs) = 0.0_${k1}$
else
b(1:n, 1:nrhs) = beta * b(1:n, 1:nrhs)
end if

if(trans == 'N') then
do j = 1, nrhs
if(n == 1_${ik}$) then
temp = d(1_${ik}$) * x(1_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
else
temp = d(1_${ik}$) * x(1_${ik}$, j) + du(1_${ik}$) * x(2_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
do i = 2, n - 1
temp = dl(i - 1) * x(i - 1, j) + d(i) * x(i, j) + du(i) * x(i + 1, j)
b(i, j) = b(i, j) + alpha * temp
end do
temp = dl(n - 1) * x(n - 1, j) + d(n) * x(n, j)
b(n, j) = b(n, j) + alpha * temp
end if
end do
#:if t1.startswith('complex')
else if(trans == 'C') then
do j = 1, nrhs
if(n == 1_${ik}$) then
temp = conjg(d(1_${ik}$)) * x(1_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
else
temp = conjg(d(1_${ik}$)) * x(1_${ik}$, j) + conjg(dl(1_${ik}$)) * x(2_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
do i = 2, n - 1
temp = conjg(du(i - 1)) * x(i - 1, j) + conjg(d(i)) * x(i, j) + conjg(dl(i)) * x(i + 1, j)
b(i, j) = b(i, j) + alpha * temp
end do
temp = conjg(du(n - 1)) * x(n - 1, j) + conjg(d(n)) * x(n, j)
b(n, j) = b(n, j) + alpha * temp
end if
end do
#:endif
else
do j = 1, nrhs
if(n == 1_${ik}$) then
temp = d(1_${ik}$) * x(1_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
else
temp = d(1_${ik}$) * x(1_${ik}$, j) + dl(1_${ik}$) * x(2_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
do i = 2, n - 1
temp = du(i - 1) * x(i - 1, j) + d(i) * x(i, j) + dl(i) * x(i + 1, j)
b(i, j) = b(i, j) + alpha * temp
end do
temp = du(n - 1) * x(n - 1, j) + d(n) * x(n, j)
b(n, j) = b(n, j) + alpha * temp
end if
end do
end if
end subroutine stdlib${ii}$_glagtm_${s1}$
#:endfor
#:endfor

end submodule
22 changes: 22 additions & 0 deletions src/lapack_extended/stdlib_lapack_extended_base.fypp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#:include "common.fypp"
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
#:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES
module stdlib_lapack_extended_base
use stdlib_linalg_constants
implicit none

interface glagtm
#:for ik,it,ii in LINALG_INT_KINDS_TYPES
#:for k1,t1,s1 in KINDS_TYPES
pure module subroutine stdlib${ii}$_glagtm_${s1}$(trans, n, nrhs, alpha, dl, d, du, x, ldx, beta, b, ldb)
character, intent(in) :: trans
integer(${ik}$), intent(in) :: ldb, ldx, n, nrhs
${t1}$, intent(in) :: alpha, beta
${t1}$, intent(inout) :: b(ldb,*)
${t1}$, intent(in) :: d(*), dl(*), du(*), x(ldx,*)
end subroutine stdlib${ii}$_glagtm_${s1}$
#:endfor
#:endfor
end interface
end module
9 changes: 5 additions & 4 deletions src/stdlib_specialmatrices.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module stdlib_specialmatrices
use stdlib_constants
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
use stdlib_lapack_extended_base
implicit none
private
public :: tridiagonal
Expand Down Expand Up @@ -99,7 +100,7 @@ module stdlib_specialmatrices
!! Matrix dimension.
type(tridiagonal_${s1}$_type) :: A
!! Corresponding Tridiagonal matrix.
end function
end function

module function initialize_tridiagonal_impure_${s1}$(dl, dv, du, err) result(A)
!! Construct a `tridiagonal` matrix from the rank-1 arrays
Expand All @@ -122,7 +123,7 @@ module stdlib_specialmatrices
!! Error handling.
type(tridiagonal_${s1}$_type) :: A
!! Corresponding Tridiagonal matrix.
end function
end function
#:endfor
end interface

Expand All @@ -145,8 +146,8 @@ module stdlib_specialmatrices
type(tridiagonal_${s1}$_type), intent(in) :: A
${t1}$, intent(in), contiguous, target :: x${ranksuffix(rank)}$
${t1}$, intent(inout), contiguous, target :: y${ranksuffix(rank)}$
real(${k1}$), intent(in), optional :: alpha
real(${k1}$), intent(in), optional :: beta
${t1}$, intent(in), optional :: alpha
${t1}$, intent(in), optional :: beta
character(1), intent(in), optional :: op
end subroutine
#:endfor
Expand Down
34 changes: 29 additions & 5 deletions src/stdlib_specialmatrices_tridiagonal.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,18 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
type(tridiagonal_${s1}$_type), intent(in) :: A
${t1}$, intent(in), contiguous, target :: x${ranksuffix(rank)}$
${t1}$, intent(inout), contiguous, target :: y${ranksuffix(rank)}$
real(${k1}$), intent(in), optional :: alpha
real(${k1}$), intent(in), optional :: beta
${t1}$, intent(in), optional :: alpha
${t1}$, intent(in), optional :: beta
character(1), intent(in), optional :: op

! Internal variables.
real(${k1}$) :: alpha_, beta_
${t1}$ :: alpha_, beta_
integer(ilp) :: n, nrhs, ldx, ldy
character(1) :: op_
#:if t1.startswith('real')
logical :: is_alpha_special, is_beta_special
#:endif

#:if rank == 1
${t1}$, pointer :: xmat(:, :), ymat(:, :)
#:endif
Expand All @@ -171,6 +175,10 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
alpha_ = 1.0_${k1}$ ; if (present(alpha)) alpha_ = alpha
beta_ = 0.0_${k1}$ ; if (present(beta)) beta_ = beta
op_ = "N" ; if (present(op)) op_ = op
#:if t1.startswith('real')
is_alpha_special = (alpha_ == 1.0_${k1}$ .or. alpha_ == 0.0_${k1}$ .or. alpha_ == -1.0_${k1}$)
is_beta_special = (beta_ == 1.0_${k1}$ .or. beta_ == 0.0_${k1}$ .or. beta_ == -1.0_${k1}$)
#:endif

! Prepare Lapack arguments.
n = A%n ; ldx = n ; ldy = n ;
Expand All @@ -179,9 +187,25 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
#:if rank == 1
! Pointer trick.
xmat(1:n, 1:nrhs) => x ; ymat(1:n, 1:nrhs) => y
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
#:if t1.startswith('complex')
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
#:else
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
if(is_alpha_special .and. is_beta_special) then
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
else
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
end if
#:endif
#:else
#:if t1.startswith('complex')
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
#:else
if(is_alpha_special .and. is_beta_special) then
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
else
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
end if
#:endif
#:endif
end subroutine
#:endfor
Expand Down
33 changes: 33 additions & 0 deletions test/linalg/test_linalg_specialmatrices.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,39 @@ contains
if (allocated(error)) return
end do
end do

! Test y = A @ x for random values of alpha and beta
y1 = 0.0_wp
call random_number(alpha)
call random_number(beta)
call random_number(y2)
y1 = alpha * matmul(Amat, x) + beta * y2
call spmv(A, x, y2, alpha=alpha, beta=beta)
call check(error, all_close(y1, y2), .true.)
if (allocated(error)) return

! Test y = A.T @ x for random values of alpha and beta
y1 = 0.0_wp
call random_number(alpha)
call random_number(beta)
call random_number(y2)
y1 = alpha * matmul(transpose(Amat), x) + beta * y2
call spmv(A, x, y2, alpha=alpha, beta=beta, op="T")
call check(error, all_close(y1, y2), .true.)
if (allocated(error)) return

#:if t1.startswith('complex')
! Test y = A.H @ x for random values of alpha and beta
y1 = 0.0_wp
call random_number(alpha)
call random_number(beta)
call random_number(y2)
y1 = alpha * matmul(transpose(conjg((Amat))), x) + beta * y2
call spmv(A, x, y2, alpha=alpha, beta=beta, op="H")
call check(error, all_close(y1, y2), .true.)
if (allocated(error)) return
#:endif

end block
#:endfor
end subroutine
Expand Down
Loading