Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rrule for Base.sum should accept init keyword #656

Open
vpuri3 opened this issue Aug 3, 2022 · 3 comments
Open

rrule for Base.sum should accept init keyword #656

vpuri3 opened this issue Aug 3, 2022 · 3 comments

Comments

@vpuri3
Copy link

vpuri3 commented Aug 3, 2022

The below issue contains an MWE of Zygote erroring on the init kwarg to Base.sum. Best way to fix it would be to define an rrule for sum

FluxML/Zygote.jl#1279

@vpuri3
Copy link
Author

vpuri3 commented Aug 3, 2022

cc @mcabbott

@mcabbott
Copy link
Member

mcabbott commented Aug 3, 2022

There's a rule for sum,

function rrule(::typeof(sum), x::AbstractArray; dims=:)
project = ProjectTo(x)
y = sum(x; dims=dims)
function sum_pullback(dy_raw)
dy = unthunk(dy_raw)
x_thunk = InplaceableThunk(
# Protect `dy` from broadcasting, for when `x` is an array of arrays:
dx -> dx .+= (dims isa Colon ? Ref(dy) : dy),
@thunk project(_unsum(x, dy, dims)) # `_unsum` handles Ref internally
)
return (NoTangent(), x_thunk)
end
return y, sum_pullback
end
, but what it doesn't do is handle init keyword, added in Julia 1.6:

julia> sum([1 2 3; 4 5 6]; init=10)
31

julia> sum([1 2 3; 4 5 6]; init=10, dims=1)
1×3 Matrix{Int64}:
 15  17  19

also:

julia> sum(abs2, [1 2 3; 4 5 6]; init=10)
101

julia> sum(abs2, [1 2 3; 4 5 6]; init=10, dims=1)
1×3 Matrix{Int64}:
 27  39  55

@mcabbott mcabbott changed the title Need rrule for Base.sum rrule for Base.sum should accept init keyword Aug 3, 2022
@mcabbott
Copy link
Member

Xref JuliaDiff/ChainRulesCore.jl#384 --- the lowered form is this

julia> Meta.@lower sum(x; init=10)
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope`
1 ─ %1 = Core.tuple(:init)
│   %2 = Core.apply_type(Core.NamedTuple, %1)
│   %3 = Core.tuple(10)
│   %4 = (%2)(%3)
│   %5 = Core.kwfunc(sum)
│   %6 = (%5)(%4, sum, x)
└──      return %6
))))

but defining a rule as rrule(Core.kwfunc(f), kwargs, args...) doesn't work right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants