Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use AbstractGPs #51

Merged
merged 10 commits into from
Mar 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <wt0881@my.bristol.ac.uk>"]
version = "0.4.1"
version = "0.5.0"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Stheno = "8188c328-b5d6-583d-959b-9690869a5511"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractGPs = "0.2"
BlockDiagonals = "0.1.7"
ChainRulesCore = "0.9"
Distributions = "0.24"
FillArrays = "0.10, 0.11"
KernelFunctions = "0.9"
StaticArrays = "1"
Stheno = "0.6.6"
StructArrays = "0.5"
Zygote = "0.6"
ZygoteRules = "0.2"
julia = "1.4"
julia = "1.5"
19 changes: 8 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,27 @@
[![Build Status](https://github.com/willtebbutt/TemporalGPs.jl/workflows/CI/badge.svg)](https://github.com/willtebbutt/TemporalGPs.jl/actions)
[![Codecov](https://codecov.io/gh/willtebbutt/TemporalGPs.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/willtebbutt/TemporalGPs.jl)

TemporalGPs.jl is a tool to make Gaussian processes (GPs) defined using [Stheno.jl](https://github.com/willtebbutt/Stheno.jl/) fast for time-series. It provides a single-function public API that lets you specify that this package should perform inference, rather than Stheno.jl.
TemporalGPs.jl is a tool to make Gaussian processes (GPs) defined using [AbstractGPs.jl](https://https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/) fast for time-series. It provides a single-function public API that lets you specify that this package should perform inference, rather than AbstractGPs.jl.

[JuliaCon 2020 Talk](https://www.youtube.com/watch?v=dysmEpX1QoE)

# Installation

TemporalGPs.jl is registered, so simply type the following at the REPL:
```julia
] add Stheno TemporalGPs
] add AbstractGPs KernelFunctions TemporalGPs
```
While you can install TemporalGPs without Stheno, in practice the latter is needed for all common tasks in TemporalGPs.
While you can install TemporalGPs without AbstractGPs and KernelFunctions, in practice the latter are needed for all common tasks in TemporalGPs.

# Example Usage

This is a small problem by TemporalGPs' standard. See timing results below for expected performance on larger problems.

```julia
using Stheno, TemporalGPs
using AbstractGPs, KernelFunctions, TemporalGPs

# Specify a Stheno.jl GP as usual
f_naive = GP(Matern32(), GPC())
# Specify a AbstractGPs.jl GP as usual
f_naive = GP(Matern32Kernel())

# Wrap it in an object that TemporalGPs knows how to handle.
f = to_sde(f_naive, SArrayStorage(Float64))
Expand Down Expand Up @@ -68,7 +68,7 @@ This tells TemporalGPs that you want all parameters of `f` and anything derived

![](/examples/benchmarks.png)

"naive" timings are with the usual [Stheno.jl](https://github.com/willtebbutt/Stheno.jl/) inference routines, and is the default implementation for GPs. "lgssm" timings are conducted using `to_sde` with no additional arguments. "static-lgssm" uses the `SArrayStorage(Float64)` option discussed above.
"naive" timings are with the usual [AbstractGPs.jl](https://https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/) inference routines, and is the default implementation for GPs. "lgssm" timings are conducted using `to_sde` with no additional arguments. "static-lgssm" uses the `SArrayStorage(Float64)` option discussed above.

Gradient computations use Zygote. Custom adjoints have been implemented to achieve this level of performance.

Expand All @@ -79,11 +79,8 @@ Gradient computations use Zygote. Custom adjoints have been implemented to achie
- Optimisation
+ in-place implementation with `ArrayStorage` to reduce allocations
+ input data types for posterior inference - the `RegularSpacing` type is great for expressing that the inputs are regularly spaced. A carefully constructed data type to let the user build regularly-spaced data when working with posteriors would also be very beneficial.
- Feature coverage
+ only a subset of `Stheno.jl`'s probabilistic-programming functionality is currently available, but it's possible to cover much more.
+ reverse-mode through posterior inference. This is quite straightforward in principle, it just requires a couple of extra ChainRules.
- Interfacing with other packages
+ Both Stheno and this package will move over to the AbstractGPs.jl interface at some point, which will enable both to interface more smoothly with other packages in the ecosystem.
+ When [Stheno.jl](https://github.com/willtebbutt/Stheno.jl/) moves over to the AbstractGPs interface, it should be possible to get some interesting process decomposition functionality in this package.

If you're interested in helping out with this stuff, please get in touch by opening an issue, commenting on an open one, or messaging me on the Julia Slack.

Expand Down
4 changes: 2 additions & 2 deletions bench/lgssm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end
function block_diagonal_dynamics_constructor(rng, N_space, N_time, N_blocks)

# Construct kernel.
k = Separable(EQ(), Matern52())
k = Separable(SEKernel(), Matern52Kernel())
for n in 1:(N_blocks - 1)
k += k
end
Expand Down Expand Up @@ -309,7 +309,7 @@ using Profile, ProfileView
rng = MersenneTwister(123456);
T = 1_000_000;
x = range(0.0; step=0.3, length=T);
f = GP(Matern52() + Matern52() + Matern52() + Matern52(), GPC());
f = GP(Matern52Kernel() + Matern52Kernel() + Matern52Kernel() + Matern52Kernel(), GPC());
fx_sde_dense = to_sde(f)(x);
fx_sde_static = to_sde(f, SArrayStorage(Float64))(x);

Expand Down
2 changes: 1 addition & 1 deletion bench/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ Q = randn(rng, Dlat, Dlat);

# Generate filtering (input) distribution.
mf = randn(rng, Float64, Dlat);
Pf = Symmetric(Stheno.pw(EQ(), range(-10.0, 10.0; length=Dlat)));
Pf = Symmetric(Stheno.kernelmatrix(SEKernel(), range(-10.0, 10.0; length=Dlat)));


# Generate corresponding dense dynamics.
Expand Down
8 changes: 4 additions & 4 deletions bench/single_output_gps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const data_dir = joinpath(datadir(), exp_dir_name)



build_gp(k_base, σ², l) = GP(σ² * stretch(k_base, 1 / l), GPC())
build_gp(k_base, σ², l) = GP(σ² * transform(k_base, 1 / l), GPC())

# Naive implementation.
function build(::Val{:naive}, k_base, σ², l, x, σ²_n)
Expand Down Expand Up @@ -79,9 +79,9 @@ tagsave(
2_000_000, 5_000_000, 10_000_000,
],
:kernels => [
# (k=Matern12(), sym=:Matern12, name="Matern12"),
# (k=Matern32(), sym=:Matern32, name="Matern32"),
(k=Matern52(), sym=:Matern52, name="Matern52"),
# (k=Matern12Kernel(), sym=:Matern12Kernel, name="Matern12Kernel"),
# (k=Matern32Kernel(), sym=:Matern32Kernel, name="Matern32Kernel"),
(k=Matern52Kernel(), sym=:Matern52Kernel, name="Matern52Kernel"),
],
:implementations => [
(
Expand Down
24 changes: 13 additions & 11 deletions src/TemporalGPs.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
module TemporalGPs

using AbstractGPs
using BlockDiagonals
using ChainRulesCore
using Distributions
using FillArrays
using LinearAlgebra
using KernelFunctions
using Random
using StaticArrays
using Stheno
using StructArrays
using Zygote
using ZygoteRules

using FillArrays: AbstractFill
using Zygote: _pullback

import Stheno:
mean,
cov,
pairwise,
logpdf,
AV,
AM,
FiniteGP,
AbstractGP
import AbstractGPs: mean, cov, logpdf, FiniteGP, AbstractGP, posterior, dtc, elbo

using KernelFunctions:
SimpleKernel,
KernelSum,
ScaleTransform,
ScaledKernel,
TransformedKernel

export
to_sde,
Expand All @@ -32,10 +32,12 @@ module TemporalGPs
RegularSpacing,
checkpointed,
posterior,
logpdf_and_rand
logpdf_and_rand,
Separable

# Various bits-and-bobs. Often commiting some type piracy.
include(joinpath("util", "harmonise.jl"))
include(joinpath("util", "linear_algebra.jl"))
include(joinpath("util", "scan.jl"))
include(joinpath("util", "zygote_friendly_map.jl"))
include(joinpath("util", "zygote_rules.jl"))
Expand Down
Loading