Skip to content

Commit

Permalink
Fix gradient issues with kernelmatrix_diag and use ChainRulesCore (#208)
Browse files Browse the repository at this point in the history
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 25, 2021
1 parent ae78b73 commit aa2099e
Show file tree
Hide file tree
Showing 26 changed files with 366 additions and 169 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.8.25"
version = "0.8.26"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -17,8 +18,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ChainRulesCore = "0.9"
Compat = "3.7"
Distances = "0.9.1, 0.10"
Distances = "0.10"
Functors = "0.1"
Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10, 1"
Expand Down
9 changes: 6 additions & 3 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ export IndependentMOKernel, LatentFactorMOKernel
export tensor,

using Compat
using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS
using ChainRulesCore: @thunk, InplaceableThunk
using Requires
using Distances, LinearAlgebra
using Functors
using SpecialFunctions: loggamma, besselk, polygamma
using ZygoteRules: @adjoint, pullback
using StatsFuns: logtwo
using ZygoteRules: ZygoteRules
using StatsFuns: logtwo, twoπ
using StatsBase
using TensorCore

Expand Down Expand Up @@ -112,7 +114,8 @@ include(joinpath("mokernels", "moinput.jl"))
include(joinpath("mokernels", "independent.jl"))
include(joinpath("mokernels", "slfm.jl"))

include("zygote_adjoints.jl")
include("chainrules.jl")
include("zygoterules.jl")

include("test_utils.jl")

Expand Down
11 changes: 11 additions & 0 deletions src/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,14 @@ function kernelmatrix!(
K .= _fbm.(_mod(x), _mod(y)', K, κ.h)
return K
end

function kernelmatrix_diag::FBMKernel, x::AbstractVector)
modx = _mod(x)
modxx = colwise(SqEuclidean(), x)
return _fbm.(modx, modx, modxx, κ.h)
end

function kernelmatrix_diag::FBMKernel, x::AbstractVector, y::AbstractVector)
modxy = colwise(SqEuclidean(), x, y)
return _fbm.(_mod(x), _mod(y), modxy, κ.h)
end
4 changes: 4 additions & 0 deletions src/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ function kernelmatrix(κ::GaborKernel, x::AbstractVector, y::AbstractVector)
end

kernelmatrix_diag::GaborKernel, x::AbstractVector) = kernelmatrix_diag.kernel, x)

function kernelmatrix_diag::GaborKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix_diag.kernel, x, y)
end
26 changes: 26 additions & 0 deletions src/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ function kernelmatrix(::NeuralNetworkKernel, x::ColVecs)
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

function kernelmatrix_diag(::NeuralNetworkKernel, x::ColVecs)
x_2 = vec(sum(x.X .* x.X; dims=1))
return asin.(x_2 ./ (x_2 .+ 1))
end

function kernelmatrix_diag(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)
validate_inputs(x, y)
x_2 = vec(sum(x.X .* x.X; dims=1) .+ 1)
y_2 = vec(sum(y.X .* y.X; dims=1) .+ 1)
xy = vec(sum(x.X' .* y.X'; dims=2))
return asin.(xy ./ sqrt.(x_2 .* y_2))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
validate_inputs(x, y)
X_2 = sum(x.X .* x.X; dims=2)
Expand All @@ -65,4 +78,17 @@ function kernelmatrix(::NeuralNetworkKernel, x::RowVecs)
return asin.(XX ./ sqrt.(X_2_1 * X_2_1'))
end

function kernelmatrix_diag(::NeuralNetworkKernel, x::RowVecs)
x_2 = vec(sum(x.X .* x.X; dims=2))
return asin.(x_2 ./ (x_2 .+ 1))
end

function kernelmatrix_diag(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
validate_inputs(x, y)
x_2 = vec(sum(x.X .* x.X; dims=2) .+ 1)
y_2 = vec(sum(y.X .* y.X; dims=2) .+ 1)
xy = vec(sum(x.X .* y.X; dims=2))
return asin.(xy ./ sqrt.(x_2 .* y_2))
end

Base.show(io::IO, ::NeuralNetworkKernel) = print(io, "Neural Network Kernel")
167 changes: 167 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
## Forward Rules

# Note that this is type piracy as the derivative should be NaN for x == y.
function ChainRulesCore.frule(
(_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector
)
Δ = x - y
D = sqrt(sum(abs2, Δ))
if !iszero(D)
Δ ./= D
end
return D, dot(Δ, Δx) - dot(Δ, Δy)
end

## Reverse Rules Delta

function ChainRulesCore.rrule(dist::Delta, x::AbstractVector, y::AbstractVector)
d = dist(x, y)
function evaluate_pullback(::Any)
return NO_FIELDS, Zero(), Zero()
end
return d, evaluate_pullback
end

function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X, Y; dims=dims)
function pairwise_pullback(::AbstractMatrix)
return NO_FIELDS, NO_FIELDS, Zero(), Zero()
end
return P, pairwise_pullback
end

function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X; dims=dims)
function pairwise_pullback(::AbstractMatrix)
return NO_FIELDS, NO_FIELDS, Zero()
end
return P, pairwise_pullback
end

function ChainRulesCore.rrule(
::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix
)
C = Distances.colwise(d, X, Y)
function colwise_pullback(::AbstractVector)
return NO_FIELDS, NO_FIELDS, Zero(), Zero()
end
return C, colwise_pullback
end

## Reverse Rules DotProduct

function ChainRulesCore.rrule(dist::DotProduct, x::AbstractVector, y::AbstractVector)
d = dist(x, y)
function evaluate_pullback::Any)
return NO_FIELDS, Δ .* y, Δ .* x
end
return d, evaluate_pullback
end

function ChainRulesCore.rrule(
::typeof(Distances.pairwise),
d::DotProduct,
X::AbstractMatrix,
Y::AbstractMatrix;
dims=2,
)
P = Distances.pairwise(d, X, Y; dims=dims)
function pairwise_pullback_cols::AbstractMatrix)
if dims == 1
return NO_FIELDS, NO_FIELDS, Δ * Y, Δ' * X
else
return NO_FIELDS, NO_FIELDS, Y * Δ', X * Δ
end
end
return P, pairwise_pullback_cols
end

function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X; dims=dims)
function pairwise_pullback_cols::AbstractMatrix)
if dims == 1
return NO_FIELDS, NO_FIELDS, 2 * Δ * X
else
return NO_FIELDS, NO_FIELDS, 2 * X * Δ
end
end
return P, pairwise_pullback_cols
end

function ChainRulesCore.rrule(
::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix
)
C = Distances.colwise(d, X, Y)
function colwise_pullback::AbstractVector)
return NO_FIELDS, NO_FIELDS, Δ' .* Y, Δ' .* X
end
return C, colwise_pullback
end

## Reverse Rules Sinus

function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
d = x - y
sind = sinpi.(d)
abs2_sind_r = abs2.(sind) ./ s.r
val = sum(abs2_sind_r)
gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2)
function evaluate_pullback::Any)
return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx
end
return val, evaluate_pullback
end

## Reverse Rulse SqMahalanobis

function ChainRulesCore.rrule(
dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector
)
d = dist(a, b)
function SqMahalanobis_pullback::Real)
a_b = a - b
∂qmat = InplaceableThunk(
@thunk((a_b * a_b') * Δ), X̄ -> mul!(X̄, a_b, a_b', true, Δ)
)
∂a = InplaceableThunk(
@thunk((2 * Δ) * dist.qmat * a_b), X̄ -> mul!(X̄, dist.qmat, a_b, true, 2 * Δ)
)
∂b = InplaceableThunk(
@thunk((-2 * Δ) * dist.qmat * a_b), X̄ -> mul!(X̄, dist.qmat, a_b, true, -2 * Δ)
)
return Composite{typeof(dist)}(; qmat=∂qmat), ∂a, ∂b
end
return d, SqMahalanobis_pullback
end

## Reverse Rules for matrix wrappers

function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)
ColVecs_pullback::Composite) = (NO_FIELDS, Δ.X)
function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
return error(
"Pullback on AbstractVector{<:AbstractVector}.\n" *
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`",
)
end
return ColVecs(X), ColVecs_pullback
end

function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix)
RowVecs_pullback::Composite) = (NO_FIELDS, Δ.X)
function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
return error(
"Pullback on AbstractVector{<:AbstractVector}.\n" *
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`",
)
end
return RowVecs(X), RowVecs_pullback
end
5 changes: 3 additions & 2 deletions src/distances/delta.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
struct Delta <: Distances.PreMetric end
# Delta is not following the PreMetric rules since d(x, x) == 1
struct Delta <: Distances.UnionPreMetric end

@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector)
@boundscheck if length(a) != length(b)
Expand All @@ -12,7 +13,7 @@ struct Delta <: Distances.PreMetric end
return a == b
end

Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool

@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
@inline (dist::Delta)(a::Number, b::Number) = a == b
4 changes: 2 additions & 2 deletions src/distances/dotproduct.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct DotProduct <: Distances.PreMetric end
# struct DotProduct <: Distances.UnionSemiMetric end
## DotProduct is not following the PreMetric rules since d(x, x) != 0 and d(x, y) >= 0 for all x, y
struct DotProduct <: Distances.UnionPreMetric end

@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector)
@boundscheck if length(a) != length(b)
Expand Down
31 changes: 31 additions & 0 deletions src/distances/pairwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,34 @@ function pairwise!(
)
return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
end

# Also defines the colwise method for abstractvectors

function colwise(d::PreMetric, x::AbstractVector)
return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition
end

## The following is a hack for DotProduct and Delta to still work
function colwise(d::Distances.UnionPreMetric, x::ColVecs)
return Distances.colwise(d, x.X, x.X)
end

function colwise(d::Distances.UnionPreMetric, x::RowVecs)
return Distances.colwise(d, x.X', x.X')
end

function colwise(d::Distances.UnionPreMetric, x::AbstractVector)
return map(d, x, x)
end

function colwise(d::PreMetric, x::ColVecs, y::ColVecs)
return Distances.colwise(d, x.X, y.X)
end

function colwise(d::PreMetric, x::RowVecs, y::RowVecs)
return Distances.colwise(d, x.X', y.X')
end

function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector)
return map(d, x, y)
end
3 changes: 1 addition & 2 deletions src/distances/sinus.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
struct Sinus{T} <: Distances.SemiMetric
# struct Sinus{T} <: Distances.UnionSemiMetric
struct Sinus{T} <: Distances.UnionSemiMetric
r::Vector{T}
end

Expand Down
2 changes: 1 addition & 1 deletion src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ Base.iterate(k::Kernel, ::Any) = nothing
printshifted(io::IO, o, shift::Int) = print(io, o)

# Fallback implementation of evaluate for `SimpleKernel`s.
(k::SimpleKernel)(x, y) = kappa(k, evaluate(metric(k), x, y))
(k::SimpleKernel)(x, y) = kappa(k, metric(k)(x, y))
4 changes: 4 additions & 0 deletions src/kernels/kernelproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ function kernelmatrix_diag(κ::KernelProduct, x::AbstractVector)
return reduce(hadamard, kernelmatrix_diag(k, x) for k in κ.kernels)
end

function kernelmatrix_diag::KernelProduct, x::AbstractVector, y::AbstractVector)
return reduce(hadamard, kernelmatrix_diag(k, x, y) for k in κ.kernels)
end

function Base.show(io::IO, κ::KernelProduct)
return printshifted(io, κ, 0)
end
Expand Down
4 changes: 4 additions & 0 deletions src/kernels/kernelsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ function kernelmatrix_diag(κ::KernelSum, x::AbstractVector)
return sum(kernelmatrix_diag(k, x) for k in κ.kernels)
end

function kernelmatrix_diag::KernelSum, x::AbstractVector, y::AbstractVector)
return sum(kernelmatrix_diag(k, x, y) for k in κ.kernels)
end

function Base.show(io::IO, κ::KernelSum)
return printshifted(io, κ, 0)
end
Expand Down
5 changes: 5 additions & 0 deletions src/kernels/kerneltensorproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ function kernelmatrix_diag(k::KernelTensorProduct, x::AbstractVector)
return mapreduce(kernelmatrix_diag, hadamard, k.kernels, slices(x))
end

function kernelmatrix_diag(k::KernelTensorProduct, x::AbstractVector, y::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix_diag, hadamard, k.kernels, slices(x), slices(y))
end

Base.show(io::IO, kernel::KernelTensorProduct) = printshifted(io, kernel, 0)

function Base.:(==)(x::KernelTensorProduct, y::KernelTensorProduct)
Expand Down
4 changes: 4 additions & 0 deletions src/kernels/scaledkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ function kernelmatrix_diag(κ::ScaledKernel, x::AbstractVector)
return κ.σ² .* kernelmatrix_diag.kernel, x)
end

function kernelmatrix_diag::ScaledKernel, x::AbstractVector, y::AbstractVector)
return κ.σ² .* kernelmatrix_diag.kernel, x, y)
end

function kernelmatrix!(
K::AbstractMatrix, κ::ScaledKernel, x::AbstractVector, y::AbstractVector
)
Expand Down
Loading

2 comments on commit aa2099e

@theogf
Copy link
Member Author

@theogf theogf commented on aa2099e Mar 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/32800

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.26 -m "<description of version>" aa2099ee03310808eadc9fa0f766f1ab8c925d12
git push origin v0.8.26

Please sign in to comment.