-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial integration with ExponentialFamilyProjection #408
Merged
Merged
Changes from 35 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
d3296c6
play around with auto node functions
bvdmitri fd7fb85
Merge branch 'main' into dev-ef-projection
bvdmitri 341a7ef
fix comp issue
bvdmitri e5800e1
2prev
bvdmitri f7a3c36
2prev
bvdmitri 3b0a9a3
add tests for the nodefunctions
bvdmitri 925c14b
remove duplicate fn definition
bvdmitri 48b6536
add rule fallbacks
bvdmitri 5cf9eaf
add generic fallback
bvdmitri 3bc9b1e
rulefallback for deltanode
bvdmitri 9880e17
2prev
bvdmitri 90a06f5
cvi projection rules
bvdmitri ada4be9
create cvi projection extension
bvdmitri 503a067
2prev
bvdmitri 5aa71a8
2prev
bvdmitri 0dc0c86
fix example
bvdmitri 4504541
tests for generic rule fallback
bvdmitri bc9ed5b
use DivisionOf structure
bvdmitri 32609b8
prettify code
bvdmitri dba0e3e
multi input
bvdmitri 1d65cf3
fix delta marginal rule
bvdmitri 31f0dde
add more tests
bvdmitri 0bbdc4c
2prev
bvdmitri 7138876
22prev
bvdmitri b864037
more tests
bvdmitri e56a975
update
bvdmitri ce37534
reimplement extension export
bvdmitri cb08e1b
2prev
bvdmitri d272194
comment
bvdmitri b50b74e
style: make format
bvdmitri 7ed4659
fix tests
bvdmitri 348ce9b
update the documentation with docstrings
bvdmitri 856f8a5
add documentation to `CVIProjection`
bvdmitri ebc1b21
more tests
bvdmitri 8a4ceae
disable debug in output rules
bvdmitri 279d474
uncomment tests
bvdmitri File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
module ReactiveMPProjectionExt | ||
|
||
using ReactiveMP, ExponentialFamily, Distributions, ExponentialFamilyProjection, BayesBase, Random, LinearAlgebra, FastCholesky | ||
|
||
struct DivisionOf{A, B} | ||
numerator::A | ||
denumerator::B | ||
end | ||
|
||
BayesBase.insupport(d::DivisionOf, p) = insupport(d.numerator, p) && insupport(d.denumerator, p) | ||
BayesBase.logpdf(d::DivisionOf, p) = logpdf(d.numerator, p) - logpdf(d.denumerator, p) | ||
|
||
function BayesBase.prod(::GenericProd, something, division::DivisionOf) | ||
return prod(GenericProd(), division, something) | ||
end | ||
|
||
function BayesBase.prod(::GenericProd, division::DivisionOf, something) | ||
if division.denumerator == something | ||
return division.numerator | ||
else | ||
return ProductOf(division, something) | ||
end | ||
end | ||
|
||
include("layout/cvi_projection.jl") | ||
include("rules/in.jl") | ||
include("rules/out.jl") | ||
include("rules/marginals.jl") | ||
|
||
# This will enable the extension and make `CVIProjection` compatible with delta nodes | ||
# Otherwise it should throw an error suggesting users to install `ExponentialFamilyProjection` | ||
# See `approximations/cvi_projection.jl` | ||
ReactiveMP.is_delta_node_compatible(::ReactiveMP.CVIProjection) = Val(true) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
|
||
using Rocket | ||
import ReactiveMP: | ||
deltafn_rule_layout, | ||
deltafn_apply_layout, | ||
AbstractDeltaNodeDependenciesLayout, | ||
DeltaFnDefaultRuleLayout, | ||
DeltaFnNode, | ||
getmarginal, | ||
functionalform, | ||
tag, | ||
Marginalisation, | ||
MessageMapping, | ||
DefferedMessage, | ||
with_statics, | ||
apply_pipeline_stage, | ||
messageout, | ||
messagein, | ||
connect! | ||
|
||
""" | ||
CVIProjectionApproximationDeltaFnRuleLayout | ||
|
||
Custom rule layout for the Delta node in case of the CVI projection approximation method: | ||
|
||
# Layout | ||
|
||
In order to compute: | ||
|
||
- `q_out`: mirrors the posterior marginal on the `out` edge | ||
- `q_ins`: uses inbound message on the `out` edge and all inbound messages on the `ins` edges | ||
- `m_out`: uses the posterior over `out`, message from `out` and the joint over the `ins` edges | ||
- `m_in_k`: uses the inbound message on the `in_k` edge and `q_ins` | ||
""" | ||
struct CVIProjectionApproximationDeltaFnRuleLayout <: AbstractDeltaNodeDependenciesLayout end | ||
|
||
deltafn_rule_layout(::DeltaFnNode, ::CVIProjection, inverse::Nothing) = CVIProjectionApproximationDeltaFnRuleLayout() | ||
|
||
function deltafn_rule_layout(::DeltaFnNode, ::CVIProjection, inverse::Any) | ||
@warn "CVI projection approximation does not accept the inverse function. Ignoring the provided inverse." | ||
return CVIProjectionApproximationDeltaFnRuleLayout() | ||
end | ||
|
||
# This function declares how to compute `q_out` locally around `DeltaFn` | ||
function deltafn_apply_layout(::CVIProjectionApproximationDeltaFnRuleLayout, ::Val{:q_out}, factornode::DeltaFnNode, meta, pipeline_stages, scheduler, addons, rulefallback) | ||
return deltafn_apply_layout(DeltaFnDefaultRuleLayout(), Val(:q_out), factornode, meta, pipeline_stages, scheduler, addons, rulefallback) | ||
end | ||
|
||
# This function declares how to compute `q_ins` locally around `DeltaFn` | ||
function deltafn_apply_layout(::CVIProjectionApproximationDeltaFnRuleLayout, ::Val{:q_ins}, factornode::DeltaFnNode, meta, pipeline_stages, scheduler, addons, rulefallback) | ||
return deltafn_apply_layout(DeltaFnDefaultRuleLayout(), Val(:q_ins), factornode, meta, pipeline_stages, scheduler, addons, rulefallback) | ||
end | ||
|
||
# This function declares how to compute `m_out` | ||
function deltafn_apply_layout(::CVIProjectionApproximationDeltaFnRuleLayout, ::Val{:m_out}, factornode::DeltaFnNode, meta, pipeline_stages, scheduler, addons, rulefallback) | ||
let interface = factornode.out | ||
msgs_names = Val{(:out,)}() | ||
msgs_observable = combineLatestUpdates((messagein(factornode.out),), PushNew()) | ||
|
||
marginal_names = Val{(:out, :ins)}() | ||
marginals_observable = combineLatestUpdates((getmarginal(factornode.localmarginals.marginals[1]), getmarginal(factornode.localmarginals.marginals[2])), PushNew()) | ||
|
||
fform = functionalform(factornode) | ||
vtag = tag(interface) | ||
vconstraint = Marginalisation() | ||
|
||
vmessageout = combineLatest((msgs_observable, marginals_observable), PushNew()) | ||
|
||
mapping = let messagemap = MessageMapping(fform, vtag, vconstraint, msgs_names, marginal_names, meta, addons, factornode, rulefallback) | ||
(dependencies) -> DefferedMessage(dependencies[1], dependencies[2], messagemap) | ||
end | ||
|
||
vmessageout = with_statics(factornode, vmessageout) | ||
vmessageout = vmessageout |> map(AbstractMessage, mapping) | ||
vmessageout = apply_pipeline_stage(pipeline_stages, factornode, vtag, vmessageout) | ||
vmessageout = vmessageout |> schedule_on(scheduler) | ||
|
||
connect!(messageout(interface), vmessageout) | ||
end | ||
end | ||
|
||
# This function declares how to compute `m_in` for each `k` | ||
function deltafn_apply_layout(::CVIProjectionApproximationDeltaFnRuleLayout, ::Val{:m_in}, factornode::DeltaFnNode, meta, pipeline_stages, scheduler, addons, rulefallback) | ||
return deltafn_apply_layout(DeltaFnDefaultRuleLayout(), Val(:m_in), factornode, meta, pipeline_stages, scheduler, addons, rulefallback) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
|
||
@rule DeltaFn((:in, k), Marginalisation) (q_ins::FactorizedJoint, m_in::Any, meta::DeltaMeta{M}) where {M <: CVIProjection} = begin | ||
q_ins_k = component(q_ins, k) | ||
return DivisionOf(q_ins_k, m_in) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
using TupleTools | ||
|
||
import Distributions: Distribution | ||
import BayesBase: AbstractContinuousGenericLogPdf | ||
|
||
@marginalrule DeltaFn(:ins) (m_out::Any, m_ins::ManyOf{1, Any}, meta::DeltaMeta{M}) where {M <: CVIProjection} = begin | ||
method = ReactiveMP.getmethod(meta) | ||
g = getnodefn(meta, Val(:out)) | ||
|
||
m_in = first(m_ins) | ||
# Create an `AbstractContinuousGenericLogPdf` with an unspecified domain and the transformed `logpdf` function | ||
F = promote_variate_type(variate_form(typeof(m_in)), BayesBase.AbstractContinuousGenericLogPdf) | ||
f = convert(F, UnspecifiedDomain(), (z) -> logpdf(m_out, g(z))) | ||
|
||
T = ExponentialFamily.exponential_family_typetag(m_in) | ||
prj = ProjectedTo(T, size(m_in)...; parameters = something(method.prjparams, ExponentialFamilyProjection.DefaultProjectionParameters())) | ||
q = project_to(prj, f, first(m_ins)) | ||
|
||
return FactorizedJoint((q,)) | ||
end | ||
|
||
@marginalrule DeltaFn(:ins) (m_out::Any, m_ins::ManyOf{N, Any}, meta::DeltaMeta{M}) where {N, M <: CVIProjection} = begin | ||
method = ReactiveMP.getmethod(meta) | ||
rng = method.rng | ||
pre_samples = zip(map(m_in_k -> ReactiveMP.cvilinearize(rand(rng, m_in_k, method.marginalsamples)), m_ins)...) | ||
|
||
logp_nc_drop_index = let g = getnodefn(meta, Val(:out)), pre_samples = pre_samples | ||
(z, i, pre_samples) -> begin | ||
samples = map(ttuple -> ReactiveMP.TupleTools.insertat(ttuple, i, (z,)), pre_samples) | ||
t_samples = map(s -> g(s...), samples) | ||
logpdfs = map(out -> logpdf(m_out, out), t_samples) | ||
return mean(logpdfs) | ||
end | ||
end | ||
|
||
optimize_natural_parameters = let m_ins = m_ins, logp_nc_drop_index = logp_nc_drop_index | ||
(i, pre_samples) -> begin | ||
# Create an `AbstractContinuousGenericLogPdf` with an unspecified domain and the transformed `logpdf` function | ||
df = let i = i, pre_samples = pre_samples, logp_nc_drop_index = logp_nc_drop_index | ||
(z) -> logp_nc_drop_index(z, i, pre_samples) | ||
end | ||
logp = convert(promote_variate_type(variate_form(typeof(first(m_ins))), BayesBase.AbstractContinuousGenericLogPdf), UnspecifiedDomain(), df) | ||
|
||
T = ExponentialFamily.exponential_family_typetag(m_ins[i]) | ||
prj = ProjectedTo(T, size(m_ins[i])...; parameters = something(method.prjparams, ExponentialFamilyProjection.DefaultProjectionParameters())) | ||
|
||
return project_to(prj, logp, m_ins[i]) | ||
end | ||
end | ||
|
||
return FactorizedJoint(ntuple(i -> optimize_natural_parameters(i, pre_samples), length(m_ins))) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
using ForwardDiff | ||
|
||
# cost function | ||
function targetfn(M, p, data) | ||
ef = convert(ExponentialFamilyDistribution, M, p) | ||
return -sum((d) -> logpdf(ef, d), data) | ||
end | ||
|
||
# # gradient function | ||
## I think this is wrong. This is not a gradient on the manifolds. It is just Euclidean gradient. | ||
function grad_targetfn(M, p, data) | ||
ef = convert(ExponentialFamilyDistribution, M, p) | ||
ifisher = cholinv(Hermitian(fisherinformation(ef))) | ||
X = ExponentialFamilyProjection.ExponentialFamilyManifolds.partition_point(M, ifisher * ForwardDiff.gradient((p) -> targetfn(M, p, data), p)) | ||
X = ExponentialFamilyProjection.ExponentialFamilyManifolds.ManifoldsBase.project(M, p, X) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe |
||
return X | ||
end | ||
|
||
@rule DeltaFn(:out, Marginalisation) (m_out::Any, q_out::Any, q_ins::FactorizedJoint, meta::DeltaMeta{U}) where {U <: CVIProjection} = begin | ||
node_function = getnodefn(meta, Val(:out)) | ||
method = ReactiveMP.getmethod(meta) | ||
rng = method.rng | ||
q_ins_components = components(q_ins) | ||
dimensions = map(size, q_ins_components) | ||
q_ins_sample_friendly = map(q_in -> sampling_optimized(q_in), q_ins_components) | ||
## Option 1 | ||
# samples = map(i -> collect(map(q -> rand(rng, q), q_ins_sample_friendly)), 1:method.out_samples_no) | ||
# q_out_samples = map(sample -> node_function(ReactiveMP.__splitjoin(sample, dimensions)...), samples) | ||
|
||
## Option 2 | ||
samples = map(ReactiveMP.cvilinearize, map(q_in -> rand(rng, q_in, method.outsamples), q_ins_sample_friendly)) | ||
q_out_samples = map(x -> node_function(x...), zip(samples...)) | ||
|
||
## Option 3 | ||
# T = ExponentialFamily.exponential_family_typetag(q_out) | ||
# s = sampling_optimized(q_out) | ||
# d = fit_mle(typeof(s), q_out_samples) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this option? |
||
# m = DivisionOf(d, m_out) | ||
# r = project_to(ProjectedTo(T, size(q_out)...; parameters = method.prjparams), (x) -> logpdf(m, x)) | ||
# return r | ||
|
||
T = ExponentialFamily.exponential_family_typetag(q_out) | ||
q_out_ef = convert(ExponentialFamilyDistribution, q_out) | ||
conditioner = getconditioner(q_out_ef) | ||
manifold = ExponentialFamilyProjection.ExponentialFamilyManifolds.get_natural_manifold(T, size(mean(q_out_ef)), conditioner) | ||
nat_params = ExponentialFamilyProjection.ExponentialFamilyManifolds.partition_point(manifold, getnaturalparameters(q_out_ef)) | ||
|
||
f = (M, p) -> targetfn(M, p, q_out_samples) | ||
g = (M, p) -> grad_targetfn(M, p, q_out_samples) | ||
|
||
est = convert( | ||
ExponentialFamilyDistribution, | ||
manifold, | ||
ExponentialFamilyProjection.Manopt.gradient_descent( | ||
manifold, f, g, nat_params; | ||
stepsize = ExponentialFamilyProjection.Manopt.ConstantStepsize(0.1), | ||
direction = ExponentialFamilyProjection.BoundedNormUpdateRule(1), | ||
debug = missing | ||
) | ||
) | ||
# return x -> logpdf(est, x) - logpdf(m_out, x) | ||
return DivisionOf(est, m_out) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
export CVIProjection | ||
|
||
""" | ||
CVIProjection(; parameters...) | ||
|
||
A structure representing the parameters for the Conjugate Variational Inference (CVI) projection method. | ||
This structure is a subtype of `AbstractApproximationMethod` and is used to configure the settings for CVI. | ||
|
||
!!! note | ||
The `CVIProjection` method requires `ExponentialFamilyProjection` package installed in the current environment. | ||
|
||
# Parameters | ||
|
||
- `rng::R`: The random number generator used for sampling. Default is `Random.MersenneTwister(42)`. | ||
- `marginalsamples::S`: The number of samples used for approximating marginal distributions. Default is `10`. | ||
- `outsamples::S`: The number of samples used for approximating output message distributions. Default is `100`. | ||
- `prjparams::P`: Parameters for the exponential family projection. Default is `nothing`, in which case it will use `ExponentialFamilyProjection.DefaultProjectionParameters()`. | ||
|
||
!!! note | ||
The `CVIProjection` method is an experimental enhancement of the now-deprecated `CVI`, offering better stability and improved accuracy. | ||
Note that the parameters of this structure, as well as their defaults, are subject to change during the experimentation phase. | ||
""" | ||
Base.@kwdef struct CVIProjection{R, S, P} <: AbstractApproximationMethod | ||
rng::R = Random.MersenneTwister(42) | ||
marginalsamples::S = 10 | ||
outsamples::S = 100 | ||
prjparams::P = nothing # ExponentialFamilyProjection.DefaultProjectionParameters() | ||
end | ||
|
||
# This method should only be invoked if a user did not install `ExponentialFamilyProjection` | ||
# in the current Julia session | ||
check_delta_node_compatibility(::Val{false}, ::CVIProjection) = error("CVI projection requires `using ExponentialFamilyProjection` in the current session.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We did projection on the last line of this function, X after is actually laying in the Tangent space of the point p.