Skip to content

Commit

Permalink
Use AbstractGPs (#51)
Browse files Browse the repository at this point in the history
* Remove Stheno add AbstractGPs and KernelFunctions

* Bump minor version

* Remove Stheno references from README

* Change Stheno for AbstractGPs

* Replace more of Stheno with AbstractGPs

* Avoid method ambiguitty in posterior

* kerneldiagmatrix -> kernelmatrix_diag

* Require KernelFunctions@0.9

* Resolve outstanding AbstractGPs problems
  • Loading branch information
willtebbutt committed Mar 27, 2021
1 parent babba0a commit 3e28705
Show file tree
Hide file tree
Showing 29 changed files with 321 additions and 258 deletions.
10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <wt0881@my.bristol.ac.uk>"]
version = "0.4.1"
version = "0.5.0"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Stheno = "8188c328-b5d6-583d-959b-9690869a5511"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractGPs = "0.2"
BlockDiagonals = "0.1.7"
ChainRulesCore = "0.9"
Distributions = "0.24"
FillArrays = "0.10, 0.11"
KernelFunctions = "0.9"
StaticArrays = "1"
Stheno = "0.6.6"
StructArrays = "0.5"
Zygote = "0.6"
ZygoteRules = "0.2"
julia = "1.4"
julia = "1.5"
19 changes: 8 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,27 @@
[![Build Status](https://github.com/willtebbutt/TemporalGPs.jl/workflows/CI/badge.svg)](https://github.com/willtebbutt/TemporalGPs.jl/actions)
[![Codecov](https://codecov.io/gh/willtebbutt/TemporalGPs.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/willtebbutt/TemporalGPs.jl)

TemporalGPs.jl is a tool to make Gaussian processes (GPs) defined using [Stheno.jl](https://github.com/willtebbutt/Stheno.jl/) fast for time-series. It provides a single-function public API that lets you specify that this package should perform inference, rather than Stheno.jl.
TemporalGPs.jl is a tool to make Gaussian processes (GPs) defined using [AbstractGPs.jl](https://https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/) fast for time-series. It provides a single-function public API that lets you specify that this package should perform inference, rather than AbstractGPs.jl.

[JuliaCon 2020 Talk](https://www.youtube.com/watch?v=dysmEpX1QoE)

# Installation

TemporalGPs.jl is registered, so simply type the following at the REPL:
```julia
] add Stheno TemporalGPs
] add AbstractGPs KernelFunctions TemporalGPs
```
While you can install TemporalGPs without Stheno, in practice the latter is needed for all common tasks in TemporalGPs.
While you can install TemporalGPs without AbstractGPs and KernelFunctions, in practice the latter are needed for all common tasks in TemporalGPs.

# Example Usage

This is a small problem by TemporalGPs' standard. See timing results below for expected performance on larger problems.

```julia
using Stheno, TemporalGPs
using AbstractGPs, KernelFunctions, TemporalGPs

# Specify a Stheno.jl GP as usual
f_naive = GP(Matern32(), GPC())
# Specify a AbstractGPs.jl GP as usual
f_naive = GP(Matern32Kernel())

# Wrap it in an object that TemporalGPs knows how to handle.
f = to_sde(f_naive, SArrayStorage(Float64))
Expand Down Expand Up @@ -68,7 +68,7 @@ This tells TemporalGPs that you want all parameters of `f` and anything derived

![](/examples/benchmarks.png)

"naive" timings are with the usual [Stheno.jl](https://github.com/willtebbutt/Stheno.jl/) inference routines, and is the default implementation for GPs. "lgssm" timings are conducted using `to_sde` with no additional arguments. "static-lgssm" uses the `SArrayStorage(Float64)` option discussed above.
"naive" timings are with the usual [AbstractGPs.jl](https://https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/) inference routines, and is the default implementation for GPs. "lgssm" timings are conducted using `to_sde` with no additional arguments. "static-lgssm" uses the `SArrayStorage(Float64)` option discussed above.

Gradient computations use Zygote. Custom adjoints have been implemented to achieve this level of performance.

Expand All @@ -79,11 +79,8 @@ Gradient computations use Zygote. Custom adjoints have been implemented to achie
- Optimisation
+ in-place implementation with `ArrayStorage` to reduce allocations
+ input data types for posterior inference - the `RegularSpacing` type is great for expressing that the inputs are regularly spaced. A carefully constructed data type to let the user build regularly-spaced data when working with posteriors would also be very beneficial.
- Feature coverage
+ only a subset of `Stheno.jl`'s probabilistic-programming functionality is currently available, but it's possible to cover much more.
+ reverse-mode through posterior inference. This is quite straightforward in principle, it just requires a couple of extra ChainRules.
- Interfacing with other packages
+ Both Stheno and this package will move over to the AbstractGPs.jl interface at some point, which will enable both to interface more smoothly with other packages in the ecosystem.
+ When [Stheno.jl](https://github.com/willtebbutt/Stheno.jl/) moves over to the AbstractGPs interface, it should be possible to get some interesting process decomposition functionality in this package.

If you're interested in helping out with this stuff, please get in touch by opening an issue, commenting on an open one, or messaging me on the Julia Slack.

Expand Down
4 changes: 2 additions & 2 deletions bench/lgssm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end
function block_diagonal_dynamics_constructor(rng, N_space, N_time, N_blocks)

# Construct kernel.
k = Separable(EQ(), Matern52())
k = Separable(SEKernel(), Matern52Kernel())
for n in 1:(N_blocks - 1)
k += k
end
Expand Down Expand Up @@ -309,7 +309,7 @@ using Profile, ProfileView
rng = MersenneTwister(123456);
T = 1_000_000;
x = range(0.0; step=0.3, length=T);
f = GP(Matern52() + Matern52() + Matern52() + Matern52(), GPC());
f = GP(Matern52Kernel() + Matern52Kernel() + Matern52Kernel() + Matern52Kernel(), GPC());
fx_sde_dense = to_sde(f)(x);
fx_sde_static = to_sde(f, SArrayStorage(Float64))(x);

Expand Down
2 changes: 1 addition & 1 deletion bench/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ Q = randn(rng, Dlat, Dlat);

# Generate filtering (input) distribution.
mf = randn(rng, Float64, Dlat);
Pf = Symmetric(Stheno.pw(EQ(), range(-10.0, 10.0; length=Dlat)));
Pf = Symmetric(Stheno.kernelmatrix(SEKernel(), range(-10.0, 10.0; length=Dlat)));


# Generate corresponding dense dynamics.
Expand Down
8 changes: 4 additions & 4 deletions bench/single_output_gps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const data_dir = joinpath(datadir(), exp_dir_name)



build_gp(k_base, σ², l) = GP(σ² * stretch(k_base, 1 / l), GPC())
build_gp(k_base, σ², l) = GP(σ² * transform(k_base, 1 / l), GPC())

# Naive implementation.
function build(::Val{:naive}, k_base, σ², l, x, σ²_n)
Expand Down Expand Up @@ -79,9 +79,9 @@ tagsave(
2_000_000, 5_000_000, 10_000_000,
],
:kernels => [
# (k=Matern12(), sym=:Matern12, name="Matern12"),
# (k=Matern32(), sym=:Matern32, name="Matern32"),
(k=Matern52(), sym=:Matern52, name="Matern52"),
# (k=Matern12Kernel(), sym=:Matern12Kernel, name="Matern12Kernel"),
# (k=Matern32Kernel(), sym=:Matern32Kernel, name="Matern32Kernel"),
(k=Matern52Kernel(), sym=:Matern52Kernel, name="Matern52Kernel"),
],
:implementations => [
(
Expand Down
24 changes: 13 additions & 11 deletions src/TemporalGPs.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
module TemporalGPs

using AbstractGPs
using BlockDiagonals
using ChainRulesCore
using Distributions
using FillArrays
using LinearAlgebra
using KernelFunctions
using Random
using StaticArrays
using Stheno
using StructArrays
using Zygote
using ZygoteRules

using FillArrays: AbstractFill
using Zygote: _pullback

import Stheno:
mean,
cov,
pairwise,
logpdf,
AV,
AM,
FiniteGP,
AbstractGP
import AbstractGPs: mean, cov, logpdf, FiniteGP, AbstractGP, posterior, dtc, elbo

using KernelFunctions:
SimpleKernel,
KernelSum,
ScaleTransform,
ScaledKernel,
TransformedKernel

export
to_sde,
Expand All @@ -32,10 +32,12 @@ module TemporalGPs
RegularSpacing,
checkpointed,
posterior,
logpdf_and_rand
logpdf_and_rand,
Separable

# Various bits-and-bobs. Often commiting some type piracy.
include(joinpath("util", "harmonise.jl"))
include(joinpath("util", "linear_algebra.jl"))
include(joinpath("util", "scan.jl"))
include(joinpath("util", "zygote_friendly_map.jl"))
include(joinpath("util", "zygote_rules.jl"))
Expand Down
Loading

0 comments on commit 3e28705

Please sign in to comment.