Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient issues with kernelmatrix_diag and use ChainRulesCore #208

Merged
merged 50 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
e525614
Use broadcasting instead of map for kerneldiagmatrix
theogf Dec 9, 2020
e56492a
Removed method for transformedkernel
theogf Dec 9, 2020
35a6306
Restored functions and applied suggestions
theogf Dec 14, 2020
25e5efd
Added tests for diagmatrix
theogf Dec 14, 2020
2f85ebc
Put changes to the right file and removed utils_AD.jl
theogf Dec 14, 2020
cae225f
Apply suggestions from code review
theogf Dec 14, 2020
3f16f07
Added colwise and fixed kerneldiagmatrix
theogf Dec 15, 2020
8c0d0a2
Added colwise for RowVecs and ColVecs
theogf Dec 16, 2020
13a10fd
Removed definition relying on Distances.colwise!
theogf Dec 21, 2020
78a2078
Merge branch 'master' into fix_diagmat
theogf Mar 16, 2021
5ca94e7
Readapt to kernelmatrix_diag
theogf Mar 16, 2021
2c60abd
Fixes for Zygote
theogf Mar 16, 2021
9214211
Remove type piracy
theogf Mar 16, 2021
87edbc8
Adding some adjoints (not everything fixed yet)
theogf Mar 17, 2021
f65556b
Fixed adjoint for polynomials
theogf Mar 17, 2021
48e2dcb
Add ChainRulesCore for defining rrule
theogf Mar 17, 2021
6cc803d
Replace broadcast by map
theogf Mar 17, 2021
0e30941
Missing return for style
theogf Mar 17, 2021
61869b1
Fixing ZygoteRules
theogf Mar 22, 2021
06bd4f0
Renamed zygote_adjoints to chainrules
theogf Mar 22, 2021
8e1e516
Apply formatting suggestions
theogf Mar 22, 2021
aaa16de
Added forward rule for Euclidean distance
theogf Mar 22, 2021
52b1ae5
Corrected rules for Row/ColVecs constructors
theogf Mar 22, 2021
4067a42
Added ZygoteRules back for the "map hack"
theogf Mar 22, 2021
641ebee
Corrected the rrules
theogf Mar 22, 2021
13d1e39
Type stable frule
theogf Mar 22, 2021
4675c2f
Corrected tests
theogf Mar 23, 2021
0b97c1a
Adapted the use of Distances.jl
theogf Mar 23, 2021
ad9838e
Added methods to make nn work
theogf Mar 23, 2021
650dc08
Missing kernelmatrix_diag
theogf Mar 23, 2021
1703db1
Formatting suggestions
theogf Mar 23, 2021
e2cd167
Added methods for FBM
theogf Mar 23, 2021
01ffac0
Last fix on Delta
theogf Mar 23, 2021
9bfb6eb
Potential fix for Euclidean
theogf Mar 23, 2021
f3fa4bc
Missing Distances.
theogf Mar 23, 2021
a0c2a64
Wrong file naming
theogf Mar 23, 2021
ff5a66b
Correct formatting
theogf Mar 23, 2021
8157b4c
Better error message
theogf Mar 23, 2021
e6bfdb1
Moar formatting
theogf Mar 23, 2021
db5e7b8
Applied suggestions
theogf Mar 24, 2021
a44a762
Fixed the dims issue with pairwise
theogf Mar 24, 2021
72889dd
Fixed formatting
theogf Mar 24, 2021
25549c1
Missing @thunk
theogf Mar 24, 2021
bbe5c7c
Putting back Composite to Any
theogf Mar 24, 2021
e08dbf4
add @thunk for -delta a
theogf Mar 24, 2021
48bd681
Update src/chainrules.jl
theogf Mar 25, 2021
3298d34
Update KernelFunctions.jl
theogf Mar 25, 2021
0b99771
Apply suggestions from code review
theogf Mar 25, 2021
c26edf3
Update Project.toml
theogf Mar 25, 2021
647862a
Merge branch 'master' into fix_diagmat
theogf Mar 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ function kerneldiagmatrix!(K::AbstractVector, κ::TransformedKernel, x::Abstract
return kerneldiagmatrix!(K, κ.kernel, _map(κ.transform, x))
end

function kerneldiagmatrix!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kerneldiagmatrix!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
end

function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix!(K, kernel(κ), _map(κ.transform, x))
end
Expand All @@ -82,6 +86,10 @@ function kerneldiagmatrix(κ::TransformedKernel, x::AbstractVector)
return kerneldiagmatrix(κ.kernel, _map(κ.transform, x))
end

function kerneldiagmatrix(κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kerneldiagmatrix(κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
end

function kernelmatrix(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix(kernel(κ), _map(κ.transform, x))
end
Expand Down
9 changes: 8 additions & 1 deletion src/matrix/kernelmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ kerneldiagmatrix(κ::Kernel, x::AbstractVector) = map(x -> κ(x, x), x)
kerneldiagmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector) = map(κ, x, y)



#
# SimpleKernel optimisations.
#
Expand Down Expand Up @@ -104,6 +103,14 @@ function kernelmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector)
return map(d -> kappa(κ, d), pairwise(metric(κ), x, y))
end

function kerneldiagmatrix(κ::SimpleKernel, x::AbstractVector)
return map(x -> κ(x, x), x)
theogf marked this conversation as resolved.
Show resolved Hide resolved
end

function kerneldiagmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector)
return map(d -> kappa(κ, d), map(metric(κ), x, y))
theogf marked this conversation as resolved.
Show resolved Hide resolved
end



#
Expand Down
180 changes: 180 additions & 0 deletions test/utils_AD.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@

const FDM = FiniteDifferences.central_fdm(5, 1)
devmotion marked this conversation as resolved.
Show resolved Hide resolved

gradient(f, s::Symbol, args) = gradient(f, Val(s), args)

function gradient(f, ::Val{:Zygote}, args)
g = first(Zygote.gradient(f, args))
if isnothing(g)
if args isa AbstractArray{<:Real}
return zeros(size(args)) # To respect the same output as other ADs
else
return zeros.(size.(args))
end
else
return g
end
end

function gradient(f, ::Val{:ForwardDiff}, args)
ForwardDiff.gradient(f, args)
end

function gradient(f, ::Val{:ReverseDiff}, args)
ReverseDiff.gradient(f, args)
end

function gradient(f, ::Val{:FiniteDiff}, args)
first(FiniteDifferences.grad(FDM, f, args))
end

function compare_gradient(f, AD::Symbol, args)
grad_AD = gradient(f, AD, args)
grad_FD = gradient(f, :FiniteDiff, args)
@test grad_AD ≈ grad_FD atol=1e-8 rtol=1e-5
end

testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim = dim))
testfunction(k, A, dim) = sum(kernelmatrix(k, A, obsdim = dim))
testdiagfunction(k, A, dim) = sum(kerneldiagmatrix(k, A, obsdim = dim))
testdiagfunction(k, A, B, dim) = sum(kerneldiagmatrix(k, A, B, obsdim = dim))

function test_ADs(kernelfunction, args = nothing; ADs = [:Zygote, :ForwardDiff, :ReverseDiff], dims = [3, 3])
test_fd = test_FiniteDiff(kernelfunction, args, dims)
if !test_fd.anynonpass
for AD in ADs
test_AD(AD, kernelfunction, args, dims)
end
end
end

function test_FiniteDiff(kernelfunction, args = nothing, dims = [3, 3])
# Init arguments :
k = if args === nothing
kernelfunction()
else
kernelfunction(args)
end
rng = MersenneTwister(42)
@testset "FiniteDifferences" begin
if k isa SimpleKernel
for d in log.([eps(), rand(rng)])
@test_nowarn gradient(:FiniteDiff, [d]) do x
kappa(k, exp(first(x)))
end
end
end
## Testing Kernel Functions
x = rand(rng, dims[1])
y = rand(rng, dims[1])
@test_nowarn gradient(:FiniteDiff, x) do x
k(x, y)
end
if !(args === nothing)
@test_nowarn gradient(:FiniteDiff, args) do p
kernelfunction(p)(x, y)
end
end
## Testing Kernel Matrices
A = rand(rng, dims...)
B = rand(rng, dims...)
for dim in 1:2
@test_nowarn gradient(:FiniteDiff, A) do a
testfunction(k, a, dim)
end
@test_nowarn gradient(:FiniteDiff , A) do a
testfunction(k, a, B, dim)
end
@test_nowarn gradient(:FiniteDiff, B) do b
testfunction(k, A, b, dim)
end
if !(args === nothing)
@test_nowarn gradient(:FiniteDiff, args) do p
testfunction(kernelfunction(p), A, B, dim)
end
end

@test_nowarn gradient(:FiniteDiff, A) do a
testdiagfunction(k, a, dim)
end
@test_nowarn gradient(:FiniteDiff , A) do a
testdiagfunction(k, a, B, dim)
end
@test_nowarn gradient(:FiniteDiff, B) do b
testdiagfunction(k, A, b, dim)
end
if !(args === nothing)
@test_nowarn gradient(:FiniteDiff, args) do p
testdiagfunction(kernelfunction(p), A, B, dim)
end
end
end
end
end

function test_AD(AD::Symbol, kernelfunction, args = nothing, dims = [3, 3])
@testset "$(AD)" begin
# Test kappa function
k = if args === nothing
kernelfunction()
else
kernelfunction(args)
end
rng = MersenneTwister(42)
if k isa SimpleKernel
for d in log.([eps(), rand(rng)])
compare_gradient(AD, [d]) do x
kappa(k, exp(x[1]))
end
end
end
# Testing kernel evaluations
x = rand(rng, dims[1])
y = rand(rng, dims[1])
compare_gradient(AD, x) do x
k(x, y)
end
compare_gradient(AD, y) do y
k(x, y)
end
if !(args === nothing)
compare_gradient(AD, args) do p
kernelfunction(p)(x,y)
end
end
# Testing kernel matrices
A = rand(rng, dims...)
B = rand(rng, dims...)
for dim in 1:2
compare_gradient(AD, A) do a
testfunction(k, a, dim)
end
compare_gradient(AD, A) do a
testfunction(k, a, B, dim)
end
compare_gradient(AD, B) do b
testfunction(k, A, b, dim)
end
if !(args === nothing)
compare_gradient(AD, args) do p
testfunction(kernelfunction(p), A, dim)
end
end

compare_gradient(AD, A) do a
testdiagfunction(k, a, dim)
end
compare_gradient(AD, A) do a
testdiagfunction(k, a, B, dim)
end
compare_gradient(AD, B) do b
testdiagfunction(k, A, b, dim)
end
if !(args === nothing)
compare_gradient(AD, args) do p
testdiagfunction(kernelfunction(p), A, dim)
end
end
end
end
end