Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KernelSum with more than 2 components #64

Closed
andreaskoher opened this issue Apr 13, 2021 · 3 comments
Closed

KernelSum with more than 2 components #64

andreaskoher opened this issue Apr 13, 2021 · 3 comments
Assignees

Comments

@andreaskoher
Copy link
Contributor

Hi @willtebbutt,
at the moment lgssm_components(k::KernelSum ...) assumes only two components. A quick generalization could be the following:

KernelFunctions.KernelSum(kernel::Kernel) = kernel

function TemporalGPs.lgssm_components(k::KernelSum, ts::AbstractVector, storage_type::TemporalGPs.StorageType)
    As_l, as_l, Qs_l, emission_proj_l, x0_l = TemporalGPs.lgssm_components(k.kernels[1], ts, storage_type)
    As_r, as_r, Qs_r, emission_proj_r, x0_r = TemporalGPs.lgssm_components( KernelSum( k.kernels[2:end]... ), ts, storage_type)

    As = map(TemporalGPs.blk_diag, As_l, As_r)
    as = map(vcat, as_l, as_r)
    Qs = map(TemporalGPs.blk_diag, Qs_l, Qs_r)
    emission_projections = TemporalGPs._sum_emission_projections(emission_proj_l, emission_proj_r)
    x0 = TemporalGPs.Gaussian(vcat(x0_l.m, x0_r.m), TemporalGPs.blk_diag(x0_l.P, x0_r.P))

    return As, as, Qs, emission_projections, x0
end

the modification KernelFunctions.KernelSum(kernel::Kernel) = kernel is a bit hacky though.

@willtebbutt
Copy link
Member

willtebbutt commented Apr 13, 2021

the modification KernelFunctions.KernelSum(kernel::Kernel) = kernel is a bit hacky though.

Hmm yeah, we'd need to find a way to do this that didn't make that particular hack, since it would type-pirate KernelFunctions :)

The current limitation to two is a hangover from the time when Stheno.jl still had its own collection of kernels, and TemporalGPs used them, in which the sum only allowed two things anyway. My point is that we should definitely generalise to arbitrarily many components!

I wonder if there's a "flat" way to do this rather that a recursion? That would reflect the structure inside the KernelSum anyway. Something like

function TemporalGPs.lgssm_components(k::KernelSum, ts::AbstractVector, storage_type::TemporalGPs.StorageType)
    As, as, Qs, emission_projs, x0s = map(k -> TemporalGPs.lgssm_components(k, ts, storage_type), k.kernels)

    As = map(TemporalGPs.blk_diag, As)
    as = map(vcat, as)
    Qs = map(TemporalGPs.blk_diag, Qs)
    emission_projections = TemporalGPs._sum_emission_projections(emission_projs)
    x0 = TemporalGPs.Gaussian(
        reduce(vcat, getfield.(x0s, :m)),
        TemporalGPs.blk_diag(getfield.(x0s, :P),
    )

    return As, as, Qs, emission_projections, x0
end

where we generalise blk_diag and _sum_emission_projections to handle Tuples and Vectors of things?

@andreaskoher
Copy link
Contributor Author

nice, I will give it a try

@andreaskoher
Copy link
Contributor Author

I tried the following:

unzip(a) = map(x->getfield.(a, x), fieldnames(eltype(a)))

mapvcat(as) = as
function mapvcat(as, bs, args...)
    cs = map(vcat, as, bs)
    return mapvcat(cs, args...)
end

TemporalGPs.blk_diag(A) = A
function TemporalGPs.blk_diag(As::AbstractArray, Bs::AbstractArray, args...)
    C = map(TemporalGPs.blk_diag, As, Bs)
    return TemporalGPs.blk_diag(C, args...)
end

function TemporalGPs.blk_diag(A::SMatrix{DA, DA, T}, B::SMatrix{DB, DB, T}, args...) where {DA, DB, T}
    zero_AB = zeros(SMatrix{DA, DB, T})
    zero_BA = zeros(SMatrix{DB, DA, T})
    C = [[A zero_AB]; [zero_BA B]]
    return TemporalGPs.blk_diag(C, args...)
end

function TemporalGPs.blk_diag(A::AbstractMatrix{T}, B::AbstractMatrix{T}, args...) where {T}
     C = hvcat(
        (2, 2),
        A, zeros(T, size(A, 1), size(B, 2)), zeros(T, size(B, 1), size(A, 2)), B,
    )
    return TemporalGPs.blk_diag(C, args...)
end

TemporalGPs._sum_emission_projections( tup::Tuple{AbstractVector, AbstractVector} ) = tup
function TemporalGPs._sum_emission_projections(
    (Hs_l, hs_l)::Tuple{AbstractVector, AbstractVector},
    (Hs_r, hs_r)::Tuple{AbstractVector, AbstractVector},
    args...
)
    Hs = map(vcat, Hs_l, Hs_r)
    hs = hs_l + hs_r
    return TemporalGPs._sum_emission_projections( (Hs, hs), args...)
end

TemporalGPs._sum_emission_projections(
    tup::Tuple{AbstractVector, AbstractVector, AbstractVector, AbstractVector}
) = tup
function TemporalGPs._sum_emission_projections(
    (Cs_l, cs_l, Hs_l, hs_l)::Tuple{AbstractVector, AbstractVector, AbstractVector, AbstractVector},
    (Cs_r, cs_r, Hs_r, hs_r)::Tuple{AbstractVector, AbstractVector, AbstractVector, AbstractVector},
    args...
)
    Cs = map(vcat, Cs_l, Cs_r)
    cs = cs_l + cs_r
    Hs = map(blk_diag, Hs_l, Hs_r)
    hs = map(vcat, hs_l, hs_r)
    return TemporalGPs._sum_emission_projections( (Cs, cs, Hs, hs), args... )
end

function TemporalGPs.lgssm_components(k::KernelSum, ts::AbstractVector, storage_type::TemporalGPs.StorageType)
    components = map-> TemporalGPs.lgssm_components(κ, ts, storage_type), k.kernels)
    Ass, ass, Qss, emission_projss, x0s = unzip(components)

    As = TemporalGPs.blk_diag(Ass...)
    as = mapvcat(ass...)
    Qs = TemporalGPs.blk_diag(Qss...)
    emission_projs = TemporalGPs._sum_emission_projections(emission_projss...)
    m  = vcat( getfield.(x0s, :m)... )
    P  = TemporalGPs.blk_diag( getfield.(x0s, :P)... )
    x0 = TemporalGPs.Gaussian( m, P )
    return As, as, Qs, emission_projs, x0
end

it is still recursive though for the inner functions. Is it for performance that you would like to avoid recursiveness?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants