Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify inference for predictions functionality #51

Merged
merged 35 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b310e3a
Modify inference for predictions functionality
albertpod Jan 24, 2023
bc6ad92
Merge branch 'main' into dev-predict
albertpod Jan 24, 2023
abc5edb
Make format
albertpod Jan 24, 2023
b67474c
WIP: update inference function
albertpod Jan 30, 2023
5fe1ac6
Make format
albertpod Jan 30, 2023
1c9eb9b
WIP: Change inference
albertpod Jan 30, 2023
11077a0
Update inference
albertpod Feb 1, 2023
a9cb543
Merge branch 'main' into dev-predict
albertpod Feb 2, 2023
39c91ce
Add tests
albertpod Feb 6, 2023
4b223ed
Merge branch 'main' into dev-predict
albertpod Feb 7, 2023
b7aed19
Merge branch 'main' of https://github.com/biaslab/RxInfer.jl into dev…
albertpod Feb 13, 2023
c77e146
Merge branch 'main' into dev-predict
albertpod Feb 22, 2023
cef2171
Merge branch 'main' into dev-predict
albertpod Mar 6, 2023
528164d
Merge main into dev-predict
albertpod Jun 19, 2023
976a38e
Merge branch 'main' into dev-predict
albertpod Jul 23, 2023
31f35f5
Merge branch 'main' into dev-predict
albertpod Sep 9, 2023
eb9c691
Merge branch 'main' into dev-predict
albertpod Sep 12, 2023
ccf4489
fix: fix datavar tests
bvdmitri Sep 13, 2023
7f53ab7
improve check data is missing
bvdmitri Sep 13, 2023
3589fa0
more tests
bvdmitri Sep 13, 2023
b4a29fb
Update inference function
albertpod Sep 13, 2023
9f59b0d
Make format
albertpod Sep 13, 2023
447efcc
Make format
albertpod Sep 13, 2023
bb993ac
Update src/inference.jl
albertpod Sep 18, 2023
829b921
Update src/inference.jl
albertpod Sep 18, 2023
5eb6c46
Update src/inference.jl
albertpod Sep 18, 2023
d98d3c3
Update src/inference.jl
albertpod Sep 18, 2023
b0dd137
Add prediction test for coin model
albertpod Sep 18, 2023
2d2ac7f
Update tests
albertpod Sep 18, 2023
3eb5cda
Make format
albertpod Sep 18, 2023
1a54d0f
fix tests
bvdmitri Sep 18, 2023
9dd6738
fix inference tests
bvdmitri Sep 18, 2023
bdafd5d
fix examples
bvdmitri Sep 18, 2023
47ccf2b
update: Bump version to 2.12.0
bvdmitri Sep 19, 2023
c309034
Merge branch 'main' into dev-predict
bvdmitri Sep 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 129 additions & 26 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ import DataStructures: CircularBuffer

using MacroTools # for `@autoupdates`

import ReactiveMP: israndom, isdata, isconst, isproxy, isanonymous
import ReactiveMP: israndom, isdata, isconst, isproxy, isanonymous, allows_missings
import ReactiveMP: CountingReal

import ProgressMeter

obtain_prediction(variable::AbstractVariable) = getprediction(variable)
obtain_prediction(variables::AbstractArray{<:AbstractVariable}) = getpredictions(variables)

obtain_marginal(variable::AbstractVariable, strategy = SkipInitial()) = getmarginal(variable, strategy)
obtain_marginal(variables::AbstractArray{<:AbstractVariable}, strategy = SkipInitial()) = getmarginals(variables, strategy)

Expand All @@ -24,19 +27,19 @@ assign_message!(variable::AbstractVariable, message) = setmes
struct KeepEach end
struct KeepLast end

make_actor(::RandomVariable, ::KeepEach) = keep(Marginal)
make_actor(::Array{<:RandomVariable, N}, ::KeepEach) where {N} = keep(Array{Marginal, N})
make_actor(x::AbstractArray{<:RandomVariable}, ::KeepEach) = keep(typeof(similar(x, Marginal)))
make_actor(::AbstractVariable, ::KeepEach) = keep(Marginal)
make_actor(::Array{<:AbstractVariable, N}, ::KeepEach) where {N} = keep(Array{Marginal, N})
make_actor(x::AbstractArray{<:AbstractVariable}, ::KeepEach) = keep(typeof(similar(x, Marginal)))

make_actor(::RandomVariable, ::KeepEach, capacity::Integer) = circularkeep(Marginal, capacity)
make_actor(::Array{<:RandomVariable, N}, ::KeepEach, capacity::Integer) where {N} = circularkeep(Array{Marginal, N}, capacity)
make_actor(x::AbstractArray{<:RandomVariable}, ::KeepEach, capacity::Integer) = circularkeep(typeof(similar(x, Marginal)), capacity)
make_actor(::AbstractVariable, ::KeepEach, capacity::Integer) = circularkeep(Marginal, capacity)
make_actor(::Array{<:AbstractVariable, N}, ::KeepEach, capacity::Integer) where {N} = circularkeep(Array{Marginal, N}, capacity)
make_actor(x::AbstractArray{<:AbstractVariable}, ::KeepEach, capacity::Integer) = circularkeep(typeof(similar(x, Marginal)), capacity)

make_actor(::RandomVariable, ::KeepLast) = storage(Marginal)
make_actor(x::AbstractArray{<:RandomVariable}, ::KeepLast) = buffer(Marginal, size(x))
make_actor(::AbstractVariable, ::KeepLast) = storage(Marginal)
make_actor(x::AbstractArray{<:AbstractVariable}, ::KeepLast) = buffer(Marginal, size(x))

make_actor(::RandomVariable, ::KeepLast, capacity::Integer) = storage(Marginal)
make_actor(x::AbstractArray{<:RandomVariable}, ::KeepLast, capacity::Integer) = buffer(Marginal, size(x))
make_actor(::AbstractVariable, ::KeepLast, capacity::Integer) = storage(Marginal)
make_actor(x::AbstractArray{<:AbstractVariable}, ::KeepLast, capacity::Integer) = buffer(Marginal, size(x))

## Inference ensure update

Expand Down Expand Up @@ -141,6 +144,12 @@ function __inference_check_dicttype(keyword::Symbol, ::T) where {T}
""")
end

__inference_check_dataismissing(d) = (ismissing(d) || any(ismissing, d))

# Return NamedTuple for predictions
__inference_fill_predictions(s::Symbol, d::AbstractArray) = NamedTuple{Tuple([s])}([repeat([missing], length(d))])
__inference_fill_predictions(s::Symbol, d::DataVariable) = NamedTuple{Tuple([s])}([missing])
albertpod marked this conversation as resolved.
Show resolved Hide resolved

## Inference results postprocessing

# TODO: Make this function a part of the public API?
Expand Down Expand Up @@ -185,16 +194,17 @@ This structure is used as a return value from the [`inference`](@ref) function.

See also: [`inference`](@ref)
"""
struct InferenceResult{P, F, M, R, E}
struct InferenceResult{P, A, F, M, R, E}
posteriors :: P
predictions :: A
free_energy :: F
model :: M
returnval :: R
error :: E
end

Base.iterate(results::InferenceResult) = iterate((getfield(results, :posteriors), getfield(results, :free_energy), getfield(results, :model), getfield(results, :returnval), getfield(results, :error)))
Base.iterate(results::InferenceResult, any) = iterate((getfield(results, :posteriors), getfield(results, :free_energy), getfield(results, :model), getfield(results, :returnval), getfield(results, :error)), any)
Base.iterate(results::InferenceResult) = iterate((getfield(results, :posteriors), getfield(results, :predictions), getfield(results, :free_energy), getfield(results, :model), getfield(results, :returnval), getfield(results, :error)))
Base.iterate(results::InferenceResult, any) = iterate((getfield(results, :posteriors), getfield(results, :predictions), getfield(results, :free_energy), getfield(results, :model), getfield(results, :returnval), getfield(results, :error)), any)

issuccess(result::InferenceResult) = !iserror(result)
iserror(result::InferenceResult) = !isnothing(result.error)
Expand All @@ -209,6 +219,13 @@ function Base.show(io::IO, result::InferenceResult)
join(io, keys(getfield(result, :posteriors)), ", ")
print(io, ")\n")

if !isempty(getfield(result, :predictions))
print(io, rpad(" Predictions", lcolumnlen), " | ")
print(io, "available for (")
join(io, keys(getfield(result, :predictions)), ", ")
print(io, ")\n")
end

if !isnothing(getfield(result, :free_energy))
print(io, rpad(" Free Energy:", lcolumnlen), " | ")
print(IOContext(io, :compact => true, :limit => true, :displaysize => (1, 80)), result.free_energy)
Expand Down Expand Up @@ -271,6 +288,7 @@ unwrap_free_energy_option(option::Type{T}) where {T <: Real} = (true, T, Countin
meta = nothing,
options = nothing,
returnvars = nothing,
predictvars = nothing,
iterations = nothing,
free_energy = false,
free_energy_diagnostics = BetheFreeEnergyDefaultChecks,
Expand All @@ -294,6 +312,7 @@ For more information about some of the arguments, please check below.
- `meta = nothing`: meta specification object, optional, may be required for some models, see `@meta`
- `options = nothing`: model creation options, optional, see `ModelInferenceOptions`
- `returnvars = nothing`: return structure info, optional, defaults to return everything at each iteration, see below for more information
- `predictvars = nothing`: return structure info, optional, see below for more information
- `iterations = nothing`: number of iterations, optional, defaults to `nothing`, the inference engine does not distinguish between variational message passing or Loopy belief propagation or expectation propagation iterations, see below for more information
- `free_energy = false`: compute the Bethe free energy, optional, defaults to false. Can be passed a floating point type, e.g. `Float64`, for better efficiency, but disables automatic differentiation packages, such as ForwardDiff.jl
- `free_energy_diagnostics = BetheFreeEnergyDefaultChecks`: free energy diagnostic checks, optional, by default checks for possible `NaN`s and `Inf`s. `nothing` disables all checks.
Expand Down Expand Up @@ -405,6 +424,27 @@ result = inference(
)
```

- ### `predictvars`

`predictvars` specifies the variables which should be predicted. In the model definition these variables are specified
as datavars, although they should not be passed inside data argument.

Similar to `returnvars`, `predictvars` accepts a `NamedTuple` or `Dict`. There are two specifications:
- `KeepLast`: saves the last update for a variable, ignoring any intermediate results during iterations
- `KeepEach`: saves all updates for a variable for all iterations

Example:

```julia
result = inference(
...,
predictvars = (
o = KeepLast(),
τ = KeepEach()
)
)
```

- ### `iterations`

Specifies the number of variational (or loopy belief propagation) iterations. By default set to `nothing`, which is equivalent of doing 1 iteration.
Expand Down Expand Up @@ -478,8 +518,8 @@ See also: [`InferenceResult`](@ref), [`rxinference`](@ref)
function inference(;
# `model`: specifies a model generator, with the help of the `Model` function
model::ModelGenerator,
# NamedTuple or Dict with data, required
data,
# NamedTuple or Dict with data, optional if predictvars are specified
data = nothing,
# NamedTuple or Dict with initial marginals, optional, defaults to empty
initmarginals = nothing,
# NamedTuple or Dict with initial messages, optional, defaults to empty
Expand All @@ -492,6 +532,8 @@ function inference(;
options = nothing,
# Return structure info, optional, defaults to return everything at each iteration
returnvars = nothing,
# Return structure info, optional, defaults to return everything at each iteration
albertpod marked this conversation as resolved.
Show resolved Hide resolved
predictvars = nothing,
# Number of iterations, defaults to 1, we do not distinguish between VMP or Loopy belief or EP iterations
iterations = nothing,
# Do we compute FE, optional, defaults to false
Expand All @@ -512,7 +554,9 @@ function inference(;
# catch exceptions during the inference procedure, optional, defaults to false
catch_exception = false
)
__inference_check_dicttype(:data, data)
if isnothing(data) && isnothing(predictvars)
error("""One of keyword arguments `data` or `predictvars`` must be specified""")
albertpod marked this conversation as resolved.
Show resolved Hide resolved
end
__inference_check_dicttype(:initmarginals, initmarginals)
__inference_check_dicttype(:initmessages, initmessages)
__inference_check_dicttype(:callbacks, callbacks)
Expand Down Expand Up @@ -559,7 +603,45 @@ function inference(;
returnvars = Dict(variable => returnoption for (variable, value) in pairs(vardict) if (israndom(value) && !isanonymous(value)))
end

# Assume that the prediction variables are specified as `datavars` inside the `@model` block, e.g. `pred = datavar(Float64, n)`.
# Verify that `predictvars` is not `nothing` and that `data` does not have any missing values.
if !isnothing(predictvars)
for (variable, value) in pairs(vardict)
if !isnothing(data) && haskey(predictvars, variable) && haskey(data, variable)
@warn "$(variable) is present in both `data` and `predictvars`. The values in `data` will be ignored."
end
# The following logic creates and adds predictions to the data as missing values.
if isdata(value) && haskey(predictvars, variable) # Verify that the value is of a specified data type and is included in `predictvars`.
if allows_missings(value) # Allow missing values, otherwise raise an error.
predictions = __inference_fill_predictions(variable, value)
data = isnothing(data) ? predictions : merge(data, predictions)
else
error("`predictvars` does not allow missing values for $(variable). Please add the following line: `$(variable) ~ datavar{...} where {allow_missing = true }`")
end
elseif isdata(value) && haskey(data, variable) && __inference_check_dataismissing(data[variable]) # The variable may be of a specified data type and contain missing values.
if allows_missings(value)
predictvars = merge(predictvars, Dict(variable => KeepLast()))
else
error("datavar $(variable) has missings inside but does not allow it. Add `where {allow_missing = true }`")
end
end
end
else # In this case, the prediction functionality should only be performed if the data allows missings and actually contains missing values.
foreach(
(variable, value) -> if isdata(value) && __inference_check_dataismissing(data[variable]) && !allows_missings(value)
error("datavar $(variable) has missings inside but does not allow it. Add `where {allow_missing = true }`")
else
nothing
end, keys(vardict), values(vardict)
)
predictvars = Dict(
variable => KeepLast() for (variable, value) in pairs(vardict) if
(isdata(value) && haskey(data, variable) && allows_missings(value) && __inference_check_dataismissing(data[variable]) && !isanonymous(value))
)
end

__inference_check_dicttype(:returnvars, returnvars)
__inference_check_dicttype(:predictvars, predictvars)

# Use `__check_has_randomvar` to filter out unknown or non-random variables in the `returnvar` specification
__check_has_randomvar(vardict, variable) = begin
Expand All @@ -573,11 +655,24 @@ function inference(;
return haskey_check && israndom_check
end

# Second, for each random variable entry we create an actor
actors = Dict(variable => make_actor(vardict[variable], value) for (variable, value) in pairs(returnvars) if __check_has_randomvar(vardict, variable))
# Use `__check_has_prediction` to filter out unknown predictions variables in the `predictvar` specification
__check_has_prediction(vardict, variable) = begin
haskey_check = haskey(vardict, variable)
isdata_check = haskey_check ? isdata(vardict[variable]) : false
if warn && !haskey_check
@warn "`predictvars` object has `$(variable)` specification, but model has no variable named `$(variable)`. The `$(variable)` specification is ignored. Use `warn = false` to suppress this warning."
elseif warn && haskey_check && !isdata_check
@warn "`predictvars` object has `$(variable)` specification, but model has no **data** variable named `$(variable)`. The `$(variable)` specification is ignored. Use `warn = false` to suppress this warning."
end
return haskey_check && isdata_check
end

# Second, for each random variable and predicting variable entry we create an actor
actors_rv = Dict(variable => make_actor(vardict[variable], value) for (variable, value) in pairs(returnvars) if __check_has_randomvar(vardict, variable))
actors_pr = Dict(variable => make_actor(vardict[variable], value) for (variable, value) in pairs(predictvars) if __check_has_prediction(vardict, variable))

# At third, for each random variable entry we create a boolean flag to track their updates
updates = Dict(variable => MarginalHasBeenUpdated(false) for (variable, _) in pairs(actors))
# At third, for each variable entry we create a boolean flag to track their updates
updates = Dict(variable => MarginalHasBeenUpdated(false) for (variable, _) in pairs(merge(actors_rv, actors_pr)))

_iterations = something(iterations, 1)
_iterations isa Integer || error("`iterations` argument must be of type Integer or `nothing`")
Expand All @@ -591,10 +686,15 @@ function inference(;

try
on_marginal_update = inference_get_callback(callbacks, :on_marginal_update)
subscriptions = Dict(variable => subscribe!(obtain_marginal(vardict[variable]) |> ensure_update(fmodel, on_marginal_update, variable, updates[variable]), actor) for (variable, actor) in pairs(actors))
subscriptions_rv = Dict(variable => subscribe!(obtain_marginal(vardict[variable]) |> ensure_update(fmodel, on_marginal_update, variable, updates[variable]), actor) for (variable, actor) in pairs(actors_rv))
subscriptions_pr = Dict(variable => subscribe!(obtain_prediction(vardict[variable]) |> ensure_update(fmodel, on_marginal_update, variable, updates[variable]), actor) for (variable, actor) in pairs(actors_pr))

is_free_energy, S, T = unwrap_free_energy_option(free_energy)

if !isempty(actors_pr) && is_free_energy
error("Cannot compute Bethe Free Energy for models with prediction variables. Please set `free_energy = false`.")
albertpod marked this conversation as resolved.
Show resolved Hide resolved
end

if is_free_energy
fe_actor = ScoreActor(S, _iterations, 1)
fe_objective = BetheFreeEnergy(BetheFreeEnergyDefaultMarginalSkipStrategy, AsapScheduler(), free_energy_diagnostics)
Expand Down Expand Up @@ -626,7 +726,9 @@ function inference(;
else
foreach(filter(pair -> isdata(last(pair)) && !isproxy(last(pair)), pairs(vardict))) do pair
varname = first(pair)
haskey(data, varname) || error("Data entry `$(varname)` is missing in `data` argument. Double check `data = ($(varname) = ???, )`")
haskey(data, varname) || error(
"Data entry `$(varname)` is missing in `data` or `predictvars` arguments. Double check `data = ($(varname) = ???, )` or `predictvars = ($(varname) = ???, )`"
)
end
end

Expand Down Expand Up @@ -668,7 +770,7 @@ function inference(;
end
end

for (_, subscription) in pairs(subscriptions)
for (_, subscription) in pairs(merge(subscriptions_pr, subscriptions_rv))
unsubscribe!(subscription)
end

Expand All @@ -683,10 +785,11 @@ function inference(;

unsubscribe!(fe_subscription)

posterior_values = Dict(variable => __inference_postprocess(postprocess, getvalues(actor)) for (variable, actor) in pairs(actors))
posterior_values = Dict(variable => __inference_postprocess(postprocess, getvalues(actor)) for (variable, actor) in pairs(actors_rv))
predicted_values = Dict(variable => __inference_postprocess(postprocess, getvalues(actor)) for (variable, actor) in pairs(actors_pr))
fe_values = !isnothing(fe_actor) ? score_snapshot_iterations(fe_actor, executed_iterations) : nothing

return InferenceResult(posterior_values, fe_values, fmodel, freturval, potential_error)
return InferenceResult(posterior_values, predicted_values, fe_values, fmodel, freturval, potential_error)
end

## ------------------------------------------------------------------------ ##
Expand Down
19 changes: 18 additions & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ function ReactiveMP.activate!(model::FactorGraphModel)

filter!(c -> isconnected(c), getconstant(model))
foreach(r -> activate!(r, options), getrandom(model))
foreach(d -> activate!(d, options), getdata(model))
foreach(n -> activate!(n, options), getnodes(model))
end

Expand Down Expand Up @@ -399,7 +400,23 @@ function ReactiveMP.make_node(
return node, var
else
combinedvars = combineLatest(ReactiveMP.getmarginal.(args, IncludeAll()), PushNew())
mappedvars = combinedvars |> map(Message, (vars) -> Message(PointMass(fform(map((d) -> ReactiveMP.getpointmass(ReactiveMP.getdata(d)), vars)...)), false, false, nothing))

# Check if some of the `DataVariable` allow for missing values
possibly_missings = any(allows_missings, filter(arg -> arg isa ReactiveMP.DataVariable, args))
# If `missing` values are allowed, then the result type is a `Union` of `Message{Missing}` and `Message{PointMass}`
result_type = possibly_missings ? Union{Message{Missing}, Message{PointMass}} : Message{PointMass}
# By convention, if the result happens to be missing, the result is a `Message{Missing}` instead of `Message{PointMass}`
mapping_fn = let possibly_missings = possibly_missings
(vars) -> begin
result = fform(map((d) -> ReactiveMP.getpointmass(ReactiveMP.getdata(d)), vars)...)
return if (possibly_missings && ismissing(result))
Message{Missing, Nothing}(missing, false, false, nothing)
else
Message{PointMass, Nothing}(PointMass(result), false, false, nothing)
end
end
end
mappedvars = combinedvars |> map(result_type, mapping_fn)
output = mappedvars |> share_recent()
var = push!(model, ReactiveMP.datavar(DataVariableCreationOptions(output, true, false), ReactiveMP.name(autovar), Any))
foreach(filter(ReactiveMP.isdata, args)) do datavar
Expand Down
Loading
Loading