Skip to content

Commit

Permalink
Resolve outstanding AbstractGPs problems
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Mar 26, 2021
1 parent 3f9ce3a commit c6e516f
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 16 deletions.
5 changes: 3 additions & 2 deletions src/TemporalGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module TemporalGPs
using FillArrays: AbstractFill
using Zygote: _pullback

import AbstractGPs: mean, cov, logpdf, FiniteGP, AbstractGP, posterior
import AbstractGPs: mean, cov, logpdf, FiniteGP, AbstractGP, posterior, dtc, elbo

using KernelFunctions:
SimpleKernel,
Expand All @@ -32,7 +32,8 @@ module TemporalGPs
RegularSpacing,
checkpointed,
posterior,
logpdf_and_rand
logpdf_and_rand,
Separable

# Various bits-and-bobs. Often commiting some type piracy.
include(joinpath("util", "harmonise.jl"))
Expand Down
12 changes: 6 additions & 6 deletions src/space_time/pseudo_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dtcify(z::AbstractVector, fx::FiniteLTISDE) = FiniteGP(dtcify(z, fx.f), fx.x, fx

dtcify(z::AbstractVector, fx::LTISDE) = LTISDE(dtcify(z, fx.f), fx.storage)

dtcify(z::AbstractVector, f::GP) = GP(f.m, dtcify(z, f.k))
dtcify(z::AbstractVector, f::GP) = GP(f.mean, dtcify(z, f.kernel))

"""
dtc(fx::FiniteLTISDE, y::AbstractVector{<:Real}, z_r::AbstractVector)
Expand Down Expand Up @@ -73,7 +73,7 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect
Σs = lgssm.emissions.fan_out.Q
marg_diags = time_ad(Val(:disabled), "marg_diags", marginals_diag, lgssm)

k = fx_dtc.f.f.k
k = fx_dtc.f.f.kernel
Cf_diags = time_ad(Val(:disabled), "Cf_diags", kernel_diagonals, k, fx_dtc.x)

# Transform a vector into a vector-of-vectors.
Expand Down Expand Up @@ -263,15 +263,15 @@ function approx_posterior_marginals(
z_r::AbstractVector,
x_r::AbstractVector,
)
fx.f.f.m isa AbstractGPs.ZeroMean || throw(error("Prior mean of GP isn't zero."))
fx.f.f.mean isa AbstractGPs.ZeroMean || throw(error("Prior mean of GP isn't zero."))

# Compute approximate posterior LGSSM.
lgssm = build_lgssm(dtcify(z_r, fx))
fx_post = posterior(lgssm, restructure(y, lgssm.emissions))

# Compute the new emission distributions + approx posterior model.
x_pr = RectilinearGrid(x_r, get_time(fx.x))
k_dtc = dtcify(z_r, fx.f.f.k)
k_dtc = dtcify(z_r, fx.f.f.kernel)
new_proj, Σs = dtc_post_emissions(k_dtc, x_pr, fx.f.storage)
new_fx_post = LGSSM(fx_post.transitions, build_emissions(new_proj, Σs))

Expand Down Expand Up @@ -318,7 +318,7 @@ function approx_posterior_marginals(
x_pr = RegularInTime(ts, x_rs)

# Compute the new emission distributions + approx posterior model.
k_dtc = dtcify(z_r, fx.f.f.k)
k_dtc = dtcify(z_r, fx.f.f.kernel)
new_proj, Σs = dtc_post_emissions(k_dtc, x_pr, fx.f.storage)
new_fx_post = LGSSM(fx_post.transitions, build_emissions(new_proj, Σs))

Expand All @@ -343,7 +343,7 @@ function approx_posterior_marginals(
fx_post = posterior(lgssm, restructure(y, lgssm.emissions))

# Compute the new emission distributions + approx posterior model.
k_dtc = dtcify(z_r, fx.f.f.k)
k_dtc = dtcify(z_r, fx.f.f.kernel)
new_proj, Σs = dtc_post_emissions(k_dtc, x_pr, fx.f.storage)
new_fx_post = LGSSM(fx_post.transitions, build_emissions(new_proj, Σs))

Expand Down
2 changes: 1 addition & 1 deletion src/util/linear_algebra.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@inline symmetric(X::AbstractMatrix) = Symmetric(X)
@inline symmetric(X::Diagonal) = X

diag_Xt_invA_X(A::Cholesky, X::AbstractVecOrMat) = diag_At_A(A.U' \ X)
diag_Xt_invA_X(A::Cholesky, X::AbstractVecOrMat) = AbstractGPs.diag_At_A(A.U' \ X)

Xt_invA_X(A::Cholesky, x::AbstractVector) = sum(abs2, A.U' \ x)

Expand Down
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ include("test_util.jl")
include(joinpath("space_time", "rectilinear_grid.jl"))
include(joinpath("space_time", "regular_in_time.jl"))
include(joinpath("space_time", "separable_kernel.jl"))
# include(joinpath("space_time", "to_gauss_markov.jl"))
# include(joinpath("space_time", "pseudo_point.jl"))
include(joinpath("space_time", "to_gauss_markov.jl"))
include(joinpath("space_time", "pseudo_point.jl"))
end
end
8 changes: 4 additions & 4 deletions test/space_time/pseudo_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ using TemporalGPs:
z = randn(rng, 3)
k_sep = Separable(SEKernel(), Matern32Kernel())
@test dtcify(z, k_sep) isa DTCSeparable
@test dtcify(z, 0.5 * k_sep) isa ScaledKernel{<:Any, <:DTCSeparable}
@test dtcify(z, 0.5 * k_sep) isa ScaledKernel{<:DTCSeparable}
@test dtcify(z, transform(k_sep, 0.5)) isa TransformedKernel{<:DTCSeparable}
@test dtcify(z, k_sep + k_sep) isa KernelSum{<:DTCSeparable, <:DTCSeparable}
@test dtcify(z, k_sep + k_sep) isa KernelSum{<:Tuple{DTCSeparable, DTCSeparable}}
end

# A couple of "base" kernels used as components in more complicated kernels below.
Expand All @@ -39,7 +39,7 @@ using TemporalGPs:
),

(name="sum-separable-1", val=separable_1 + separable_2),
(name="sum-separable-2", val=1.3 * separable_1 + separable_2 * 0.95),
(name="sum-separable-2", val=1.3 * separable_1 + 0.95 * separable_2),
]

# Input locations.
Expand Down Expand Up @@ -158,7 +158,7 @@ using TemporalGPs:

# Compute approximate posterior marginals naively with missings.
f_approx_post_naive = approx_posterior(
VFE, fx_naive, naive_y_missings, f_naive(z_naive),
VFE(), fx_naive, naive_y_missings, f_naive(z_naive),
)
naive_approx_post_marginals = marginals(f_approx_post_naive(collect(x_pr)))

Expand Down

0 comments on commit c6e516f

Please sign in to comment.