Skip to content

Commit

Permalink
Hacky fix for approx periodic (#122)
Browse files Browse the repository at this point in the history
* Add failing test and fix formatting

* Write to_sde for approx periodic

* Bump patch

* Restrict usage to Arrays

* Remove known failure case

* Fix ambiguity

* Remove test for disallowed implementation
  • Loading branch information
willtebbutt committed Dec 22, 2023
1 parent 3c56beb commit 2e6ccab
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 50 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <wt0881@my.bristol.ac.uk> and contributors"]
version = "0.6.5"
version = "0.6.6"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
63 changes: 24 additions & 39 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ function stationary_distribution(::CosineKernel, ::SArrayStorage{T}) where {T<:R
return Gaussian(m, P)
end

# Approximate Periodic Kernel
# ApproxPeriodicKernel

# The periodic kernel is approximated by a sum of cosine kernels with different frequencies.
struct ApproxPeriodicKernel{N,K<:PeriodicKernel} <: KernelFunctions.SimpleKernel
kernel::K
Expand All @@ -279,53 +280,37 @@ function Base.show(io::IO, κ::ApproxPeriodicKernel{N}) where {N}
return print(io, "Approximate Periodic Kernel, (r = $(only.kernel.r))) approximated with $N cosine kernels")
end

function lgssm_components(approx::ApproxPeriodicKernel{N}, t::Union{StepRangeLen, RegularSpacing}, storage::StorageType{T}) where {N,T<:Real}
Fs, Hs, ms, Ps = _init_periodic_kernel_lgssm(approx.kernel, storage, N)
nt = length(t)
As = map(F -> Fill(time_exp(F, T(step(t))), nt), Fs)
return _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
end
function lgssm_components(approx::ApproxPeriodicKernel{N}, t::AbstractVector{<:Real}, storage::StorageType{T}) where {N,T<:Real}
Fs, Hs, ms, Ps = _init_periodic_kernel_lgssm(approx.kernel, storage, N)
t = vcat([first(t) - 1], t)
nt = length(diff(t))
As = _map(F -> _map(Δt -> time_exp(F, T(Δt)), diff(t)), Fs)
return _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
end
# Can't use approx periodic kernel with static arrays -- the dimensions become too large.
_ap_error() = throw(error("Unable to construct an ApproxPeriodicKernel for SArrayStorage"))
to_sde(::ApproxPeriodicKernel, ::SArrayStorage) = _ap_error()
stationary_distribution(::ApproxPeriodicKernel, ::SArrayStorage) = _ap_error()

function _init_periodic_kernel_lgssm(kernel::PeriodicKernel, storage, N::Int=7)
r = kernel.r
l⁻² = inv(4 * only(r)^2)

function to_sde(::ApproxPeriodicKernel{N}, storage::ArrayStorage{T}) where {T<:Real, N}

# Compute F and H for component processes.
F, _, H = to_sde(CosineKernel(), storage)
Fs = ntuple(N) do i
2π * (i - 1) * F
end
Hs = Fill(H, N)

# Combine component processes into a single whole.
F = block_diagonal(collect.(Fs)...)
q = zero(T)
H = repeat(collect(H), N)
return F, q, H
end

function stationary_distribution(kernel::ApproxPeriodicKernel{N}, storage::ArrayStorage{<:Real}) where {N}
x0 = stationary_distribution(CosineKernel(), storage)
ms = Fill(x0.m, N)
P = x0.P
m = collect(repeat(x0.m, N))
r = kernel.kernel.r
l⁻² = inv(4 * only(r)^2)
Ps = ntuple(N) do j
qⱼ = (1 + (j !== 1) ) * besseli(j - 1, l⁻²) / exp(l⁻²)
qⱼ * P
end

Fs, Hs, ms, Ps
end

function _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
as = Fill(Fill(Zeros{T}(size(first(first(As)), 1)), nt), N)
Qs = _map((P, A) -> _map(A -> Symmetric(P) - A * Symmetric(P) * A', A), Ps, As)
Hs = Fill(vcat(Hs...), nt)
h = Fill(zero(T), nt)
As = _map(block_diagonal, As...)
as = -map(vcat, as...)
Qs = _map(block_diagonal, Qs...)
m = reduce(vcat, ms)
P = block_diagonal(Ps...)
x0 = Gaussian(m, P)
return As, as, Qs, (Hs, h), x0
return qⱼ * x0.P
end
P = collect(block_diagonal(Ps...))
return Gaussian(m, P)
end

# Constant
Expand Down
29 changes: 19 additions & 10 deletions test/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ end
@testset "$(typeof(t)), $storage, $N" for t in (
sort(rand(Nt)), RegularSpacing(0.0, 0.1, Nt)
),
storage in (SArrayStorage{Float64}(), ArrayStorage{Float64}()),
storage in (ArrayStorage{Float64}(), ),
N in (5, 8)

k = ApproxPeriodicKernel{N}()
Expand Down Expand Up @@ -131,6 +131,12 @@ println("lti_sde:")
val=3.0 * Matern32Kernel() * Matern52Kernel() * ConstantKernel(),
to_vec_grad=nothing,
),
# THIS IS KNOWN NOT TO WORK!
# (
# name="prod-(Matern32Kernel + ConstantKernel) * Matern52Kernel",
# val=(Matern32Kernel() + ConstantKernel()) * Matern52Kernel(),
# to_vec_grad=nothing,
# ),

# Summed kernels.
(
Expand All @@ -149,18 +155,21 @@ println("lti_sde:")
)

# Construct a Gauss-Markov model with either dense storage or static storage.
storages = ((name="dense storage Float64", val=ArrayStorage(Float64)),
# (name="static storage Float64", val=SArrayStorage(Float64)),
)
storages = (
(name="dense storage Float64", val=ArrayStorage(Float64)),
# (name="static storage Float64", val=SArrayStorage(Float64)),
)

# Either regular spacing or irregular spacing in time.
ts = ((name="irregular spacing", val=collect(RegularSpacing(0.0, 0.3, N))),
# (name="regular spacing", val=RegularSpacing(0.0, 0.3, N)),
)
ts = (
(name="irregular spacing", val=collect(RegularSpacing(0.0, 0.3, N))),
# (name="regular spacing", val=RegularSpacing(0.0, 0.3, N)),
)

σ²s = ((name="homoscedastic noise", val=(0.1,)),
# (name="heteroscedastic noise", val=(rand(rng, N) .+ 1e-1, )),
)
σ²s = (
(name="homoscedastic noise", val=(0.1,)),
# (name="heteroscedastic noise", val=(rand(rng, N) .+ 1e-1, )),
)

means = (
(name="Zero Mean", val=ZeroMean()),
Expand Down

2 comments on commit 2e6ccab

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/97642

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.6 -m "<description of version>" 2e6ccabae93a4c293856f8d47f2422e352e0f523
git push origin v0.6.6

Please sign in to comment.