Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with Zygote.gradient for foldl, sum #1279

Open
vpuri3 opened this issue Aug 2, 2022 · 2 comments
Open

Error with Zygote.gradient for foldl, sum #1279

vpuri3 opened this issue Aug 2, 2022 · 2 comments
Labels
ChainRules adjoint -> rrule, and further integration

Comments

@vpuri3
Copy link

vpuri3 commented Aug 2, 2022

MWE

using Zygote, LinearAlgebra

N = 4
u0 = rand(N)
ps = rand(N)

mats = (rand(N,N), rand(N,N),) # (A, B,)
nums = (rand(), rand(),)       # (α, β,)

loss_m = function(p)
    v = Diagonal(p) * u0
    v = Zygote.hook-> (println("Δv: ", typeof(Δ)); Δ), v)

    w = foldl((acc, op) -> op * acc, mats; init=v) # w = B * A * v
    w = Zygote.hook-> (println("Δw: ", Δ); Δ), w)

    l = sum(w)
    l = Zygote.hook-> (println("Δl: ", Δ); Δ), l)
end

println("fwd"); @time loss_m(ps) |> display
println("bwd"); @time Zygote.gradient(loss_m, ps) |> display # INCORRECT - should not vanish

loss_n = function(p)
    v = Diagonal(p) * u0
    v = Zygote.hook-> (println("Δv: ", typeof(Δ)); Δ), v)

    w = sum(a -> convert(Number, a), nums; init=zero(eltype(nums))) * v # w = αβ * v
    w = Zygote.hook-> (println("Δw: ", Δ); Δ), w)

    l = sum(w)
    l = Zygote.hook-> (println("Δl: ", Δ); Δ), l)
end

println("fwd"); @time loss_n(ps) |> display
println("bwd"); @time Zygote.gradient(loss_n, ps) |> display # ERRORS
julia> include("examples/ad/zy.jl")
fwd
4.339451806053281
  0.021413 seconds (44.18 k allocations: 2.637 MiB, 99.38% compilation time)
bwd
Δl: 1.0
Δw: Fill(1.0, 4)
Δv: Nothing
(nothing,)
  0.139943 seconds (444.37 k allocations: 23.545 MiB, 99.45% compilation time)
fwd
1.5660193401267022
  0.355185 seconds (1.11 M allocations: 65.174 MiB, 99.67% compilation time)
bwd
ERROR: LoadError: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:872
  iterate(::Union{LinRange, StepRangeLen}, ::Integer) at range.jl:872
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at dict.jl:712
  ...
Stacktrace:
  [1] indexed_iterate(I::Nothing, i::Int64)
    @ Base ./tuple.jl:91
  [2] chain_rrule_kw
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:229 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0 [inlined]
  [4] _pullback(::Zygote.Context, ::Base.var"#sum##kw", ::NamedTuple{(:init,), Tuple{Float64}}, ::typeof(sum), ::var"#49#54", ::Tuple{Float64, Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:9
  [5] _pullback
    @ ~/.julia/dev/PDEInterfaces/examples/ad/zy.jl:29 [inlined]
  [6] _pullback(ctx::Zygote.Context, f::var"#47#52", args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
  [7] _pullback(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:34
  [8] pullback(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:40
  [9] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:75
 [10] top-level scope
    @ ./timing.jl:242
 [11] include(fname::String)
    @ Base.MainInclude ./client.jl:476
 [12] top-level scope
    @ REPL[2]:1
 [13] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
in expression starting at /home/vedantpu/.julia/dev/PDEInterfaces/examples/ad/zy.jl:38

ref - SciML/SciMLOperators.jl#94

@vpuri3
Copy link
Author

vpuri3 commented Aug 2, 2022

the case with sum works when I remove the kwarg init. but still curious why it wouldn't work otherwise

vpuri3 added a commit to CalculustJL/CalculustCore.jl that referenced this issue Aug 2, 2022
@mcabbott
Copy link
Member

mcabbott commented Aug 3, 2022

foldl not tracking init keyword is JuliaDiff/ChainRules.jl#567, you could try with JuliaDiff/ChainRules.jl#569

sum not supporting init is also bad, could you make an issue on ChainRules.jl?

julia> ChainRules.rrule(sum, [1,2,3]; init=4)
ERROR: MethodError: no method matching rrule(::typeof(sum), ::Vector{Int64}; init::Int64)

Closest candidates are:
  rrule(::typeof(sum), ::AbstractArray; dims) got unsupported keyword argument "init"
   @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/mapreduce.jl:28
  rrule(::typeof(sum), ::Any, ::AbstractArray{Bool}; sum_pullback) got unsupported keyword argument "init"
   @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:82

@mcabbott mcabbott added the ChainRules adjoint -> rrule, and further integration label Aug 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

No branches or pull requests

2 participants