Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KernelTensorSum #507

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ TransformedKernel
ScaledKernel
KernelSum
KernelProduct
KernelTensorSum
KernelTensorProduct
NormalizedKernel
```
Expand Down
4 changes: 3 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ export LinearKernel, PolynomialKernel
export RationalKernel, RationalQuadraticKernel, GammaRationalKernel
export PiecewisePolynomialKernel
export PeriodicKernel, NeuralNetworkKernel
export KernelSum, KernelProduct, KernelTensorProduct
export KernelSum, KernelProduct, KernelTensorSum, KernelTensorProduct
export TransformedKernel, ScaledKernel, NormalizedKernel
export GibbsKernel
export ⊕

export Transform,
SelectTransform,
Expand Down Expand Up @@ -108,6 +109,7 @@ include("kernels/normalizedkernel.jl")
include("matrix/kernelmatrix.jl")
include("kernels/kernelsum.jl")
include("kernels/kernelproduct.jl")
include("kernels/kerneltensorsum.jl")
include("kernels/kerneltensorproduct.jl")
include("kernels/overloads.jl")
include("kernels/neuralkernelnetwork.jl")
Expand Down
102 changes: 102 additions & 0 deletions src/kernels/kerneltensorsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
KernelTensorSum

Tensor sum of kernels.

# Definition

For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor
sum of kernels ``k_1, \\ldots, k_n`` is defined as
```math
k(x, x'; k_1, \\ldots, k_n) = \\sum_{i=1}^n k_i(x_i, x'_i).
```

# Construction

The simplest way to specify a `KernelTensorSum` is to use the `⊕` operator (can be typed by `\\oplus<tab>`).
```jldoctest tensorproduct
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2);

julia> kernelmatrix(k1 ⊕ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) + kernelmatrix(k2, X[:, 2])
true
```

You can also specify a `KernelTensorSum` by providing kernels as individual arguments
or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or
individual arguments guarantees that `KernelTensorSum` is concretely typed but might
lead to large compilation times if the number of kernels is large.
```jldoctest tensorproduct
julia> KernelTensorSum(k1, k2) == k1 ⊕ k2
true

julia> KernelTensorSum((k1, k2)) == k1 ⊕ k2
true

julia> KernelTensorSum([k1, k2]) == k1 ⊕ k2
true
```
"""
struct KernelTensorSum{K} <: Kernel
kernels::K
end

function KernelTensorSum(kernel::Kernel, kernels::Kernel...)
return KernelTensorSum((kernel, kernels...))
end

@functor KernelTensorSum

Base.length(kernel::KernelTensorSum) = length(kernel.kernels)

function (kernel::KernelTensorSum)(x, y)
if !(length(x) == length(y) == length(kernel))
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
throw(DimensionMismatch("number of kernels and number of features
are not consistent"))
end
return sum(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
end

function validate_domain(k::KernelTensorSum, x::AbstractVector)
return dim(x) == length(k) ||
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
error("number of kernels and groups of features are not consistent")
end

function kernelmatrix(k::KernelTensorSum, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, +, k.kernels, slices(x))
end

function kernelmatrix(k::KernelTensorSum, x::AbstractVector, y::AbstractVector)
validate_domain(k, x)
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
return mapreduce(kernelmatrix, +, k.kernels, slices(x), slices(y))
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
end

function kernelmatrix_diag(k::KernelTensorSum, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x))
end

function kernelmatrix_diag(k::KernelTensorSum, x::AbstractVector, y::AbstractVector)
validate_domain(k, x)
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x), slices(y))
end

function Base.:(==)(x::KernelTensorSum, y::KernelTensorSum)
return (
length(x.kernels) == length(y.kernels) &&
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
)
end

Base.show(io::IO, kernel::KernelTensorSum) = printshifted(io, kernel, 0)

function printshifted(io::IO, kernel::KernelTensorSum, shift::Int)
print(io, "Tensor sum of ", length(kernel), " kernels:")
for k in kernel.kernels
print(io, "\n")
for _ in 1:(shift + 1)
print(io, "\t")
end
printshifted(io, k, shift + 2)
end
end
3 changes: 3 additions & 0 deletions src/kernels/overloads.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
function ⊕ end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too generic to be defined and exported from KernelFunctions. Is it not part of TensorCore or some other lightweight interface package? We would also a non-Unicode alias, as for other keyword arguments and functions.

Copy link
Author

@martincornejo martincornejo May 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that, but is not part of TensorCore or as far as I know any other lightweight package (https://juliahub.com/ui/Search?q=%E2%8A%95&type=symbols). It is a help constructor for the new KernelTensorSum/KernelIndependentSum, so the non-Unicode function is already available.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestions wellcome on how to improve this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no non-Unicode alternative similar to +, *, or tensor yet as far as I can tell?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! I see, you're right

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolve?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too generic to be defined and exported from KernelFunctions.

This problem is not fixed yet, is it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But since TensorCore.jl does not define , what should we do? Here are the packages that use . Kronecker.jl is one, but I guess we do not want to add this as a dependency, only to re-use the symbol.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make a PR to TensorCore. I think the operator should not be owned by KernelFunctions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


for (M, op, T) in (
(:Base, :+, :KernelSum),
(:Base, :*, :KernelProduct),
(:TensorCore, :tensor, :KernelTensorProduct),
(:KernelFunctions, :⊕, :KernelTensorSum),
)
@eval begin
$M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2)
Expand Down
67 changes: 67 additions & 0 deletions test/kernels/kerneltensorsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
@testset "kerneltensorsum" begin
rng = MersenneTwister(123456)
u1 = rand(rng, 10)
u2 = rand(rng, 10)
v1 = rand(rng, 5)
v2 = rand(rng, 5)

# kernels
k1 = SqExponentialKernel()
k2 = ExponentialKernel()
kernel1 = KernelTensorSum(k1, k2)
kernel2 = KernelTensorSum([k1, k2])

@test kernel1 == kernel2
@test kernel1.kernels == (k1, k2) === KernelTensorSum((k1, k2)).kernels
for (_k1, _k2) in Iterators.product(
(k1, KernelTensorSum((k1,)), KernelTensorSum([k1])),
(k2, KernelTensorSum((k2,)), KernelTensorSum([k2])),
)
@test kernel1 == _k1 ⊕ _k2
end
@test length(kernel1) == length(kernel2) == 2
@test string(kernel1) == (
"Tensor sum of 2 kernels:\n" *
"\tSquared Exponential Kernel (metric = Euclidean(0.0))\n" *
"\tExponential Kernel (metric = Euclidean(0.0))"
)
@test_throws DimensionMismatch kernel1(rand(3), rand(3))

@testset "val" begin
for (x, y) in (((v1, u1), (v2, u2)), ([v1, u1], [v2, u2]))
val = k1(x[1], y[1]) + k2(x[2], y[2])

@test kernel1(x, y) == kernel2(x, y) == val
end
end

# Standardised tests.
TestUtils.test_interface(kernel1, ColVecs{Float64})
TestUtils.test_interface(kernel1, RowVecs{Float64})
TestUtils.test_interface(
KernelTensorSum(WhiteKernel(), ConstantKernel(; c=1.1)), ColVecs{String}
)
test_ADs(
x -> KernelTensorSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))),
rand(1);
dims=[2, 2],
)
types = [ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}}]
test_interface_ad_perf(2.1, StableRNG(123456), types) do c
KernelTensorSum(SqExponentialKernel(), LinearKernel(; c=c))
end
test_params(KernelTensorSum(k1, k2), (k1, k2))

@testset "single kernel" begin
kernel = KernelTensorSum(k1)
@test length(kernel) == 1

@testset "eval" begin
for (x, y) in (((v1,), (v2,)), ([v1], [v2]))
val = k1(x[1], y[1])

@test kernel(x, y) == val
end
end
end
end
2 changes: 1 addition & 1 deletion test/kernels/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
k2 = SqExponentialKernel()
k3 = RationalQuadraticKernel()

for (op, T) in ((+, KernelSum), (*, KernelProduct), (⊗, KernelTensorProduct))
for (op, T) in ((+, KernelSum), (*, KernelProduct), (⊗, KernelTensorProduct), (⊕, KernelTensorSum))
if T === KernelTensorProduct
v2_1 = rand(rng, 2)
v2_2 = rand(rng, 2)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ include("test_utils.jl")
include("kernels/kernelproduct.jl")
include("kernels/kernelsum.jl")
include("kernels/kerneltensorproduct.jl")
include("kernels/kerneltensorsum.jl")
include("kernels/overloads.jl")
include("kernels/scaledkernel.jl")
include("kernels/transformedkernel.jl")
Expand Down