Skip to content

Commit

Permalink
fixed vanishing gradient problem in SciML/SciMLOperators.jl#94
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Aug 2, 2022
1 parent fd1b163 commit 2ecde1c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
17 changes: 9 additions & 8 deletions examples/ad/oper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@ let
nothing
end

using SciMLSensitivity, Zygote
using LinearAlgebra, SciMLSensitivity, Zygote

n = 3
N = 9
N = 16

u0 = rand(N)
ps = rand(N)

space = FourierSpace(N)
space = make_transform(space; isinplace=false)
space = make_transform(space, u0; isinplace=false)

F = transformOp(space)
Dx = gradientOp(space)[1]
Expand All @@ -26,12 +25,12 @@ S = rand() * MatrixOperator(rand(N,N))
Di = DiagonalOperator(rand(N))
M = MatrixOperator(rand(N,N))
Af = AffineOperator(rand(N,N), rand(N,N), rand(N))
Ad = rand(N,N) + rand(N,N)
Co = rand(N,N) * rand(N,N)
Ad = MatrixOperator(rand(N,N)) + MatrixOperator(rand(N,N))
Co = MatrixOperator(rand(N,N)) * MatrixOperator(rand(N,N))
T = TensorProductOperator(rand(n,n), rand(n,n))
Id = IdentityOperator{N}()
Z = NullOperator{N}()
Ad = SciMLOperators.AdjointOperator(rand(N,N) |> MatrixOperator)
Ao = SciMLOperators.AdjointOperator(rand(N,N) |> MatrixOperator)
Tr = SciMLOperators.TransposedOperator(rand(N,N) |> MatrixOperator)

α = ScalarOperator(rand())
Expand All @@ -46,6 +45,8 @@ loss = function(p)
#v = Zygote.hook(Δ -> (println("Δv: ", Δ); Δ), v)
v = Zygote.hook-> (println("Δv: ", typeof(Δ)); Δ), v)

#w = Dx * v ## Δ vanishes - ComposedOperator # INCORRECT
w = F \ F * v ## Δ vanishes - ComposedOperator # INCORRECT
#w = Co * v ## Δ vanishes - ComposedOperator # INCORRECT
#w = β * v ## Δ - AddedScalarOperator # ERROR

Expand All @@ -57,7 +58,7 @@ loss = function(p)
#w = M * v ## Δ ok - MatrixOperator
#w = Id * v ## Δ ok - IdentityOperator
#w = Z * v ## Δ ok - NullOperator
#w = Ad' * v ## Δ ok - AdjointOperator
#w = Ao' * v ## Δ ok - AdjointOperator
#w = transpose(Tr) * v ## Δ ok - TransposedOperator
#w = F\(F*v) ## Δ ok - FunctionOperator

Expand Down
4 changes: 2 additions & 2 deletions examples/burgers/model1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ function setup_model1(N, ν, filename;
Dx = cache_operator(Dx, u0.η) ## ERROR This is killing the gradient
F = cache_operator(transformOp(space), u0.η)

Dx = IdentityOperator(space)
Dx = F \ F
#Dx = IdentityOperator(space)
#Dx = F \ F

""" time discr """
function dudt(u, p, t)
Expand Down

0 comments on commit 2ecde1c

Please sign in to comment.