Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial integration with ExponentialFamilyProjection #408

Merged
merged 36 commits into from
Jul 19, 2024
Merged

Conversation

bvdmitri
Copy link
Member

@bvdmitri bvdmitri commented Jul 18, 2024

This PR adds initial support for ExponentialFamilyProjection in ReactiveMP.jl. It also introduces a new feature for the inference backend that uses fallback rules when explicit message-passing update rules are not available. This functionality acts globally for all nodes, similar to meta. Currently, only a simple fallback is implemented, and it is applicable only to stochastic nodes.

Remaining items:

  • Document the arguments for CVIProjection

This PR is accompanied by a corresponding PR to RxInfer.jl. Inference tests are available in RxInfer, and this PR includes tests for individual rules only.

As a note, @ismailsenoz has an improved version of the rules, which require additional time for testing and setup due to dependencies on the AdvancedHMC library. The plan is to integrate AdvancedHMC-based rules following this PR.

@bvdmitri bvdmitri requested a review from albertpod July 18, 2024 11:22
@bvdmitri bvdmitri marked this pull request as ready for review July 19, 2024 08:45
end

# # gradient function
## I think this is wrong. This is not a gradient on the manifolds. It is just Euclidean gradient.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We did projection on the last line of this function, X after is actually laying in the Tangent space of the point p.

ef = convert(ExponentialFamilyDistribution, M, p)
ifisher = cholinv(Hermitian(fisherinformation(ef)))
X = ExponentialFamilyProjection.ExponentialFamilyManifolds.partition_point(M, ifisher * ForwardDiff.gradient((p) -> targetfn(M, p, data), p))
X = ExponentialFamilyProjection.ExponentialFamilyManifolds.ManifoldsBase.project(M, p, X)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ManifoldsBase.project!(M, X, p, X)?

## Option 3
# T = ExponentialFamily.exponential_family_typetag(q_out)
# s = sampling_optimized(q_out)
# d = fit_mle(typeof(s), q_out_samples)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this option?

@@ -279,14 +279,15 @@ end
## We create a lambda-like callable structure to improve type inference and make it more stable
## However it is not fully inferrable due to dynamic tags and variable constraints, but still better than just a raw lambda callback

struct MessageMapping{F, T, C, N, M, A, X, R}
struct MessageMapping{F, T, C, N, M, A, X, R, K}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be valuable to group fields here into smaller structures. Like there is a certainly cluster of names.
But this is minor comment.

(FactorizedJoint((Exponential(0.5),)), Exponential(3), Rayleigh(1 / (sqrt(2 * 3))), x -> sqrt(x))
# (FactorizedJoint((Exponential(0.5),)), Exponential(3), Geometric(1 - exp(-0.5)), x->ceil(x))
# (FactorizedJoint((Exponential(0.5),)), Exponential(3), Pareto(3, 0.5), x->3*exp(x) ) ##exponential family projection errors
# (FactorizedJoint((Exponential(0.5),)), Exponential(30), Beta(0.5, 1), x->exp(-x) ) ##exponential family projection errors
Copy link
Contributor

@Nimrais Nimrais Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test does work # (FactorizedJoint((Exponential(0.5),)), Exponential(30), Beta(0.5, 1), x->exp(-x) ) ##exponential family projection errors

@bvdmitri bvdmitri merged commit c379f7b into main Jul 19, 2024
3 checks passed
@bvdmitri bvdmitri deleted the dev-ef-projection branch July 19, 2024 11:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants