Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KernelTensorSum #507

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ TransformedKernel
ScaledKernel
KernelSum
KernelProduct
KernelTensorSum
KernelTensorProduct
NormalizedKernel
```
Expand Down
4 changes: 3 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ export LinearKernel, PolynomialKernel
export RationalKernel, RationalQuadraticKernel, GammaRationalKernel
export PiecewisePolynomialKernel
export PeriodicKernel, NeuralNetworkKernel
export KernelSum, KernelProduct, KernelTensorProduct
export KernelSum, KernelProduct, KernelTensorSum, KernelTensorProduct
export TransformedKernel, ScaledKernel, NormalizedKernel
export GibbsKernel
export ⊕

export Transform,
SelectTransform,
Expand Down Expand Up @@ -108,6 +109,7 @@ include("kernels/normalizedkernel.jl")
include("matrix/kernelmatrix.jl")
include("kernels/kernelsum.jl")
include("kernels/kernelproduct.jl")
include("kernels/kerneltensorsum.jl")
include("kernels/kerneltensorproduct.jl")
include("kernels/overloads.jl")
include("kernels/neuralkernelnetwork.jl")
Expand Down
110 changes: 110 additions & 0 deletions src/kernels/kerneltensorsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
KernelTensorSum

Tensor sum of kernels.

# Definition

For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor
sum of kernels ``k_1, \\ldots, k_n`` is defined as
```math
k(x, x'; k_1, \\ldots, k_n) = \\sum_{i=1}^n k_i(x_i, x'_i).
```

# Construction

The simplest way to specify a `KernelTensorSum` is to use the `⊕` operator (can be typed by `\\oplus<tab>`).
```jldoctest tensorsum
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2);

julia> kernelmatrix(k1 ⊕ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) + kernelmatrix(k2, X[:, 2])
true
```

You can also specify a `KernelTensorSum` by providing kernels as individual arguments
or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or
individual arguments guarantees that `KernelTensorSum` is concretely typed but might
lead to large compilation times if the number of kernels is large.
```jldoctest tensorsum
julia> KernelTensorSum(k1, k2) == k1 ⊕ k2
true

julia> KernelTensorSum((k1, k2)) == k1 ⊕ k2
true

julia> KernelTensorSum([k1, k2]) == k1 ⊕ k2
true
```
"""
struct KernelTensorSum{K} <: Kernel
kernels::K
end

function KernelTensorSum(kernel::Kernel, kernels::Kernel...)
return KernelTensorSum((kernel, kernels...))
end

@functor KernelTensorSum

Base.length(kernel::KernelTensorSum) = length(kernel.kernels)

function (kernel::KernelTensorSum)(x, y)
if !((nx = length(x)) == (ny = length(y)) == (nkernels = length(kernel)))
throw(
DimensionMismatch(
"number of kernels ($nkernels) and number of features (x=$nx, y=$ny) are not consistent",
),
)
end
return sum(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
end

function validate_domain(k::KernelTensorSum, x::AbstractVector, y::AbstractVector)
return (dx = dim(x)) == (dy = dim(y)) == (nkernels = length(k)) || error(
"number of kernels ($nkernels) and group of features (x=$dx), y=$dy) are not consistent",
)
end

function validate_domain(k::KernelTensorSum, x::AbstractVector)
return validate_domain(k, x, x)
end

function kernelmatrix(k::KernelTensorSum, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, +, k.kernels, slices(x))
end

function kernelmatrix(k::KernelTensorSum, x::AbstractVector, y::AbstractVector)
validate_domain(k, x, y)
return mapreduce(kernelmatrix, +, k.kernels, slices(x), slices(y))
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
end

function kernelmatrix_diag(k::KernelTensorSum, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x))
end

function kernelmatrix_diag(k::KernelTensorSum, x::AbstractVector, y::AbstractVector)
validate_domain(k, x, y)
return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x), slices(y))
end

function Base.:(==)(x::KernelTensorSum, y::KernelTensorSum)
return (
length(x.kernels) == length(y.kernels) &&
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
)
end

Base.show(io::IO, kernel::KernelTensorSum) = printshifted(io, kernel, 0)

function printshifted(io::IO, kernel::KernelTensorSum, shift::Int)
print(io, "Tensor sum of ", length(kernel), " kernels:")
for k in kernel.kernels
print(io, "\n")
for _ in 1:(shift + 1)
print(io, "\t")
end
printshifted(io, k, shift + 2)
end
end
4 changes: 4 additions & 0 deletions src/kernels/overloads.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
function tensor_sum end
const ⊕ = tensor_sum

for (M, op, T) in (
(:Base, :+, :KernelSum),
(:Base, :*, :KernelProduct),
(:TensorCore, :tensor, :KernelTensorProduct),
(:KernelFunctions, :⊕, :KernelTensorSum),
)
@eval begin
$M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2)
Expand Down
67 changes: 67 additions & 0 deletions test/kernels/kerneltensorsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
@testset "kerneltensorsum" begin
rng = MersenneTwister(123456)
u1 = rand(rng, 10)
u2 = rand(rng, 10)
v1 = rand(rng, 5)
v2 = rand(rng, 5)

# kernels
k1 = SqExponentialKernel()
k2 = ExponentialKernel()
kernel1 = KernelTensorSum(k1, k2)
kernel2 = KernelTensorSum([k1, k2])

@test kernel1 == kernel2
@test kernel1.kernels == (k1, k2) === KernelTensorSum((k1, k2)).kernels
for (_k1, _k2) in Iterators.product(
(k1, KernelTensorSum((k1,)), KernelTensorSum([k1])),
(k2, KernelTensorSum((k2,)), KernelTensorSum([k2])),
)
@test kernel1 == _k1 ⊕ _k2
end
@test length(kernel1) == length(kernel2) == 2
@test string(kernel1) == (
"Independent sum of 2 kernels:\n" *
"\tSquared Exponential Kernel (metric = Euclidean(0.0))\n" *
"\tExponential Kernel (metric = Euclidean(0.0))"
)
@test_throws DimensionMismatch kernel1(rand(3), rand(3))

@testset "val" begin
for (x, y) in (((v1, u1), (v2, u2)), ([v1, u1], [v2, u2]))
val = k1(x[1], y[1]) + k2(x[2], y[2])

@test kernel1(x, y) == kernel2(x, y) == val
end
end

# Standardised tests.
TestUtils.test_interface(kernel1, ColVecs{Float64})
TestUtils.test_interface(kernel1, RowVecs{Float64})
TestUtils.test_interface(
KernelTensorSum(WhiteKernel(), ConstantKernel(; c=1.1)), ColVecs{String}
)
test_ADs(
x -> KernelTensorSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))),
rand(1);
dims=[2, 2],
)
types = [ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}}]
test_interface_ad_perf(2.1, StableRNG(123456), types) do c
KernelTensorSum(SqExponentialKernel(), LinearKernel(; c=c))
end
test_params(KernelTensorSum(k1, k2), (k1, k2))

@testset "single kernel" begin
kernel = KernelTensorSum(k1)
@test length(kernel) == 1

@testset "eval" begin
for (x, y) in (((v1,), (v2,)), ([v1], [v2]))
val = k1(x[1], y[1])

@test kernel(x, y) == val
end
end
end
end
5 changes: 3 additions & 2 deletions test/kernels/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
k2 = SqExponentialKernel()
k3 = RationalQuadraticKernel()

for (op, T) in ((+, KernelSum), (*, KernelProduct), (⊗, KernelTensorProduct))
if T === KernelTensorProduct
for (op, T) in
((+, KernelSum), (*, KernelProduct), (⊗, KernelTensorProduct), (⊕, KernelTensorSum))
if T === KernelTensorProduct || T === KernelTensorSum
v2_1 = rand(rng, 2)
v2_2 = rand(rng, 2)
v3_1 = rand(rng, 3)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ include("test_utils.jl")
include("kernels/kernelproduct.jl")
include("kernels/kernelsum.jl")
include("kernels/kerneltensorproduct.jl")
include("kernels/kerneltensorsum.jl")
include("kernels/overloads.jl")
include("kernels/scaledkernel.jl")
include("kernels/transformedkernel.jl")
Expand Down