diff --git a/Project.toml b/Project.toml index b02a50b94..542e490ab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,8 +1,9 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.8.25" +version = "0.8.26" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -17,8 +18,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] +ChainRulesCore = "0.9" Compat = "3.7" -Distances = "0.9.1, 0.10" +Distances = "0.10" Functors = "0.1" Requires = "1.0.1" SpecialFunctions = "0.8, 0.9, 0.10, 1" diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 7a6ec8a6c..8746b70e4 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -55,12 +55,14 @@ export IndependentMOKernel, LatentFactorMOKernel export tensor, ⊗ using Compat +using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS +using ChainRulesCore: @thunk, InplaceableThunk using Requires using Distances, LinearAlgebra using Functors using SpecialFunctions: loggamma, besselk, polygamma -using ZygoteRules: @adjoint, pullback -using StatsFuns: logtwo +using ZygoteRules: ZygoteRules +using StatsFuns: logtwo, twoπ using StatsBase using TensorCore @@ -112,7 +114,8 @@ include(joinpath("mokernels", "moinput.jl")) include(joinpath("mokernels", "independent.jl")) include(joinpath("mokernels", "slfm.jl")) -include("zygote_adjoints.jl") +include("chainrules.jl") +include("zygoterules.jl") include("test_utils.jl") diff --git a/src/basekernels/fbm.jl b/src/basekernels/fbm.jl index fad6d70f2..7ea88e110 100644 --- a/src/basekernels/fbm.jl +++ b/src/basekernels/fbm.jl @@ -69,3 +69,14 @@ function kernelmatrix!( K .= _fbm.(_mod(x), _mod(y)', K, κ.h) return K end + +function kernelmatrix_diag(κ::FBMKernel, x::AbstractVector) + modx = _mod(x) + modxx = colwise(SqEuclidean(), x) + return _fbm.(modx, modx, modxx, κ.h) +end + +function kernelmatrix_diag(κ::FBMKernel, x::AbstractVector, y::AbstractVector) + modxy = colwise(SqEuclidean(), x, y) + return _fbm.(_mod(x), _mod(y), modxy, κ.h) +end diff --git a/src/basekernels/gabor.jl b/src/basekernels/gabor.jl index 2796901e2..311afcf1b 100644 --- a/src/basekernels/gabor.jl +++ b/src/basekernels/gabor.jl @@ -72,3 +72,7 @@ function kernelmatrix(κ::GaborKernel, x::AbstractVector, y::AbstractVector) end kernelmatrix_diag(κ::GaborKernel, x::AbstractVector) = kernelmatrix_diag(κ.kernel, x) + +function kernelmatrix_diag(κ::GaborKernel, x::AbstractVector, y::AbstractVector) + return kernelmatrix_diag(κ.kernel, x, y) +end diff --git a/src/basekernels/nn.jl b/src/basekernels/nn.jl index 52b9c607c..40070075d 100644 --- a/src/basekernels/nn.jl +++ b/src/basekernels/nn.jl @@ -51,6 +51,19 @@ function kernelmatrix(::NeuralNetworkKernel, x::ColVecs) return asin.(XX ./ sqrt.(X_2_1' * X_2_1)) end +function kernelmatrix_diag(::NeuralNetworkKernel, x::ColVecs) + x_2 = vec(sum(x.X .* x.X; dims=1)) + return asin.(x_2 ./ (x_2 .+ 1)) +end + +function kernelmatrix_diag(::NeuralNetworkKernel, x::ColVecs, y::ColVecs) + validate_inputs(x, y) + x_2 = vec(sum(x.X .* x.X; dims=1) .+ 1) + y_2 = vec(sum(y.X .* y.X; dims=1) .+ 1) + xy = vec(sum(x.X' .* y.X'; dims=2)) + return asin.(xy ./ sqrt.(x_2 .* y_2)) +end + function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs) validate_inputs(x, y) X_2 = sum(x.X .* x.X; dims=2) @@ -65,4 +78,17 @@ function kernelmatrix(::NeuralNetworkKernel, x::RowVecs) return asin.(XX ./ sqrt.(X_2_1 * X_2_1')) end +function kernelmatrix_diag(::NeuralNetworkKernel, x::RowVecs) + x_2 = vec(sum(x.X .* x.X; dims=2)) + return asin.(x_2 ./ (x_2 .+ 1)) +end + +function kernelmatrix_diag(::NeuralNetworkKernel, x::RowVecs, y::RowVecs) + validate_inputs(x, y) + x_2 = vec(sum(x.X .* x.X; dims=2) .+ 1) + y_2 = vec(sum(y.X .* y.X; dims=2) .+ 1) + xy = vec(sum(x.X .* y.X; dims=2)) + return asin.(xy ./ sqrt.(x_2 .* y_2)) +end + Base.show(io::IO, ::NeuralNetworkKernel) = print(io, "Neural Network Kernel") diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 000000000..3c871d89a --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,167 @@ +## Forward Rules + +# Note that this is type piracy as the derivative should be NaN for x == y. +function ChainRulesCore.frule( + (_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector +) + Δ = x - y + D = sqrt(sum(abs2, Δ)) + if !iszero(D) + Δ ./= D + end + return D, dot(Δ, Δx) - dot(Δ, Δy) +end + +## Reverse Rules Delta + +function ChainRulesCore.rrule(dist::Delta, x::AbstractVector, y::AbstractVector) + d = dist(x, y) + function evaluate_pullback(::Any) + return NO_FIELDS, Zero(), Zero() + end + return d, evaluate_pullback +end + +function ChainRulesCore.rrule( + ::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2 +) + P = Distances.pairwise(d, X, Y; dims=dims) + function pairwise_pullback(::AbstractMatrix) + return NO_FIELDS, NO_FIELDS, Zero(), Zero() + end + return P, pairwise_pullback +end + +function ChainRulesCore.rrule( + ::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2 +) + P = Distances.pairwise(d, X; dims=dims) + function pairwise_pullback(::AbstractMatrix) + return NO_FIELDS, NO_FIELDS, Zero() + end + return P, pairwise_pullback +end + +function ChainRulesCore.rrule( + ::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix +) + C = Distances.colwise(d, X, Y) + function colwise_pullback(::AbstractVector) + return NO_FIELDS, NO_FIELDS, Zero(), Zero() + end + return C, colwise_pullback +end + +## Reverse Rules DotProduct + +function ChainRulesCore.rrule(dist::DotProduct, x::AbstractVector, y::AbstractVector) + d = dist(x, y) + function evaluate_pullback(Δ::Any) + return NO_FIELDS, Δ .* y, Δ .* x + end + return d, evaluate_pullback +end + +function ChainRulesCore.rrule( + ::typeof(Distances.pairwise), + d::DotProduct, + X::AbstractMatrix, + Y::AbstractMatrix; + dims=2, +) + P = Distances.pairwise(d, X, Y; dims=dims) + function pairwise_pullback_cols(Δ::AbstractMatrix) + if dims == 1 + return NO_FIELDS, NO_FIELDS, Δ * Y, Δ' * X + else + return NO_FIELDS, NO_FIELDS, Y * Δ', X * Δ + end + end + return P, pairwise_pullback_cols +end + +function ChainRulesCore.rrule( + ::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2 +) + P = Distances.pairwise(d, X; dims=dims) + function pairwise_pullback_cols(Δ::AbstractMatrix) + if dims == 1 + return NO_FIELDS, NO_FIELDS, 2 * Δ * X + else + return NO_FIELDS, NO_FIELDS, 2 * X * Δ + end + end + return P, pairwise_pullback_cols +end + +function ChainRulesCore.rrule( + ::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix +) + C = Distances.colwise(d, X, Y) + function colwise_pullback(Δ::AbstractVector) + return NO_FIELDS, NO_FIELDS, Δ' .* Y, Δ' .* X + end + return C, colwise_pullback +end + +## Reverse Rules Sinus + +function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) + d = x - y + sind = sinpi.(d) + abs2_sind_r = abs2.(sind) ./ s.r + val = sum(abs2_sind_r) + gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2) + function evaluate_pullback(Δ::Any) + return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx + end + return val, evaluate_pullback +end + +## Reverse Rulse SqMahalanobis + +function ChainRulesCore.rrule( + dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector +) + d = dist(a, b) + function SqMahalanobis_pullback(Δ::Real) + a_b = a - b + ∂qmat = InplaceableThunk( + @thunk((a_b * a_b') * Δ), X̄ -> mul!(X̄, a_b, a_b', true, Δ) + ) + ∂a = InplaceableThunk( + @thunk((2 * Δ) * dist.qmat * a_b), X̄ -> mul!(X̄, dist.qmat, a_b, true, 2 * Δ) + ) + ∂b = InplaceableThunk( + @thunk((-2 * Δ) * dist.qmat * a_b), X̄ -> mul!(X̄, dist.qmat, a_b, true, -2 * Δ) + ) + return Composite{typeof(dist)}(; qmat=∂qmat), ∂a, ∂b + end + return d, SqMahalanobis_pullback +end + +## Reverse Rules for matrix wrappers + +function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) + ColVecs_pullback(Δ::Composite) = (NO_FIELDS, Δ.X) + function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) + return error( + "Pullback on AbstractVector{<:AbstractVector}.\n" * + "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * + "To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`", + ) + end + return ColVecs(X), ColVecs_pullback +end + +function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) + RowVecs_pullback(Δ::Composite) = (NO_FIELDS, Δ.X) + function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) + return error( + "Pullback on AbstractVector{<:AbstractVector}.\n" * + "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * + "To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`", + ) + end + return RowVecs(X), RowVecs_pullback +end diff --git a/src/distances/delta.jl b/src/distances/delta.jl index 273804308..979ecc197 100644 --- a/src/distances/delta.jl +++ b/src/distances/delta.jl @@ -1,4 +1,5 @@ -struct Delta <: Distances.PreMetric end +# Delta is not following the PreMetric rules since d(x, x) == 1 +struct Delta <: Distances.UnionPreMetric end @inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) @boundscheck if length(a) != length(b) @@ -12,7 +13,7 @@ struct Delta <: Distances.PreMetric end return a == b end -Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb) +Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool @inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b) @inline (dist::Delta)(a::Number, b::Number) = a == b diff --git a/src/distances/dotproduct.jl b/src/distances/dotproduct.jl index ef0f64b28..1cef13ab5 100644 --- a/src/distances/dotproduct.jl +++ b/src/distances/dotproduct.jl @@ -1,5 +1,5 @@ -struct DotProduct <: Distances.PreMetric end -# struct DotProduct <: Distances.UnionSemiMetric end +## DotProduct is not following the PreMetric rules since d(x, x) != 0 and d(x, y) >= 0 for all x, y +struct DotProduct <: Distances.UnionPreMetric end @inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector) @boundscheck if length(a) != length(b) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 7a555222f..7379f0cbd 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -29,3 +29,34 @@ function pairwise!( ) return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1) end + +# Also defines the colwise method for abstractvectors + +function colwise(d::PreMetric, x::AbstractVector) + return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition +end + +## The following is a hack for DotProduct and Delta to still work +function colwise(d::Distances.UnionPreMetric, x::ColVecs) + return Distances.colwise(d, x.X, x.X) +end + +function colwise(d::Distances.UnionPreMetric, x::RowVecs) + return Distances.colwise(d, x.X', x.X') +end + +function colwise(d::Distances.UnionPreMetric, x::AbstractVector) + return map(d, x, x) +end + +function colwise(d::PreMetric, x::ColVecs, y::ColVecs) + return Distances.colwise(d, x.X, y.X) +end + +function colwise(d::PreMetric, x::RowVecs, y::RowVecs) + return Distances.colwise(d, x.X', y.X') +end + +function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) + return map(d, x, y) +end diff --git a/src/distances/sinus.jl b/src/distances/sinus.jl index f91884c5f..4bcf4bdf0 100644 --- a/src/distances/sinus.jl +++ b/src/distances/sinus.jl @@ -1,5 +1,4 @@ -struct Sinus{T} <: Distances.SemiMetric - # struct Sinus{T} <: Distances.UnionSemiMetric +struct Sinus{T} <: Distances.UnionSemiMetric r::Vector{T} end diff --git a/src/generic.jl b/src/generic.jl index ef8762fef..f161ca5ff 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -6,4 +6,4 @@ Base.iterate(k::Kernel, ::Any) = nothing printshifted(io::IO, o, shift::Int) = print(io, o) # Fallback implementation of evaluate for `SimpleKernel`s. -(k::SimpleKernel)(x, y) = kappa(k, evaluate(metric(k), x, y)) +(k::SimpleKernel)(x, y) = kappa(k, metric(k)(x, y)) diff --git a/src/kernels/kernelproduct.jl b/src/kernels/kernelproduct.jl index ce39dde58..990b4a1bb 100644 --- a/src/kernels/kernelproduct.jl +++ b/src/kernels/kernelproduct.jl @@ -57,6 +57,10 @@ function kernelmatrix_diag(κ::KernelProduct, x::AbstractVector) return reduce(hadamard, kernelmatrix_diag(k, x) for k in κ.kernels) end +function kernelmatrix_diag(κ::KernelProduct, x::AbstractVector, y::AbstractVector) + return reduce(hadamard, kernelmatrix_diag(k, x, y) for k in κ.kernels) +end + function Base.show(io::IO, κ::KernelProduct) return printshifted(io, κ, 0) end diff --git a/src/kernels/kernelsum.jl b/src/kernels/kernelsum.jl index 5dd068b12..6c4c8d499 100644 --- a/src/kernels/kernelsum.jl +++ b/src/kernels/kernelsum.jl @@ -57,6 +57,10 @@ function kernelmatrix_diag(κ::KernelSum, x::AbstractVector) return sum(kernelmatrix_diag(k, x) for k in κ.kernels) end +function kernelmatrix_diag(κ::KernelSum, x::AbstractVector, y::AbstractVector) + return sum(kernelmatrix_diag(k, x, y) for k in κ.kernels) +end + function Base.show(io::IO, κ::KernelSum) return printshifted(io, κ, 0) end diff --git a/src/kernels/kerneltensorproduct.jl b/src/kernels/kerneltensorproduct.jl index c46e204fc..ce9c69d6c 100644 --- a/src/kernels/kerneltensorproduct.jl +++ b/src/kernels/kerneltensorproduct.jl @@ -123,6 +123,11 @@ function kernelmatrix_diag(k::KernelTensorProduct, x::AbstractVector) return mapreduce(kernelmatrix_diag, hadamard, k.kernels, slices(x)) end +function kernelmatrix_diag(k::KernelTensorProduct, x::AbstractVector, y::AbstractVector) + validate_domain(k, x) + return mapreduce(kernelmatrix_diag, hadamard, k.kernels, slices(x), slices(y)) +end + Base.show(io::IO, kernel::KernelTensorProduct) = printshifted(io, kernel, 0) function Base.:(==)(x::KernelTensorProduct, y::KernelTensorProduct) diff --git a/src/kernels/scaledkernel.jl b/src/kernels/scaledkernel.jl index 1e786f83a..4ec1fb42a 100644 --- a/src/kernels/scaledkernel.jl +++ b/src/kernels/scaledkernel.jl @@ -37,6 +37,10 @@ function kernelmatrix_diag(κ::ScaledKernel, x::AbstractVector) return κ.σ² .* kernelmatrix_diag(κ.kernel, x) end +function kernelmatrix_diag(κ::ScaledKernel, x::AbstractVector, y::AbstractVector) + return κ.σ² .* kernelmatrix_diag(κ.kernel, x, y) +end + function kernelmatrix!( K::AbstractMatrix, κ::ScaledKernel, x::AbstractVector, y::AbstractVector ) diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index 5daf38361..6cf693dca 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -80,6 +80,12 @@ function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::Abstrac return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x)) end +function kernelmatrix_diag!( + K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector +) + return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) +end + function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector) return kernelmatrix!(K, kernel(κ), _map(κ.transform, x)) end @@ -94,6 +100,10 @@ function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector) return kernelmatrix_diag(κ.kernel, _map(κ.transform, x)) end +function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::AbstractVector) + return kernelmatrix_diag(κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) +end + function kernelmatrix(κ::TransformedKernel, x::AbstractVector) return kernelmatrix(kernel(κ), _map(κ.transform, x)) end diff --git a/src/matrix/kernelmatrix.jl b/src/matrix/kernelmatrix.jl index f619368a0..f524fc48e 100644 --- a/src/matrix/kernelmatrix.jl +++ b/src/matrix/kernelmatrix.jl @@ -80,7 +80,7 @@ kernelmatrix_diag(κ::Kernel, x::AbstractVector, y::AbstractVector) = map(κ, x, function kernelmatrix!(K::AbstractMatrix, κ::SimpleKernel, x::AbstractVector) validate_inplace_dims(K, x) pairwise!(K, metric(κ), x) - return map!(d -> kappa(κ, d), K, K) + return map!(Base.Fix1(kappa, κ), K, K) end function kernelmatrix!( @@ -88,16 +88,24 @@ function kernelmatrix!( ) validate_inplace_dims(K, x, y) pairwise!(K, metric(κ), x, y) - return map!(d -> kappa(κ, d), K, K) + return map!(Base.Fix1(kappa, κ), K, K) end function kernelmatrix(κ::SimpleKernel, x::AbstractVector) - return map(d -> kappa(κ, d), pairwise(metric(κ), x)) + return map(Base.Fix1(kappa, κ), pairwise(metric(κ), x)) end function kernelmatrix(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) validate_inputs(x, y) - return map(d -> kappa(κ, d), pairwise(metric(κ), x, y)) + return map(Base.Fix1(kappa, κ), pairwise(metric(κ), x, y)) +end + +function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector) + return map(Base.Fix1(kappa, κ), colwise(metric(κ), x)) +end + +function kernelmatrix_diag(κ::SimpleKernel, x::AbstractVector, y::AbstractVector) + return map(Base.Fix1(kappa, κ), colwise(metric(κ), x, y)) end # diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl deleted file mode 100644 index 8a9696b5d..000000000 --- a/src/zygote_adjoints.jl +++ /dev/null @@ -1,98 +0,0 @@ -## Adjoints Delta -@adjoint function evaluate(s::Delta, x::AbstractVector, y::AbstractVector) - return evaluate(s, x, y), Δ -> begin - (nothing, nothing, nothing) - end -end - -@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X, Y; dims=dims) - if dims == 1 - return D, Δ -> (nothing, nothing, nothing) - else - return D, Δ -> (nothing, nothing, nothing) - end -end - -@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X; dims=dims) - if dims == 1 - return D, Δ -> (nothing, nothing) - else - return D, Δ -> (nothing, nothing) - end -end - -## Adjoints DotProduct -@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector) - return dot(x, y), Δ -> begin - (nothing, Δ .* y, Δ .* x) - end -end - -@adjoint function Distances.pairwise( - d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2 -) - D = Distances.pairwise(d, X, Y; dims=dims) - if dims == 1 - return D, Δ -> (nothing, Δ * Y, (X' * Δ)') - else - return D, Δ -> (nothing, (Δ * Y')', X * Δ) - end -end - -@adjoint function Distances.pairwise(d::DotProduct, X::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X; dims=dims) - if dims == 1 - return D, Δ -> (nothing, 2 * Δ * X) - else - return D, Δ -> (nothing, 2 * X * Δ) - end -end - -## Adjoints Sinus -@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector) - d = (x - y) - sind = sinpi.(d) - val = sum(abs2, sind ./ s.r) - gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) - return val, Δ -> begin - ((r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx) - end -end - -@adjoint function ColVecs(X::AbstractMatrix) - ColVecs_pullback(Δ::NamedTuple) = (Δ.X,) - ColVecs_pullback(Δ::AbstractMatrix) = (Δ,) - function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) - return throw(error("In slow method")) - end - return ColVecs(X), ColVecs_pullback -end - -@adjoint function RowVecs(X::AbstractMatrix) - RowVecs_pullback(Δ::NamedTuple) = (Δ.X,) - RowVecs_pullback(Δ::AbstractMatrix) = (Δ,) - function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) - return throw(error("In slow method")) - end - return RowVecs(X), RowVecs_pullback -end - -@adjoint function Base.map(t::Transform, X::ColVecs) - return pullback(_map, t, X) -end - -@adjoint function Base.map(t::Transform, X::RowVecs) - return pullback(_map, t, X) -end - -@adjoint function (dist::Distances.SqMahalanobis)(a, b) - function SqMahalanobis_pullback(Δ::Real) - B_Bᵀ = dist.qmat + transpose(dist.qmat) - a_b = a - b - δa = (B_Bᵀ * a_b) * Δ - return (qmat=(a_b * a_b') * Δ,), δa, -δa - end - return evaluate(dist, a, b), SqMahalanobis_pullback -end diff --git a/src/zygoterules.jl b/src/zygoterules.jl new file mode 100644 index 000000000..88016613d --- /dev/null +++ b/src/zygoterules.jl @@ -0,0 +1,7 @@ +ZygoteRules.@adjoint function Base.map(t::Transform, X::ColVecs) + return ZygoteRules.pullback(_map, t, X) +end + +ZygoteRules.@adjoint function Base.map(t::Transform, X::RowVecs) + return ZygoteRules.pullback(_map, t, X) +end diff --git a/test/Project.toml b/test/Project.toml index 06b69c7c8..166fb590a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" diff --git a/test/basekernels/nn.jl b/test/basekernels/nn.jl index d64312864..c9dabeb69 100644 --- a/test/basekernels/nn.jl +++ b/test/basekernels/nn.jl @@ -8,5 +8,4 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(NeuralNetworkKernel) - @test_broken "Zygote uncompatible with BaseKernel" end diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 000000000..51a545ba1 --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,27 @@ +@testset "Chain Rules" begin + rng = MersenneTwister(123456) + x = rand(rng, 5) + y = rand(rng, 5) + r = rand(rng, 5) + Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0)) + @assert isposdef(Q) + + compare_gradient(:Zygote, [x, y]) do xy + Euclidean()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + SqEuclidean()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + KernelFunctions.DotProduct()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + KernelFunctions.Delta()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + KernelFunctions.Sinus(r)(xy[1], xy[2]) + end + compare_gradient(:Zygote, [Q, x, y]) do xy + SqMahalanobis(xy[1])(xy[2], xy[3]) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 7ad679905..9ca50b15a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using Zygote: Zygote using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff using FiniteDifferences: FiniteDifferences +using ChainRulesTestUtils using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils @@ -146,7 +147,8 @@ include("test_utils.jl") end include("generic.jl") - include("zygote_adjoints.jl") + include("chainrules.jl") + include("zygoterules.jl") @testset "doctests" begin DocMeta.setdocmeta!( diff --git a/test/test_utils.jl b/test/test_utils.jl index d3bb435a6..caa04b8d3 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -50,6 +50,8 @@ end testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B; obsdim=dim)) testfunction(k, A, dim) = sum(kernelmatrix(k, A; obsdim=dim)) +testdiagfunction(k, A, dim) = sum(kernelmatrix_diag(k, A; obsdim=dim)) +testdiagfunction(k, A, B, dim) = sum(kernelmatrix_diag(k, A, B; obsdim=dim)) function test_ADs( kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3] @@ -107,6 +109,21 @@ function test_FiniteDiff(kernelfunction, args=nothing, dims=[3, 3]) testfunction(kernelfunction(p), A, B, dim) end end + + @test_nowarn gradient(:FiniteDiff, A) do a + testdiagfunction(k, a, dim) + end + @test_nowarn gradient(:FiniteDiff, A) do a + testdiagfunction(k, a, B, dim) + end + @test_nowarn gradient(:FiniteDiff, B) do b + testdiagfunction(k, A, b, dim) + end + if args !== nothing + @test_nowarn gradient(:FiniteDiff, args) do p + testdiagfunction(kernelfunction(p), A, B, dim) + end + end end end end @@ -159,6 +176,21 @@ function test_AD(AD::Symbol, kernelfunction, args=nothing, dims=[3, 3]) testfunction(kernelfunction(p), A, dim) end end + + compare_gradient(AD, A) do a + testdiagfunction(k, a, dim) + end + compare_gradient(AD, A) do a + testdiagfunction(k, a, B, dim) + end + compare_gradient(AD, B) do b + testdiagfunction(k, A, b, dim) + end + if args !== nothing + compare_gradient(AD, args) do p + testdiagfunction(kernelfunction(p), A, dim) + end + end end end end diff --git a/test/zygote_adjoints.jl b/test/zygote_adjoints.jl deleted file mode 100644 index 6b349437b..000000000 --- a/test/zygote_adjoints.jl +++ /dev/null @@ -1,53 +0,0 @@ -@testset "zygote_adjoints" begin - rng = MersenneTwister(123456) - x = rand(rng, 5) - y = rand(rng, 5) - r = rand(rng, 5) - Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0)) - @assert isposdef(Q) - - gzeucl = gradient(:Zygote, [x, y]) do xy - evaluate(Euclidean(), xy[1], xy[2]) - end - gzsqeucl = gradient(:Zygote, [x, y]) do xy - evaluate(SqEuclidean(), xy[1], xy[2]) - end - gzdotprod = gradient(:Zygote, [x, y]) do xy - evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]) - end - gzdelta = gradient(:Zygote, [x, y]) do xy - evaluate(KernelFunctions.Delta(), xy[1], xy[2]) - end - gzsinus = gradient(:Zygote, [x, y]) do xy - evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) - end - gzsqmaha = gradient(:Zygote, [Q, x, y]) do xy - evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) - end - - gfeucl = gradient(:FiniteDiff, [x, y]) do xy - evaluate(Euclidean(), xy[1], xy[2]) - end - gfsqeucl = gradient(:FiniteDiff, [x, y]) do xy - evaluate(SqEuclidean(), xy[1], xy[2]) - end - gfdotprod = gradient(:FiniteDiff, [x, y]) do xy - evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]) - end - gfdelta = gradient(:FiniteDiff, [x, y]) do xy - evaluate(KernelFunctions.Delta(), xy[1], xy[2]) - end - gfsinus = gradient(:FiniteDiff, [x, y]) do xy - evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) - end - gfsqmaha = gradient(:FiniteDiff, [Q, x, y]) do xy - evaluate(SqMahalanobis(xy[1]), xy[2], xy[3]) - end - - @test all(gzeucl .≈ gfeucl) - @test all(gzsqeucl .≈ gfsqeucl) - @test all(gzdotprod .≈ gfdotprod) - @test all(gzdelta .≈ gfdelta) - @test all(gzsinus .≈ gfsinus) - @test all(gzsqmaha .≈ gfsqmaha) -end diff --git a/test/zygoterules.jl b/test/zygoterules.jl new file mode 100644 index 000000000..dc3bb98fe --- /dev/null +++ b/test/zygoterules.jl @@ -0,0 +1 @@ +@testset "zygoterules" begin end