-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KernelSum with more than 2 components #64
Comments
Hmm yeah, we'd need to find a way to do this that didn't make that particular hack, since it would type-pirate The current limitation to two is a hangover from the time when Stheno.jl still had its own collection of kernels, and TemporalGPs used them, in which the sum only allowed two things anyway. My point is that we should definitely generalise to arbitrarily many components! I wonder if there's a "flat" way to do this rather that a recursion? That would reflect the structure inside the function TemporalGPs.lgssm_components(k::KernelSum, ts::AbstractVector, storage_type::TemporalGPs.StorageType)
As, as, Qs, emission_projs, x0s = map(k -> TemporalGPs.lgssm_components(k, ts, storage_type), k.kernels)
As = map(TemporalGPs.blk_diag, As)
as = map(vcat, as)
Qs = map(TemporalGPs.blk_diag, Qs)
emission_projections = TemporalGPs._sum_emission_projections(emission_projs)
x0 = TemporalGPs.Gaussian(
reduce(vcat, getfield.(x0s, :m)),
TemporalGPs.blk_diag(getfield.(x0s, :P),
)
return As, as, Qs, emission_projections, x0
end where we generalise |
nice, I will give it a try |
I tried the following: unzip(a) = map(x->getfield.(a, x), fieldnames(eltype(a)))
mapvcat(as) = as
function mapvcat(as, bs, args...)
cs = map(vcat, as, bs)
return mapvcat(cs, args...)
end
TemporalGPs.blk_diag(A) = A
function TemporalGPs.blk_diag(As::AbstractArray, Bs::AbstractArray, args...)
C = map(TemporalGPs.blk_diag, As, Bs)
return TemporalGPs.blk_diag(C, args...)
end
function TemporalGPs.blk_diag(A::SMatrix{DA, DA, T}, B::SMatrix{DB, DB, T}, args...) where {DA, DB, T}
zero_AB = zeros(SMatrix{DA, DB, T})
zero_BA = zeros(SMatrix{DB, DA, T})
C = [[A zero_AB]; [zero_BA B]]
return TemporalGPs.blk_diag(C, args...)
end
function TemporalGPs.blk_diag(A::AbstractMatrix{T}, B::AbstractMatrix{T}, args...) where {T}
C = hvcat(
(2, 2),
A, zeros(T, size(A, 1), size(B, 2)), zeros(T, size(B, 1), size(A, 2)), B,
)
return TemporalGPs.blk_diag(C, args...)
end
TemporalGPs._sum_emission_projections( tup::Tuple{AbstractVector, AbstractVector} ) = tup
function TemporalGPs._sum_emission_projections(
(Hs_l, hs_l)::Tuple{AbstractVector, AbstractVector},
(Hs_r, hs_r)::Tuple{AbstractVector, AbstractVector},
args...
)
Hs = map(vcat, Hs_l, Hs_r)
hs = hs_l + hs_r
return TemporalGPs._sum_emission_projections( (Hs, hs), args...)
end
TemporalGPs._sum_emission_projections(
tup::Tuple{AbstractVector, AbstractVector, AbstractVector, AbstractVector}
) = tup
function TemporalGPs._sum_emission_projections(
(Cs_l, cs_l, Hs_l, hs_l)::Tuple{AbstractVector, AbstractVector, AbstractVector, AbstractVector},
(Cs_r, cs_r, Hs_r, hs_r)::Tuple{AbstractVector, AbstractVector, AbstractVector, AbstractVector},
args...
)
Cs = map(vcat, Cs_l, Cs_r)
cs = cs_l + cs_r
Hs = map(blk_diag, Hs_l, Hs_r)
hs = map(vcat, hs_l, hs_r)
return TemporalGPs._sum_emission_projections( (Cs, cs, Hs, hs), args... )
end
function TemporalGPs.lgssm_components(k::KernelSum, ts::AbstractVector, storage_type::TemporalGPs.StorageType)
components = map(κ -> TemporalGPs.lgssm_components(κ, ts, storage_type), k.kernels)
Ass, ass, Qss, emission_projss, x0s = unzip(components)
As = TemporalGPs.blk_diag(Ass...)
as = mapvcat(ass...)
Qs = TemporalGPs.blk_diag(Qss...)
emission_projs = TemporalGPs._sum_emission_projections(emission_projss...)
m = vcat( getfield.(x0s, :m)... )
P = TemporalGPs.blk_diag( getfield.(x0s, :P)... )
x0 = TemporalGPs.Gaussian( m, P )
return As, as, Qs, emission_projs, x0
end it is still recursive though for the inner functions. Is it for performance that you would like to avoid recursiveness? |
Hi @willtebbutt,
at the moment lgssm_components(k::KernelSum ...) assumes only two components. A quick generalization could be the following:
the modification
KernelFunctions.KernelSum(kernel::Kernel) = kernel
is a bit hacky though.The text was updated successfully, but these errors were encountered: