diff --git a/Project.toml b/Project.toml index d1f39b93..260c1624 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLOperators" uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" authors = ["xtalax "] -version = "0.1.9" +version = "0.1.10" [deps] ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" diff --git a/src/basic.jl b/src/basic.jl index 9abc887a..64d1a8c8 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -293,6 +293,7 @@ struct AddedOperator{T, ops::O function AddedOperator(ops) + @assert !isempty(ops) T = promote_type(eltype.(ops)...) new{T,typeof(ops)}(ops) end @@ -414,6 +415,7 @@ struct ComposedOperator{T,O,C} <: AbstractSciMLOperator{T} isset::Bool function ComposedOperator(ops, cache, isset::Bool) + @assert !isempty(ops) for i in reverse(2:length(ops)) opcurr = ops[i] opnext = ops[i-1] @@ -518,8 +520,27 @@ for fact in ( end # operator application -Base.:*(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op * acc, reverse(L.ops); init=u) -Base.:\(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op \ acc, L.ops; init=u) +# https://github.com/SciML/SciMLOperators.jl/pull/94 +#Base.:*(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op * acc, reverse(L.ops); init=u) +#Base.:\(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op \ acc, L.ops; init=u) + +function Base.:\(L::ComposedOperator, u::AbstractVecOrMat) + v = u + for op in L.ops + v = op \ v + end + + v +end + +function Base.:*(L::ComposedOperator, u::AbstractVecOrMat) + v = u + for op in reverse(L.ops) + v = op * v + end + + v +end function cache_self(L::ComposedOperator, u::AbstractVecOrMat) vec = zero(u) diff --git a/src/scalar.jl b/src/scalar.jl index f04944fc..b5a7481c 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -115,6 +115,7 @@ struct AddedScalarOperator{T,O} <: AbstractSciMLScalarOperator{T} ops::O function AddedScalarOperator(ops::NTuple{N,AbstractSciMLScalarOperator}) where{N} + @assert !isempty(ops) T = promote_type(eltype.(ops)...) new{T,typeof(ops)}(ops) end @@ -141,12 +142,14 @@ for op in ( end function Base.convert(::Type{Number}, α::AddedScalarOperator{T}) where{T} - sum(op -> convert(Number, op), α.ops; init=zero(T)) + sum(op -> convert(Number, op), α.ops) end Base.conj(L::AddedScalarOperator) = AddedScalarOperator(conj.(L.ops)) getops(α::AddedScalarOperator) = α.ops +has_ldiv(α::AddedScalarOperator) = !iszero(convert(Number, α)) +has_ldiv!(α::AddedScalarOperator) = has_ldiv(α) """ Lazy composition of Scalar Operators @@ -155,6 +158,7 @@ struct ComposedScalarOperator{T,O} <: AbstractSciMLScalarOperator{T} ops::O function ComposedScalarOperator(ops::NTuple{N,AbstractSciMLScalarOperator}) where{N} + @assert !isempty(ops) T = promote_type(eltype.(ops)...) new{T,typeof(ops)}(ops) end @@ -188,4 +192,6 @@ Base.conj(L::ComposedScalarOperator) = ComposedScalarOperator(conj.(L.ops)) Base.:-(α::AbstractSciMLScalarOperator{T}) where{T} = (-one(T)) * α getops(α::ComposedScalarOperator) = α.ops +has_ldiv(α::ComposedScalarOperator) = all(has_ldiv, α.ops) +has_ldiv!(α::ComposedScalarOperator) = all(has_ldiv!, α.ops) # diff --git a/test/Project.toml b/test/Project.toml index 04f89be7..a84a6e5b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,3 +5,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/runtests.jl b/test/runtests.jl index 4046ad8b..1ecb3ef3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,5 +11,7 @@ if GROUP == "All" || GROUP == "OperatorInterface" @time @safetestset "Matrix Operators" begin include("matrix.jl") end @time @safetestset "Function Operator" begin include("func.jl") end @time @safetestset "Full tests" begin include("total.jl") end + + @time @safetestset "Zygote.jl" begin include("zygote.jl") end end end diff --git a/test/zygote.jl b/test/zygote.jl new file mode 100644 index 00000000..eefeeb0b --- /dev/null +++ b/test/zygote.jl @@ -0,0 +1,88 @@ +# +using SciMLOperators, Zygote, LinearAlgebra +using Random + +using SciMLOperators +using SciMLOperators: AbstractSciMLOperator, + IdentityOperator, NullOperator, + AdjointOperator, TransposedOperator, + InvertedOperator, InvertibleOperator, + BatchedDiagonalOperator, AddedOperator, ComposedOperator, + AddedScalarOperator, ComposedScalarOperator, ScaledOperator, + has_mul, has_ldiv + +Random.seed!(0) +n = 3 +N = n*n +K = 12 + +u0 = rand(N, K) +ps = rand(N) + +M = rand(N,N) + +for (op_type, A) in + ( + (IdentityOperator, IdentityOperator{N}()), + (NullOperator, NullOperator{N}()), + (MatrixOperator, MatrixOperator(rand(N,N))), + (AffineOperator, AffineOperator(rand(N,N), rand(N,N), rand(N,K))), + (ScaledOperator, rand() * MatrixOperator(rand(N,N))), + (InvertedOperator, InvertedOperator(rand(N,N) |> MatrixOperator)), + (InvertibleOperator, InvertibleOperator(rand(N,N) |> MatrixOperator)), + (BatchedDiagonalOperator, DiagonalOperator(rand(N,K))), + (AddedOperator, MatrixOperator(rand(N,N)) + MatrixOperator(rand(N,N))), + (ComposedOperator, MatrixOperator(rand(N,N)) * MatrixOperator(rand(N,N))), + (TensorProductOperator, TensorProductOperator(rand(n,n), rand(n,n))), + (FunctionOperator, FunctionOperator((u,p,t)->M*u, op_inverse=(u,p,t)->M\u, + T=Float64, isinplace=false, size=(N,N), + input_prototype=u0, output_prototype=u0)), + + ## ignore wrappers + #(AdjointOperator, AdjointOperator(rand(N,N) |> MatrixOperator) |> adjoint), + #(TransposedOperator, TransposedOperator(rand(N,N) |> MatrixOperator) |> transpose), + + (ScalarOperator, ScalarOperator(rand())), + (AddedScalarOperator, ScalarOperator(rand()) + ScalarOperator(rand())), + (ComposedScalarOperator, ScalarOperator(rand()) * ScalarOperator(rand())), + ) + + @assert A isa op_type + + loss_mul = function(p) + + v = Diagonal(p) * u0 + + w = A * v + + l = sum(w) + end + + loss_div = function(p) + + v = Diagonal(p) * u0 + + w = A \ v + + l = sum(w) + end + + @testset "$op_type" begin + l_mul = loss_mul(ps) + g_mul = Zygote.gradient(loss_mul, ps)[1] + + if A isa NullOperator + @test isa(g_mul, Nothing) + else + @test !isa(g_mul, Nothing) + end + + if has_ldiv(A) + l_div = loss_div(ps) + g_div = Zygote.gradient(loss_div, ps)[1] + + @test !isa(g_div, Nothing) + end + end +end +