diff --git a/src/GenerativeModels.jl b/src/GenerativeModels.jl index bdce672..712d7f8 100644 --- a/src/GenerativeModels.jl +++ b/src/GenerativeModels.jl @@ -24,11 +24,12 @@ module GenerativeModels abstract type AbstractVAE <: AbstractGM end abstract type AbstractGAN <: AbstractGM end - include(joinpath("utils", "flux_ode_decoder.jl")) + include(joinpath("utils", "flux_decoders.jl")) include(joinpath("utils", "saveload.jl")) include(joinpath("utils", "utils.jl")) include(joinpath("models", "vae.jl")) + include(joinpath("models", "ardnet.jl")) include(joinpath("models", "rodent.jl")) include(joinpath("models", "gan.jl")) include(joinpath("models", "vamp.jl")) diff --git a/src/models/ardnet.jl b/src/models/ardnet.jl new file mode 100644 index 0000000..80054eb --- /dev/null +++ b/src/models/ardnet.jl @@ -0,0 +1,48 @@ +export ARDNet + +const ACG = ConditionalDists.AbstractConditionalGaussian + +const FDCMeanGaussian = CMeanGaussian{V,<:FluxDecoder} where V + +""" + ARDNet(h::InverseGamma, p::Gaussian, e::Gaussian, d::ACGaussian) + +Generative model that emposes the sparsifying ARD (*Automatic Relevance +Determination*) prior on the weights of the decoder mapping: + +p(x|z) = N(x|ϕ(z),σx²) +p(z) = N(z|0,diag(λz²)) +p(λz) = iG(λ|α0,β0) + +where the posterior on z is a multivariate Gaussian +q(z|x) = N(z|μz,σz²) +""" +struct ARDNet{H<:InverseGamma,P<:Gaussian,E<:Gaussian,D<:ACG} <: AbstractGM + hyperprior::H + prior::P + encoder::E + decoder::D +end + +Flux.@functor ARDNet + +function ConditionalDists.logpdf(p::ACG, x::AbstractArray{T}, z::AbstractArray{T}, + ps::AbstractVector{T}) where T + μ = mean(p, z, ps) + σ2 = var(p, z) + d = x - μ + y = d .* d + y = (1 ./ σ2) .* y .+ log.(σ2) .+ T(log(2π)) + -sum(y, dims=1) / 2 +end + +ConditionalDists.mean(p::FDCMeanGaussian, z::AbstractArray, ps::AbstractVector) = + p.mapping(z, ps) + +function elbo(m::ARDNet, x, y; β=1) + ps = reshape(rand(m.encoder),:) + llh = sum(logpdf(m.decoder, y, x, ps)) + kld = sum(kl_divergence(m.encoder, m.prior)) + lpλ = sum(logpdf(m.hyperprior, var(m.prior))) + llh - β*(kld - lpλ) +end diff --git a/src/models/rodent.jl b/src/models/rodent.jl index 30f976a..ac6ad46 100644 --- a/src/models/rodent.jl +++ b/src/models/rodent.jl @@ -12,7 +12,8 @@ with ARD prior and an ODE decoder. * `e`: Encoder p(z|x) * `d`: Decoder p(x|z) """ -struct Rodent{P<:Gaussian,E<:CMeanGaussian,D<:CMeanGaussian} <: AbstractVAE +struct Rodent{H,P,E<:CMeanGaussian,D<:CMeanGaussian} <: AbstractGM + hyperprior::InverseGamma prior::Gaussian encoder::CMeanGaussian decoder::CMeanGaussian @@ -20,7 +21,7 @@ end Flux.@functor Rodent -Rodent(p::P, e::E, d::D) where {P,E,D} = Rodent{P,E,D}(p,e,d) +Rodent(h::H, p::P, e::E, d::D) where {H,P,E,D} = Rodent{H,P,E,D}(h,p,e,d) """ Rodent(slen::Int, tlen::Int, dt::T, encoder; @@ -95,50 +96,30 @@ function Rodent(slen::Int, tlen::Int, dt::T, encoder; olen=slen*tlen) where T zlen = length(Flux.destructure(ode)[1]) + slen + # hyperprior + hyperprior = InverseGamma(T(1), T(1), zlen, true) + + # prior μpz = NoGradArray(zeros(T, zlen)) λ2z = ones(T, zlen) / 20 prior = Gaussian(μpz, λ2z) + # encoder σ2z = ones(T, zlen) / 20 enc_dist = CMeanGaussian{DiagVar}(encoder, σ2z) + # decoder σ2x = ones(T, 1) / 20 decoder = FluxODEDecoder(slen, tlen, dt, ode, observe) dec_dist = CMeanGaussian{ScalarVar}(decoder, σ2x, olen) - Rodent(prior, enc_dist, dec_dist) -end - -struct ConstSpecRodent{CP<:Gaussian,SP<:Gaussian,E<:ConstSpecGaussian,D<:CMeanGaussian} <: AbstractVAE - const_prior::CP - spec_prior::SP - encoder::E - decoder::D + Rodent(hyperprior, prior, enc_dist, dec_dist) end -ConstSpecRodent(cp::CP, sp::SP, e::E, d::D) where {CP,SP,E,D} = - ConstSpecRodent{CP,SP,E,D}(cp,sp,e,d) - -Flux.@functor ConstSpecRodent - -function elbo(m::ConstSpecRodent, x::AbstractArray) - cz = rand(m.encoder.cnst) - sz = rand(m.encoder.spec, x) - z = cz .+ sz - +function elbo(m::Rodent, x::AbstractMatrix; β=1) + z = rand(m.encoder, x) llh = sum(logpdf(m.decoder, x, z)) - ckl = sum(kl_divergence(m.encoder.cnst, m.const_prior)) - skl = sum(kl_divergence(m.encoder.spec, m.spec_prior, sz)) - - llh - ckl - skl -end - -function Base.show(io::IO, m::ConstSpecRodent) - msg = """$(typeof(m)): - const_prior = $(summary(m.const_prior))) - spec_prior = $(summary(m.spec_prior)) - encoder = $(summary(m.encoder)) - decoder = $(summary(m.decoder)) - """ - print(io, msg) + kld = sum(kl_divergence(m.encoder, m.prior, x)) + lpλ = sum(logpdf(m.hyperprior, var(m.prior))) + llh - β*(kld - lpλ) end diff --git a/src/utils/flux_ode_decoder.jl b/src/utils/flux_decoders.jl similarity index 77% rename from src/utils/flux_ode_decoder.jl rename to src/utils/flux_decoders.jl index 399137a..4e05e35 100644 --- a/src/utils/flux_ode_decoder.jl +++ b/src/utils/flux_decoders.jl @@ -1,4 +1,27 @@ -export FluxODEDecoder +export FluxDecoder, FluxODEDecoder + +""" + FluxDecoder{M}(model) + +Simple decoder that, when called with an additional parameter vector, +restructures it into `model` and calls model(x) + +julia> dec = FluxDecoder(Dense(2,3)) +julia> ps = rand(9) +julia> dec(rand(2,10), ps) +3×10 Array{Float64,2}: + 0.508304 0.620386 0.423422 … 0.595583 0.551536 0.565597 0.255811 + 1.75512 1.32246 1.57151 1.82269 1.2394 1.73934 0.844125 + 1.45708 0.92777 1.28766 1.49607 0.863829 1.4156 0.546213 +""" +struct FluxDecoder{M} + model::M + restructure::Function +end + +FluxDecoder(m) = FluxDecoder(m, Flux.destructure(m)[2]) +(d::FluxDecoder)(x::AbstractMatrix, ps::AbstractVector) = d.restructure(ps)(x) +(d::FluxDecoder)(x::AbstractMatrix) = d.model(x) """ FluxODEDecoder{M}(slength::Int, tlength::Int, dt::Real,