diff --git a/examples/ad/oper.jl b/examples/ad/oper.jl index 36fc431..60fef5e 100644 --- a/examples/ad/oper.jl +++ b/examples/ad/oper.jl @@ -8,16 +8,15 @@ let nothing end -using SciMLSensitivity, Zygote +using LinearAlgebra, SciMLSensitivity, Zygote -n = 3 -N = 9 +N = 16 u0 = rand(N) ps = rand(N) space = FourierSpace(N) -space = make_transform(space; isinplace=false) +space = make_transform(space, u0; isinplace=false) F = transformOp(space) Dx = gradientOp(space)[1] @@ -26,12 +25,12 @@ S = rand() * MatrixOperator(rand(N,N)) Di = DiagonalOperator(rand(N)) M = MatrixOperator(rand(N,N)) Af = AffineOperator(rand(N,N), rand(N,N), rand(N)) -Ad = rand(N,N) + rand(N,N) -Co = rand(N,N) * rand(N,N) +Ad = MatrixOperator(rand(N,N)) + MatrixOperator(rand(N,N)) +Co = MatrixOperator(rand(N,N)) * MatrixOperator(rand(N,N)) T = TensorProductOperator(rand(n,n), rand(n,n)) Id = IdentityOperator{N}() Z = NullOperator{N}() -Ad = SciMLOperators.AdjointOperator(rand(N,N) |> MatrixOperator) +Ao = SciMLOperators.AdjointOperator(rand(N,N) |> MatrixOperator) Tr = SciMLOperators.TransposedOperator(rand(N,N) |> MatrixOperator) α = ScalarOperator(rand()) @@ -46,6 +45,8 @@ loss = function(p) #v = Zygote.hook(Δ -> (println("Δv: ", Δ); Δ), v) v = Zygote.hook(Δ -> (println("Δv: ", typeof(Δ)); Δ), v) + #w = Dx * v ## Δ vanishes - ComposedOperator # INCORRECT + w = F \ F * v ## Δ vanishes - ComposedOperator # INCORRECT #w = Co * v ## Δ vanishes - ComposedOperator # INCORRECT #w = β * v ## Δ - AddedScalarOperator # ERROR @@ -57,7 +58,7 @@ loss = function(p) #w = M * v ## Δ ok - MatrixOperator #w = Id * v ## Δ ok - IdentityOperator #w = Z * v ## Δ ok - NullOperator - #w = Ad' * v ## Δ ok - AdjointOperator + #w = Ao' * v ## Δ ok - AdjointOperator #w = transpose(Tr) * v ## Δ ok - TransposedOperator #w = F\(F*v) ## Δ ok - FunctionOperator diff --git a/examples/burgers/model1.jl b/examples/burgers/model1.jl index 90f3212..6158fcb 100644 --- a/examples/burgers/model1.jl +++ b/examples/burgers/model1.jl @@ -115,8 +115,8 @@ function setup_model1(N, ν, filename; Dx = cache_operator(Dx, u0.η) ## ERROR This is killing the gradient F = cache_operator(transformOp(space), u0.η) - Dx = IdentityOperator(space) - Dx = F \ F + #Dx = IdentityOperator(space) + #Dx = F \ F """ time discr """ function dudt(u, p, t)