Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove zygoterules #451

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

remove zygoterules #451

wants to merge 4 commits into from

Conversation

st--
Copy link
Member

@st-- st-- commented Apr 13, 2022

Curious to see what breaks...

@codecov
Copy link

codecov bot commented Apr 13, 2022

Codecov Report

Merging #451 (0f20455) into master (8e805ef) will increase coverage by 0.28%.
The diff coverage is n/a.

@@            Coverage Diff             @@
##           master     #451      +/-   ##
==========================================
+ Coverage   93.18%   93.46%   +0.28%     
==========================================
  Files          52       51       -1     
  Lines        1261     1255       -6     
==========================================
- Hits         1175     1173       -2     
+ Misses         86       82       -4     
Impacted Files Coverage Δ
src/KernelFunctions.jl 100.00% <ø> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8e805ef...0f20455. Read the comment docs.

@st-- st-- marked this pull request as ready for review April 13, 2022 12:36
@willtebbutt
Copy link
Member

Huh, this is weird. Is it clear if we're hitting these adjoints at all in our tests?

@st--
Copy link
Member Author

st-- commented Apr 13, 2022

Huh, this is weird. Is it clear if we're hitting these adjoints at all in our tests?

_map is used by the transforms, and most of the transforms have passing AD tests, so I would say "probably"? I've not yet added print() statements to the custom adjoints to actually see...

@devmotion
Copy link
Member

_map is used by the transforms, and most of the transforms have passing AD tests, so I would say "probably"?

The rules were defined for map, not _map.

@st--
Copy link
Member Author

st-- commented Apr 13, 2022

_map is used by the transforms, and most of the transforms have passing AD tests, so I would say "probably"?

The rules were defined for map, not _map.

The rules were defined for map(t::Transform, x), which is defined in src/transform/transform.jl as _map(t, x). So what are you trying to say by that?

@devmotion
Copy link
Member

That

ZygoteRules.@adjoint function Base.map(t::Transform, X::ColVecs)
return ZygoteRules.pullback(_map, t, X)
end
ZygoteRules.@adjoint function Base.map(t::Transform, X::RowVecs)
return ZygoteRules.pullback(_map, t, X)
end
are rules for map. Internally, typically we work with _map though (due to AD issues and map being handled to generally by Zygote), e.g., in
function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix_diag!(K, κ.kernel, _map.transform, x))
end
function kernelmatrix_diag!(
K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix_diag!(K, κ.kernel, _map.transform, x), _map.transform, y))
end
function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix!(K, κ.kernel, _map.transform, x))
end
function kernelmatrix!(
K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix!(K, κ.kernel, _map.transform, x), _map.transform, y))
end
function kernelmatrix_diag::TransformedKernel, x::AbstractVector)
return kernelmatrix_diag.kernel, _map.transform, x))
end
function kernelmatrix_diag::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix_diag.kernel, _map.transform, x), _map.transform, y))
end
function kernelmatrix::TransformedKernel, x::AbstractVector)
return kernelmatrix.kernel, _map.transform, x))
end
function kernelmatrix::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix.kernel, _map.transform, x), _map.transform, y))
end
, and hence rules for map will not be hit.

@willtebbutt
Copy link
Member

Just verified this locally by adding print statements to the zygote rules -- they don't seem to be hit. Given that we document that one can call map(t, x), I would be in favour of adding a couple of unit tests to ensure that this actually works. We don't need to do that here, but an issue to this effect so that we don't forget about it would be good.

@st--
Copy link
Member Author

st-- commented Apr 13, 2022

@devmotion ah okay, I think I see what you mean now.

Given the map(t::Transform, x) = _map(t, x) definition, should we then just replace all _map calls inside e.g. kernelmatrix methods with just map?

@willtebbutt what would you like to have added unit tests for ?

@devmotion
Copy link
Member

Given the map(t::Transform, x) = _map(t, x) definition, should we then just replace all _map calls inside e.g. kernelmatrix methods with just map?

I don't think this will work. The main reason why _map was introduced was that map did not work because Zygote handles map(f, ::AbstractVector). See #113, #114, and FluxML/Zygote.jl#646.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants