Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit the lifetime of the matrix handle to MKLSparse call #52

Merged
merged 6 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,56 +9,56 @@
matdescra(A::Transpose) = matdescra(A.parent)
matdescra(A::Adjoint) = matdescra(A.parent)

function cscmv!(transa::Char, α::T, matdescra::String,
function cscmv!(transA::Char, α::T, matdescrA::String,

Check warning on line 12 in src/deprecated.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecated.jl#L12

Added line #L12 was not covered by tests
A::AbstractSparseMatrix{T}, x::StridedVector{T},
β::T, y::StridedVector{T}) where {T <: BlasFloat}
check_transa(transa)
check_mat_op_sizes(y, A, transa, x, 'N')
check_trans(transA)
check_mat_op_sizes(y, A, transA, x, 'N')

Check warning on line 16 in src/deprecated.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecated.jl#L15-L16

Added lines #L15 - L16 were not covered by tests

mkl_call(Val{:mkl_TSmvI}(), typeof(A),
transa, A.m, A.n, α, matdescra,
transA, A.m, A.n, α, matdescrA,
A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y)
return y
end

function cscmm!(transa::Char, α::T, matdescra::String,
function cscmm!(transA::Char, α::T, matdescrA::String,

Check warning on line 24 in src/deprecated.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecated.jl#L24

Added line #L24 was not covered by tests
A::SparseMatrixCSC{T}, B::StridedMatrix{T},
β::T, C::StridedMatrix{T}) where {T <: BlasFloat}
check_transa(transa)
check_mat_op_sizes(C, A, transa, B, 'N')
check_trans(transA)
check_mat_op_sizes(C, A, transA, B, 'N')

Check warning on line 28 in src/deprecated.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecated.jl#L27-L28

Added lines #L27 - L28 were not covered by tests
mB, nB = size(B)
mC, nC = size(C)

mkl_call(Val{:mkl_TSmmI}(), typeof(A),
transa, A.m, nC, A.n, α, matdescra,
transA, A.m, nC, A.n, α, matdescrA,
A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, β, C, mC)
return C
end

function cscsv!(transa::Char, α::T, matdescra::String,
function cscsv!(transA::Char, α::T, matdescrA::String,

Check warning on line 38 in src/deprecated.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecated.jl#L38

Added line #L38 was not covered by tests
A::SparseMatrixCSC{T}, x::StridedVector{T},
y::StridedVector{T}) where {T <: BlasFloat}
n = checksquare(A)
check_transa(transa)
check_mat_op_sizes(y, A, transa, x, 'N')
check_trans(transA)
check_mat_op_sizes(y, A, transA, x, 'N')

Check warning on line 43 in src/deprecated.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecated.jl#L42-L43

Added lines #L42 - L43 were not covered by tests

mkl_call(Val{:mkl_TSsvI}(), typeof(A),
transa, A.m, α, matdescra,
transA, A.m, α, matdescrA,
A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, y)
return y
end

function cscsm!(transa::Char, α::T, matdescra::String,
function cscsm!(transA::Char, α::T, matdescrA::String,

Check warning on line 51 in src/deprecated.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecated.jl#L51

Added line #L51 was not covered by tests
A::SparseMatrixCSC{T}, B::StridedMatrix{T},
C::StridedMatrix{T}) where {T <: BlasFloat}
mB, nB = size(B)
mC, nC = size(C)
n = checksquare(A)
check_transa(transa)
check_mat_op_sizes(C, A, transa, B, 'N')
check_trans(transA)
check_mat_op_sizes(C, A, transA, B, 'N')

Check warning on line 58 in src/deprecated.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecated.jl#L57-L58

Added lines #L57 - L58 were not covered by tests

mkl_call(Val{:mkl_TSsmI}(), typeof(A),
transa, A.n, nC, α, matdescra,
transA, A.n, nC, α, matdescrA,
A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, C, mC)
return C
end
70 changes: 44 additions & 26 deletions src/generic.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Intermediate wrappers for the Sparse BLAS routines
# that check the parameters validity (including matrix dimensions checks)
# and convert Julia's matrix types to the MKL's matrix types.
# See https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-2/inspector-executor-sparse-blas-execution-routines.html
# for the detailed description of the wrapped functions.

# generates the reference to the MKL function from the template
@inline @generated function mkl_function(
::Val{F}, ::Type{S}
Expand All @@ -18,56 +24,68 @@ end
return body
end

function mv!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
# y := alpha * op(A) * x + beta * y
function mv!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
x::StridedVector{T}, beta::T, y::StridedVector{T}
) where T
check_transa(transa)
check_mat_op_sizes(y, A, transa, x, 'N')
check_trans(transA)
check_mat_op_sizes(y, A, transA, x, 'N')
hA = MKLSparseMatrix(A)
res = mkl_call(Val{:mkl_sparse_T_mvI}(), typeof(A),
transa, alpha, MKLSparseMatrix(A), descr, x, beta, y)
transA, alpha, hA, descr, x, beta, y)
destroy(hA)
check_status(res)
return y
end

function mm!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
x::StridedMatrix{T}, beta::T, y::StridedMatrix{T};
# C := alpha * op(A) * B + beta * C
function mm!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
B::StridedMatrix{T}, beta::T, C::StridedMatrix{T};
dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR
) where T
check_transa(transa)
check_mat_op_sizes(y, A, transa, x, 'N'; dense_layout)
columns = size(y, dense_layout == SPARSE_LAYOUT_COLUMN_MAJOR ? 2 : 1)
ldx = stride(x, 2)
ldy = stride(y, 2)
check_trans(transA)
check_mat_op_sizes(C, A, transA, B, 'N'; dense_layout)
columns = size(C, dense_layout == SPARSE_LAYOUT_COLUMN_MAJOR ? 2 : 1)
ldB = stride(B, 2)
ldC = stride(C, 2)
hA = MKLSparseMatrix(A)
res = mkl_call(Val{:mkl_sparse_T_mmI}(), typeof(A),
transa, alpha, MKLSparseMatrix(A), descr, dense_layout, x, columns, ldx, beta, y, ldy)
transA, alpha, hA, descr, dense_layout, B, columns, ldB, beta, C, ldC)
destroy(hA)
check_status(res)
return y
return C
end

function trsv!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
# find y: op(A) * y = alpha * x
function trsv!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
x::StridedVector{T}, y::StridedVector{T}
) where T
checksquare(A)
check_transa(transa)
check_mat_op_sizes(y, A, transa, x, 'N')
check_trans(transA)
check_mat_op_sizes(y, A, transA, x, 'N')
hA = MKLSparseMatrix(A)
res = mkl_call(Val{:mkl_sparse_T_trsvI}(), typeof(A),
transa, alpha, MKLSparseMatrix(A), descr, x, y)
transA, alpha, hA, descr, x, y)
destroy(hA)
check_status(res)
return y
end

function trsm!(transa::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
x::StridedMatrix{T}, y::StridedMatrix{T};
# Y := alpha * inv(op(A)) * X
function trsm!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
X::StridedMatrix{T}, Y::StridedMatrix{T};
dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR
) where T
checksquare(A)
check_transa(transa)
check_mat_op_sizes(y, A, transa, x, 'N'; dense_layout)
columns = size(y, dense_layout == SPARSE_LAYOUT_COLUMN_MAJOR ? 2 : 1)
ldx = stride(x, 2)
ldy = stride(y, 2)
check_trans(transA)
check_mat_op_sizes(Y, A, transA, X, 'N'; dense_layout)
columns = size(Y, dense_layout == SPARSE_LAYOUT_COLUMN_MAJOR ? 2 : 1)
ldX = stride(X, 2)
ldY = stride(Y, 2)
hA = MKLSparseMatrix(A)
res = mkl_call(Val{:mkl_sparse_T_trsmI}(), typeof(A),
transa, alpha, MKLSparseMatrix(A), descr, dense_layout, x, columns, ldx, y, ldy)
transA, alpha, hA, descr, dense_layout, X, columns, ldX, Y, ldY)
destroy(hA)
check_status(res)
return y
return Y
end
29 changes: 15 additions & 14 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,27 @@ SimpleOrSpecialMat{T, M} = Union{M, SpecialMat{T, <:M}}
SimpleOrSpecialOrAdjMat{T, M} = Union{SimpleOrAdjMat{T, <:SimpleOrSpecialMat{T, <:M}},
SimpleOrSpecialMat{T, <:SimpleOrAdjMat{T, <:M}}}

unwrapa(A::AbstractMatrix) = A
unwrapa(A::Union{Adjoint, Transpose}) = unwrapa(parent(A))
unwrapa(A::SpecialMat) = unwrapa(parent(A))

# returns a tuple of transa, matdescra and unwrapped A
describe_and_unwrap(A::AbstractMatrix) = ('N', matrix_descr(A), unwrapa(A))
describe_and_unwrap(A::Adjoint) = ('C', matrix_descr(A), unwrapa(parent(A)))
describe_and_unwrap(A::Transpose) = ('T', matrix_descr(A), unwrapa(parent(A)))
# unwraps matrix A from Adjoint/Transpose transform
unwrap_trans(A::AbstractMatrix) = A
unwrap_trans(A::Union{Adjoint, Transpose}) = unwrap_trans(parent(A))
unwrap_trans(A::SpecialMat) = unwrap_trans(parent(A))

# returns a tuple of trans, matrix_descr and unwrapped A
describe_and_unwrap(A::AbstractMatrix) = ('N', matrix_descr(A), unwrap_trans(A))
describe_and_unwrap(A::Adjoint) = ('C', matrix_descr(A), unwrap_trans(parent(A)))
describe_and_unwrap(A::Transpose) = ('T', matrix_descr(A), unwrap_trans(parent(A)))
describe_and_unwrap(A::LowerTriangular{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'U', 'N'), unwrapa(A))
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'U', 'N'), unwrap_trans(A))
describe_and_unwrap(A::UpperTriangular{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'L', 'N'), unwrapa(A))
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'L', 'N'), unwrap_trans(A))
describe_and_unwrap(A::UnitLowerTriangular{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'U', 'U'), unwrapa(A))
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'U', 'U'), unwrap_trans(A))
describe_and_unwrap(A::UnitUpperTriangular{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'L', 'U'), unwrapa(A))
(T <: Adjoint ? 'C' : 'T', matrix_descr('T', 'L', 'U'), unwrap_trans(A))
describe_and_unwrap(A::Symmetric{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Transpose || (eltype(A) <: Real) ? 'N' : 'C', matrix_descr('S', A.uplo, 'N'), unwrapa(A))
(T <: Transpose || (eltype(A) <: Real) ? 'N' : 'C', matrix_descr('S', A.uplo, 'N'), unwrap_trans(A))
describe_and_unwrap(A::Hermitian{<:Any, T}) where T <: Union{Adjoint, Transpose} =
(T <: Adjoint || (eltype(A) <: Real) ? 'N' : 'T', matrix_descr('H', A.uplo, 'N'), unwrapa(A))
(T <: Adjoint || (eltype(A) <: Real) ? 'N' : 'T', matrix_descr('H', A.uplo, 'N'), unwrap_trans(A))

# 5-arg mul!()
function mul!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S},
Expand Down
Loading