From b9ae44d3f1d9309b3326bc2fbee056daae55a846 Mon Sep 17 00:00:00 2001 From: theogf Date: Thu, 23 Feb 2023 13:31:44 +0100 Subject: [PATCH 1/4] Add collection of mean, var and std for distributions --- src/MeasureTheory.jl | 2 +- src/combinators/product.jl | 5 ++++ src/parameterized/mvnormal.jl | 25 ++++++++++++++++ src/parameterized/normal.jl | 20 +++++++++++++ src/parameterized/poisson.jl | 4 +++ test/runtests.jl | 55 +++++++++++++++++++++++++++++++++++ 6 files changed, 110 insertions(+), 1 deletion(-) diff --git a/src/MeasureTheory.jl b/src/MeasureTheory.jl index e43642d7..7a7acd3d 100644 --- a/src/MeasureTheory.jl +++ b/src/MeasureTheory.jl @@ -64,7 +64,7 @@ using MeasureBase: BoundedInts, BoundedReals, CountingMeasure, IntegerDomain, In using MeasureBase: weightedmeasure, restrict using MeasureBase: AbstractTransitionKernel -import Statistics: mean, var, std +import Statistics: mean, cov, var, std import MeasureBase: likelihoodof export likelihoodof diff --git a/src/combinators/product.jl b/src/combinators/product.jl index f0c03e0b..cd8d7df3 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -57,6 +57,11 @@ end x end +for f in (:mean, :var, :std) + @eval $f(d::ProductMeasure) = map($f, marginals(d)) + @eval $f(d::For) = map($f, marginals(d)) +end + # # e.g. set(Normal(μ=2)^5, params, randn(5)) # function Accessors.set( # d::ProductMeasure{A}, diff --git a/src/parameterized/mvnormal.jl b/src/parameterized/mvnormal.jl index 31d7ebc9..f1ceb330 100644 --- a/src/parameterized/mvnormal.jl +++ b/src/parameterized/mvnormal.jl @@ -90,3 +90,28 @@ end @inline function proxy(d::MvNormal{(:μ, :Λ),Tuple{T,C}}) where {T,C<:Cholesky} affine((μ = d.μ, λ = d.Λ.L), Normal()^supportdim(d)) end + +# Statistics dispatch +for N in ((:μ,), (:σ,), (:λ,), (:Σ,), (:Λ,), (:μ, :σ), (:μ, :λ), (:μ, :Σ), (:μ, :Λ)) + expr = Expr(:block) + if first(N) == :μ + push!(expr.args, :(mean(d::MvNormal{$N}) = d.μ)) + else + push!(expr.args, :(mean(d::MvNormal{$N,Tuple{T}}) where {T} = zeros(eltype(T), supportdim(d)))) + end + cov_var = last(N) + push!(expr.args, :(var(d::MvNormal{$N}) = diag(cov(d)))) + push!(expr.args, :(std(d::MvNormal{$N}) = sqrt.(diag(cov(d))))) + if cov_var == :μ + push!(expr.args, :(cov(d::MvNormal{$N, Tuple{T}}) where {T} = I(supportdim(d)...))) + elseif cov_var == :σ + push!(expr.args, :(cov(d::MvNormal{$N}) = d.σ * d.σ')) + elseif cov_var == :λ + push!(expr.args, :(cov(d::MvNormal{$N}) = inv(d.λ' * d.λ))) + elseif cov_var == :Σ + push!(expr.args, :(cov(d::MvNormal{$N}) = Matrix(d.Σ))) + elseif cov_var == :Λ + push!(expr.args, :(cov(d::MvNormal{$N}) = inv(d.Λ'))) + end + eval(expr) +end diff --git a/src/parameterized/normal.jl b/src/parameterized/normal.jl index aa4d37a1..3ad7ceaa 100644 --- a/src/parameterized/normal.jl +++ b/src/parameterized/normal.jl @@ -199,3 +199,23 @@ function logdensity_def(p::Normal, q::Normal, x) return sqdiff / 2 end + +for N in ((:μ,), (:σ,), (:λ,), (:μ, :σ), (:μ, :λ)) + expr = Expr(:block) + if first(N) == :μ + push!(expr.args, :(mean(d::Normal{$N}) = d.μ)) + else + push!(expr.args, :(mean(d::Normal{$N,Tuple{T}}) where {T} = zero(T))) + end + cov_var = last(N) + push!(expr.args, :(var(d::Normal{$N}) = abs2(std(d)))) + if cov_var == :μ + push!(expr.args, :(std(d::Normal{$N, Tuple{T}}) where {T} = one(T))) + elseif cov_var == :σ + push!(expr.args, :(std(d::Normal{$N}) = d.σ)) + elseif cov_var == :λ + push!(expr.args, :(std(d::Normal{$N}) = inv(d.λ))) + end + eval(expr) +end + diff --git a/src/parameterized/poisson.jl b/src/parameterized/poisson.jl index 7135fd70..6e57caa3 100644 --- a/src/parameterized/poisson.jl +++ b/src/parameterized/poisson.jl @@ -38,6 +38,10 @@ function Base.rand(rng::AbstractRNG, T::Type, d::Poisson{(:logλ,)}) rand(rng, Dists.Poisson(exp(d.logλ))) end +mean(d::Poisson{(:λ,)}) = d.λ +std(d::Poisson{(:λ,)}) = sqrt(d.λ) +var(d::Poisson{(:λ,)}) = d.λ + @inline function insupport(::Poisson, x) isinteger(x) && x ≥ 0 end diff --git a/test/runtests.jl b/test/runtests.jl index 3ccf0f45..97cbe9c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Test using StatsFuns using Base.Iterators: take +using Statistics using Random using LinearAlgebra using DynamicIterators: trace, TimeLift @@ -419,6 +420,33 @@ end x = rand(d) @test logdensityof(d, x) ≈ logdensityof(Dists.MvNormal(Σ), x) @test logdensityof(MvNormal(zeros(3), σ), x) ≈ logdensityof(d, x) + + D = 3 + μ = randn(D) + zero_μ = zeros(D) + σ = LowerTriangular(randn(D, D)) + Σ = σ * σ' + λ = inv(σ) + Λ = inv(Σ) + + d = MvNormal(;μ) + @test mean(d) == μ + @test cov(d) == I(D) + @testset "Cov param: $(string(cov_param))" for (cov_param, val) in [(:σ, σ), (:Σ, Σ), (:λ, λ), (:Λ, Λ)] + @eval begin + # Mean is not given + d = MvNormal(;$cov_param=$val) + @test mean(d) == $zero_μ + @test cov(d) ≈ $Σ + @test std(d) ≈ sqrt.(diag($Σ)) + @test var(d) ≈ diag($Σ) + d = MvNormal(;μ=$μ, $cov_param=$val) + @test mean(d) == $μ + @test cov(d) ≈ $Σ + @test std(d) ≈ sqrt.(diag($Σ)) + @test var(d) ≈ diag($Σ) + end + end end @testset "NegativeBinomial" begin @@ -427,10 +455,34 @@ end @testset "Normal" begin @test_broken repro(Normal, (:μ, :σ)) + μ = randn() + σ = rand() + λ = inv(σ) + d = Normal(;μ) + @test mean(d) == μ + @test var(d) == one(μ) + @test std(d) == one(μ) + @testset "std param : $(string(std_param))" for (std_param, val) in [(:σ, σ), (:λ, λ)] + @eval begin + d = Normal(;μ=$μ) + @test mean(d) == $μ + @test var(d) == one($μ) + @test std(d) == one($μ) + d = Normal(;μ=$μ, $std_param=$val) + @test mean(d) == $μ + @test var(d) == abs2($σ) + @test std(d) == $σ + end + end end @testset "Poisson" begin @test repro(Poisson, (:λ,)) + λ = rand() + d = Poisson(;λ) + @test mean(d) == λ + @test var(d) == λ + @test std(d) == sqrt(λ) end @testset "StudentT" begin @@ -449,6 +501,9 @@ end x = Vector{Int16}(undef, 10) @test rand!(d, x) isa Vector @test rand(d) isa Vector + @test mean(d) == mean.(collect(marginals(d))) + @test std(d) == std.(collect(marginals(d))) + @test var(d) == var.(collect(marginals(d))) @testset "Indexed by Generator" begin d = For((j^2 for j in 1:10)) do i From cd9c21654c848e6a442151e2e12924324cb218bb Mon Sep 17 00:00:00 2001 From: theogf Date: Tue, 9 May 2023 15:39:39 +0200 Subject: [PATCH 2/4] Bump compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 26a3a666..701352a7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MeasureTheory" uuid = "eadaa1a4-d27c-401d-8699-e962e1bbc33b" authors = ["Chad Scherrer and contributors"] -version = "0.18.1" +version = "0.18.3" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" From c2f1088ac14866f9301e490daa869ed08a5557ff Mon Sep 17 00:00:00 2001 From: theogf Date: Sat, 27 Jan 2024 14:06:20 +0100 Subject: [PATCH 3/4] create Project for test --- Project.toml | 10 ++++------ test/Project.toml | 12 ++++++++++++ test/runtests.jl | 2 +- 3 files changed, 17 insertions(+), 7 deletions(-) create mode 100644 test/Project.toml diff --git a/Project.toml b/Project.toml index 2c34a57d..d748373e 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,9 @@ ForwardDiff = "0.10" IfElse = "0.1" Infinities = "0.1" InverseFunctions = "0.1" +InteractiveUtils = "<0.0.1, 1" KeywordCalls = "0.2" +LinearAlgebra = "<0.0.1, 1" LogExpFunctions = "0.3.3" MLStyle = "0.4" MacroTools = "0.5" @@ -62,19 +64,15 @@ MeasureBase = "0.14" NamedTupleTools = "0.13, 0.14" PositiveFactorizations = "0.2" PrettyPrinting = "0.3, 0.4" +Random = "<0.0.1, 1" Reexport = "1" SpecialFunctions = "1, 2" Static = "0.8" StaticArraysCore = "1" +Statistics = "<0.0.1, 1" StatsBase = "0.34" StatsFuns = "0.9, 1" TransformVariables = "0.8" Tricks = "0.1" julia = "1.6" -[extras] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test", "Aqua"] diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 00000000..60cd250e --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,12 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DynamicIterators = "6c76993d-992e-5bf1-9e63-34920a5a5a38" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" diff --git a/test/runtests.jl b/test/runtests.jl index 674e593f..f15374d7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,7 @@ using IfElse # # detect_ambiguities_options..., # ) -Aqua.test_all(MeasureBase; ambiguities = false) +# Aqua.test_all(MeasureBase; ambiguities = false) function draw2(μ) x = rand(μ) From 8728097ad536572b772085b42ae407be38cc14e9 Mon Sep 17 00:00:00 2001 From: theogf Date: Sat, 27 Jan 2024 17:16:21 +0100 Subject: [PATCH 4/4] use isapprox --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index f15374d7..499c2f8b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -462,8 +462,8 @@ end @test std(d) == one($μ) d = Normal(;μ=$μ, $std_param=$val) @test mean(d) == $μ - @test var(d) == abs2($σ) - @test std(d) == $σ + @test var(d) ≈ abs2($σ) + @test std(d) ≈ $σ end end end