-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial integration with ExponentialFamilyProjection #408
Conversation
end | ||
|
||
# # gradient function | ||
## I think this is wrong. This is not a gradient on the manifolds. It is just Euclidean gradient. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We did projection on the last line of this function, X after is actually laying in the Tangent space of the point p.
ef = convert(ExponentialFamilyDistribution, M, p) | ||
ifisher = cholinv(Hermitian(fisherinformation(ef))) | ||
X = ExponentialFamilyProjection.ExponentialFamilyManifolds.partition_point(M, ifisher * ForwardDiff.gradient((p) -> targetfn(M, p, data), p)) | ||
X = ExponentialFamilyProjection.ExponentialFamilyManifolds.ManifoldsBase.project(M, p, X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe ManifoldsBase.project!(M, X, p, X)
?
## Option 3 | ||
# T = ExponentialFamily.exponential_family_typetag(q_out) | ||
# s = sampling_optimized(q_out) | ||
# d = fit_mle(typeof(s), q_out_samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this option?
@@ -279,14 +279,15 @@ end | |||
## We create a lambda-like callable structure to improve type inference and make it more stable | |||
## However it is not fully inferrable due to dynamic tags and variable constraints, but still better than just a raw lambda callback | |||
|
|||
struct MessageMapping{F, T, C, N, M, A, X, R} | |||
struct MessageMapping{F, T, C, N, M, A, X, R, K} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it could be valuable to group fields here into smaller structures. Like there is a certainly cluster of names.
But this is minor comment.
(FactorizedJoint((Exponential(0.5),)), Exponential(3), Rayleigh(1 / (sqrt(2 * 3))), x -> sqrt(x)) | ||
# (FactorizedJoint((Exponential(0.5),)), Exponential(3), Geometric(1 - exp(-0.5)), x->ceil(x)) | ||
# (FactorizedJoint((Exponential(0.5),)), Exponential(3), Pareto(3, 0.5), x->3*exp(x) ) ##exponential family projection errors | ||
# (FactorizedJoint((Exponential(0.5),)), Exponential(30), Beta(0.5, 1), x->exp(-x) ) ##exponential family projection errors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does work # (FactorizedJoint((Exponential(0.5),)), Exponential(30), Beta(0.5, 1), x->exp(-x) ) ##exponential family projection errors
This PR adds initial support for
ExponentialFamilyProjection
inReactiveMP.jl
. It also introduces a new feature for the inference backend that uses fallback rules when explicit message-passing update rules are not available. This functionality acts globally for all nodes, similar tometa
. Currently, only a simple fallback is implemented, and it is applicable only to stochastic nodes.Remaining items:
CVIProjection
This PR is accompanied by a corresponding PR to
RxInfer.jl
. Inference tests are available inRxInfer
, and this PR includes tests for individual rules only.As a note, @ismailsenoz has an improved version of the rules, which require additional time for testing and setup due to dependencies on the
AdvancedHMC
library. The plan is to integrateAdvancedHMC
-based rules following this PR.