From cd98384e18021099b7d16580546e17b3a602189a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 19 Jun 2022 20:51:14 -0400 Subject: [PATCH 1/9] Rename VarTransformation to TransportFunction --- src/transport.jl | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/transport.jl b/src/transport.jl index e3970f55..82b0e1c3 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -180,44 +180,44 @@ end """ - struct VarTransformation <: Function + struct TransportFunction <: Function Transforms a variate from one measure to a variate of another. -In general `VarTransformation` should not be called directly, call +In general `TransportFunction` should not be called directly, call [`transport_to`](@ref) instead. """ -struct VarTransformation{NU,MU} <: Function +struct TransportFunction{NU,MU} <: Function ν::NU μ::MU - function VarTransformation{NU,MU}(ν::NU, μ::MU) where {NU,MU} + function TransportFunction{NU,MU}(ν::NU, μ::MU) where {NU,MU} return new{NU,MU}(ν, μ) end - function VarTransformation(ν::NU, μ::MU) where {NU,MU} + function TransportFunction(ν::NU, μ::MU) where {NU,MU} check_dof(ν, μ) return new{NU,MU}(ν, μ) end end -@inline transport_to(ν, μ) = VarTransformation(ν, μ) +@inline transport_to(ν, μ) = TransportFunction(ν, μ) -function Base.:(==)(a::VarTransformation, b::VarTransformation) +function Base.:(==)(a::TransportFunction, b::TransportFunction) return a.ν == b.ν && a.μ == b.μ end -Base.@propagate_inbounds function (f::VarTransformation)(x) +Base.@propagate_inbounds function (f::TransportFunction)(x) return transport_def(f.ν, f.μ, checked_var(f.μ, x)) end -@inline function InverseFunctions.inverse(f::VarTransformation{NU,MU}) where {NU,MU} - return VarTransformation{MU,NU}(f.μ, f.ν) +@inline function InverseFunctions.inverse(f::TransportFunction{NU,MU}) where {NU,MU} + return TransportFunction{MU,NU}(f.μ, f.ν) end -function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x) +function ChangesOfVariables.with_logabsdet_jacobian(f::TransportFunction, x) y = f(x) logpdf_src = logdensityof(f.μ, x) logpdf_trg = logdensityof(f.ν, y) @@ -228,18 +228,18 @@ function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x) end -Base.:(∘)(::typeof(identity), f::VarTransformation) = f -Base.:(∘)(f::VarTransformation, ::typeof(identity)) = f +Base.:(∘)(::typeof(identity), f::TransportFunction) = f +Base.:(∘)(f::TransportFunction, ::typeof(identity)) = f -function Base.:∘(outer::VarTransformation, inner::VarTransformation) +function Base.:∘(outer::TransportFunction, inner::TransportFunction) if !(outer.μ == inner.ν || isequal(outer.μ, inner.ν) || outer.μ ≈ inner.ν) - throw(ArgumentError("Cannot compose VarTransformation if source of outer doesn't equal target of inner.")) + throw(ArgumentError("Cannot compose TransportFunction if source of outer doesn't equal target of inner.")) end - return VarTransformation(outer.ν, inner.μ) + return TransportFunction(outer.ν, inner.μ) end -function Base.show(io::IO, f::VarTransformation) +function Base.show(io::IO, f::TransportFunction) print(io, Base.typename(typeof(f)).name, "(") show(io, f.ν) print(io, ", ") @@ -247,7 +247,7 @@ function Base.show(io::IO, f::VarTransformation) print(io, ")") end -Base.show(io::IO, M::MIME"text/plain", f::VarTransformation) = show(io, f) +Base.show(io::IO, M::MIME"text/plain", f::TransportFunction) = show(io, f) """ From 6ec304223110423240a67e3ffb72139a4c9252e1 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 19 Jun 2022 20:52:17 -0400 Subject: [PATCH 2/9] Rename test_vartransform to test_transport --- src/interface.jl | 4 ++-- test/transport.jl | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index d27ee505..646ad92f 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -13,7 +13,7 @@ using InverseFunctions: inverse using ChangesOfVariables: with_logabsdet_jacobian export test_interface -export test_vartransform +export test_transport export basemeasure_depth export proxy export insupport @@ -66,7 +66,7 @@ function test_interface(μ::M) where {M} end -function test_vartransform(ν, μ) +function test_transport(ν, μ) supertype(x::Real) = Real supertype(x::AbstractArray{<:Real,N}) where N = AbstractArray{<:Real,N} diff --git a/test/transport.jl b/test/transport.jl index 0cb6e55a..10367722 100644 --- a/test/transport.jl +++ b/test/transport.jl @@ -1,6 +1,6 @@ using Test -using MeasureBase.Interface: transport_to, test_vartransform +using MeasureBase.Interface: transport_to, test_transport using MeasureBase: StdUniform, StdExponential, StdLogistic using MeasureBase: Dirac @@ -8,13 +8,13 @@ using MeasureBase: Dirac @testset "transport_to" begin for μ0 in [StdUniform(), StdExponential(), StdLogistic()], ν0 in [StdUniform(), StdExponential(), StdLogistic()] @testset "transport_to (variations of) $(nameof(typeof(μ0))) to $(nameof(typeof(ν0)))" begin - test_vartransform(ν0, μ0) - test_vartransform(2.2 * ν0, 3 * μ0) - test_vartransform(ν0, μ0^1) - test_vartransform(ν0^1, μ0) - test_vartransform(ν0^3, μ0^3) - test_vartransform(ν0^(2,3,2), μ0^(3,4)) - test_vartransform(2.2 * ν0^(2,3,2), 3 * μ0^(3,4)) + test_transport(ν0, μ0) + test_transport(2.2 * ν0, 3 * μ0) + test_transport(ν0, μ0^1) + test_transport(ν0^1, μ0) + test_transport(ν0^3, μ0^3) + test_transport(ν0^(2,3,2), μ0^(3,4)) + test_transport(2.2 * ν0^(2,3,2), 3 * μ0^(3,4)) @test_throws ArgumentError transport_to(ν0, μ0)(rand(μ0^12)) @test_throws ArgumentError transport_to(ν0^3, μ0^3)(rand(μ0^(3,4))) end @@ -22,10 +22,10 @@ using MeasureBase: Dirac @testset "transfrom from/to Dirac" begin μ = Dirac(4.2) - test_vartransform(StdExponential()^0, μ) - test_vartransform(StdExponential()^(0,0,0), μ) - test_vartransform(μ, StdExponential()^static(0)) - test_vartransform(μ, StdExponential()^(static(0),static(0))) + test_transport(StdExponential()^0, μ) + test_transport(StdExponential()^(0,0,0), μ) + test_transport(μ, StdExponential()^static(0)) + test_transport(μ, StdExponential()^(static(0),static(0))) @test_throws ArgumentError transport_to(StdExponential()^1, μ) @test_throws ArgumentError transport_to(μ, StdExponential()^1) end From 435796d71853361f3c6611aea333938034575d35 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 19 Jun 2022 20:54:47 -0400 Subject: [PATCH 3/9] Rename NoVarTransform to NoTransport --- src/interface.jl | 4 ++-- src/transport.jl | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 646ad92f..b2e4cef2 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -6,7 +6,7 @@ using Reexport using MeasureBase: basemeasure_depth, proxy using MeasureBase: insupport, basemeasure_sequence, commonbase -using MeasureBase: transport_to, NoVarTransform +using MeasureBase: transport_to, NoTransport using DensityInterface: logdensityof using InverseFunctions: inverse @@ -72,7 +72,7 @@ function test_transport(ν, μ) @testset "transport_to $μ to $ν" begin x = rand(μ) - @test !(@inferred(transport_to(ν, μ)(x)) isa NoVarTransform) + @test !(@inferred(transport_to(ν, μ)(x)) isa NoTransport) f = transport_to(ν, μ) y = f(x) @test @inferred(inverse(f)(y)) ≈ x diff --git a/src/transport.jl b/src/transport.jl index 82b0e1c3..8b4eb5df 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -41,12 +41,12 @@ to_origin(ν::NU, ::Any) where NU = NoTransformOrigin{NU}(ν) """ - struct MeasureBase.NoVarTransform{NU,MU} end + struct MeasureBase.NoTransport{NU,MU} end Indicates that no transformation from a measure of type `MU` to a measure of type `NU` could be found. """ -struct NoVarTransform{NU,MU} end +struct NoTransport{NU,MU} end """ @@ -120,7 +120,7 @@ See [`transport_to`](@ref). function transport_def end transport_def(::Any, ::Any, x::NoTransformOrigin) = x -transport_def(::Any, ::Any, x::NoVarTransform) = x +transport_def(::Any, ::Any, x::NoTransport) = x function transport_def(ν, μ, x) _vartransform_with_intermediate(ν, _checked_vartransform_origin(ν), _checked_vartransform_origin(μ), μ, x) @@ -175,8 +175,8 @@ function _vartransform_with_intermediate(ν, m, μ, x) end # Prevent infinite recursion in case vartransform_intermediate doesn't change type: -@inline _vartransform_with_intermediate(::NU, ::NU, ::MU, ::Any) where {NU,MU} = NoVarTransform{NU,MU}() -@inline _vartransform_with_intermediate(::NU, ::MU, ::MU, ::Any) where {NU,MU} = NoVarTransform{NU,MU}() +@inline _vartransform_with_intermediate(::NU, ::NU, ::MU, ::Any) where {NU,MU} = NoTransport{NU,MU}() +@inline _vartransform_with_intermediate(::NU, ::MU, ::MU, ::Any) where {NU,MU} = NoTransport{NU,MU}() """ From 09533f566207461b6166cca37b2b8a806f5da80e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 19 Jun 2022 20:58:07 -0400 Subject: [PATCH 4/9] Rename checked_var to checked_arg --- src/combinators/power.jl | 4 ++-- src/combinators/transformedmeasure.jl | 4 ++-- src/getdof.jl | 10 +++++----- src/primitives/dirac.jl | 2 +- src/primitives/lebesgue.jl | 4 ++-- src/transport.jl | 10 +++++----- test/getdof.jl | 14 +++++++------- 7 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index db5b72c9..9a7a12f4 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -114,7 +114,7 @@ end @inline getdof(::PowerMeasure{<:Any, NTuple{N,Base.OneTo{StaticInt{0}}}}) where N = static(0) -@propagate_inbounds function checked_var(μ::PowerMeasure, x::AbstractArray{<:Any}) +@propagate_inbounds function checked_arg(μ::PowerMeasure, x::AbstractArray{<:Any}) @boundscheck begin sz_μ = map(length, μ.axes) sz_x = size(x) @@ -125,6 +125,6 @@ end return x end -function checked_var(μ::PowerMeasure, x::Any) +function checked_arg(μ::PowerMeasure, x::Any) throw(ArgumentError("Size of variate doesn't match size of power measure")) end diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index ee471861..2b2fde77 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -82,8 +82,8 @@ _pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where MU = dof _pushfwd_dof(MU, R, getdof(ν.origin)) end -# Bypass `checked_var`, would require potentially costly transformation: -@inline checked_var(::PushforwardMeasure, x) = x +# Bypass `checked_arg`, would require potentially costly transformation: +@inline checked_arg(::PushforwardMeasure, x) = x @inline transport_origin(ν::PushforwardMeasure) = ν.origin diff --git a/src/getdof.jl b/src/getdof.jl index b0c8d864..c2bcac70 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -60,18 +60,18 @@ struct NoVarCheck{MU,T} end """ - MeasureBase.checked_var(μ::MU, x::T)::T + MeasureBase.checked_arg(μ::MU, x::T)::T Return `x` if `x` is a valid variate of `μ`, throw an `ArgumentError` if not, return `NoVarCheck{MU,T}()` if not check can be performed. """ -function checked_var end +function checked_arg end # Prevent infinite recursion: @propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoVarCheck{MU,T} -@propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_var(mu_base, x) +@propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_arg(mu_base, x) -@propagate_inbounds checked_var(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x) +@propagate_inbounds checked_arg(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x) _checked_var_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ -ChainRulesCore.rrule(::typeof(checked_var), ν, x) = checked_var(ν, x), _checked_var_pullback +ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_var_pullback diff --git a/src/primitives/dirac.jl b/src/primitives/dirac.jl index 2575382c..0c10e723 100644 --- a/src/primitives/dirac.jl +++ b/src/primitives/dirac.jl @@ -32,7 +32,7 @@ insupport(d::Dirac, x) = x == d.x @inline getdof(::Dirac) = static(0) -@propagate_inbounds function checked_var(μ::Dirac, x) +@propagate_inbounds function checked_arg(μ::Dirac, x) @boundscheck insupport(μ, x) || throw(ArgumentError("Invalid variate for measure")) x end diff --git a/src/primitives/lebesgue.jl b/src/primitives/lebesgue.jl index 7049e88a..52844708 100644 --- a/src/primitives/lebesgue.jl +++ b/src/primitives/lebesgue.jl @@ -43,8 +43,8 @@ logdensity_def(::CountingMeasure, ::LebesgueMeasure, x) = Inf @inline getdof(::Lebesgue) = static(1) -@inline checked_var(::Lebesgue, x::Real) = x +@inline checked_arg(::Lebesgue, x::Real) = x -@propagate_inbounds function checked_var(::Lebesgue, x::Any) +@propagate_inbounds function checked_arg(::Lebesgue, x::Any) @boundscheck throw(ArgumentError("Invalid variate type for measure")) end diff --git a/src/transport.jl b/src/transport.jl index 8b4eb5df..9d8b4126 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -140,8 +140,8 @@ end function _vartransform_with_intermediate(ν, ν_o, μ_o, μ, x) x_o = to_origin(μ, x) - # If μ is a pushforward then checked_var may have been bypassed, so check now: - y_o = transport_def(ν_o, μ_o, checked_var(μ_o, x_o)) + # If μ is a pushforward then checked_arg may have been bypassed, so check now: + y_o = transport_def(ν_o, μ_o, checked_arg(μ_o, x_o)) y = from_origin(ν, y_o) return y end @@ -154,8 +154,8 @@ end function _vartransform_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x) x_o = to_origin(μ, x) - # If μ is a pushforward then checked_var may have been bypassed, so check now: - y = transport_def(ν, μ_o, checked_var(μ_o, x_o)) + # If μ is a pushforward then checked_arg may have been bypassed, so check now: + y = transport_def(ν, μ_o, checked_arg(μ_o, x_o)) return y end @@ -209,7 +209,7 @@ end Base.@propagate_inbounds function (f::TransportFunction)(x) - return transport_def(f.ν, f.μ, checked_var(f.μ, x)) + return transport_def(f.ν, f.μ, checked_arg(f.μ, x)) end @inline function InverseFunctions.inverse(f::TransportFunction{NU,MU}) where {NU,MU} diff --git a/test/getdof.jl b/test/getdof.jl index c8d3953b..6ef08c40 100644 --- a/test/getdof.jl +++ b/test/getdof.jl @@ -1,6 +1,6 @@ using Test -using MeasureBase: getdof, check_dof, checked_var +using MeasureBase: getdof, check_dof, checked_arg using MeasureBase: StdUniform, StdExponential, StdLogistic using ChainRulesTestUtils: test_rrule using Static: static @@ -18,18 +18,18 @@ using Static: static @test_throws ArgumentError check_dof(μ2, μ0) test_rrule(check_dof, μ0, StdUniform()) - @test @inferred(checked_var(μ0, x0)) === x0 - @test_throws ArgumentError checked_var(μ0, x2) - test_rrule(checked_var, μ0, x0) + @test @inferred(checked_arg(μ0, x0)) === x0 + @test_throws ArgumentError checked_arg(μ0, x2) + test_rrule(checked_arg, μ0, x0) @test @inferred(getdof(μ2)) == 6 @test (check_dof(μ2, StdUniform()^(1,6,1)); true) @test_throws ArgumentError check_dof(μ2, μ0) test_rrule(check_dof, μ2, StdUniform()^(1,6,1)) - @test @inferred(checked_var(μ2, x2)) === x2 - @test_throws ArgumentError checked_var(μ2, x0) - test_rrule(checked_var, μ2, x2) + @test @inferred(checked_arg(μ2, x2)) === x2 + @test_throws ArgumentError checked_arg(μ2, x0) + test_rrule(checked_arg, μ2, x2) @test @inferred(getdof((StdExponential()^3)^(static(0),static(0)))) === static(0) end From 42a27e2ba2593cc4981537747d8674e733ab383b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 19 Jun 2022 20:58:47 -0400 Subject: [PATCH 5/9] Rename NoVarCheck to NoArgCheck --- src/getdof.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/getdof.jl b/src/getdof.jl index c2bcac70..946dd038 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -51,24 +51,24 @@ ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_do """ - MeasureBase.NoVarCheck{MU,T} + MeasureBase.NoArgCheck{MU,T} Indicates that there is no way to check of a values of type `T` are variate of measures of type `MU`. """ -struct NoVarCheck{MU,T} end +struct NoArgCheck{MU,T} end """ MeasureBase.checked_arg(μ::MU, x::T)::T Return `x` if `x` is a valid variate of `μ`, throw an `ArgumentError` if not, -return `NoVarCheck{MU,T}()` if not check can be performed. +return `NoArgCheck{MU,T}()` if not check can be performed. """ function checked_arg end # Prevent infinite recursion: -@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoVarCheck{MU,T} +@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoArgCheck{MU,T} @propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_arg(mu_base, x) @propagate_inbounds checked_arg(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x) From e9546379294a3e1b95fbb239b26eec78b4c91f75 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 19 Jun 2022 21:00:04 -0400 Subject: [PATCH 6/9] Rename _default_checked_var to _default_checked_arg --- src/getdof.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/getdof.jl b/src/getdof.jl index 946dd038..1a405ad4 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -68,10 +68,10 @@ return `NoArgCheck{MU,T}()` if not check can be performed. function checked_arg end # Prevent infinite recursion: -@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoArgCheck{MU,T} -@propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_arg(mu_base, x) +@propagate_inbounds _default_checked_arg(::Type{MU}, ::MU, ::T) where {MU,T} = NoArgCheck{MU,T} +@propagate_inbounds _default_checked_arg(::Type{MU}, mu_base, x) where MU = checked_arg(mu_base, x) -@propagate_inbounds checked_arg(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x) +@propagate_inbounds checked_arg(mu::MU, x) where MU = _default_checked_arg(MU, basemeasure(mu), x) -_checked_var_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ -ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_var_pullback +_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ +ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback From 617f1bce3c03871aeae5560848197b0fa14373ff Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 19 Jun 2022 21:03:29 -0400 Subject: [PATCH 7/9] Rename internal transport helper functions --- src/transport.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/transport.jl b/src/transport.jl index 9d8b4126..b9ae2a67 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -123,7 +123,7 @@ transport_def(::Any, ::Any, x::NoTransformOrigin) = x transport_def(::Any, ::Any, x::NoTransport) = x function transport_def(ν, μ, x) - _vartransform_with_intermediate(ν, _checked_vartransform_origin(ν), _checked_vartransform_origin(μ), μ, x) + _transport_with_intermediate(ν, _checked_transport_origin(ν), _checked_transport_origin(μ), μ, x) end @@ -132,13 +132,13 @@ function _origin_must_have_separate_type(::Type{MU}, μ_o::MU) where MU throw(ArgumentError("Measure of type $MU and its origin must have separate types")) end -@inline function _checked_vartransform_origin(μ::MU) where MU +@inline function _checked_transport_origin(μ::MU) where MU μ_o = transport_origin(μ) _origin_must_have_separate_type(MU, μ_o) end -function _vartransform_with_intermediate(ν, ν_o, μ_o, μ, x) +function _transport_with_intermediate(ν, ν_o, μ_o, μ, x) x_o = to_origin(μ, x) # If μ is a pushforward then checked_arg may have been bypassed, so check now: y_o = transport_def(ν_o, μ_o, checked_arg(μ_o, x_o)) @@ -146,37 +146,37 @@ function _vartransform_with_intermediate(ν, ν_o, μ_o, μ, x) return y end -function _vartransform_with_intermediate(ν, ν_o, ::NoTransformOrigin, μ, x) +function _transport_with_intermediate(ν, ν_o, ::NoTransformOrigin, μ, x) y_o = transport_def(ν_o, μ, x) y = from_origin(ν, y_o) return y end -function _vartransform_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x) +function _transport_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x) x_o = to_origin(μ, x) # If μ is a pushforward then checked_arg may have been bypassed, so check now: y = transport_def(ν, μ_o, checked_arg(μ_o, x_o)) return y end -function _vartransform_with_intermediate(ν, ::NoTransformOrigin, ::NoTransformOrigin, μ, x) - _vartransform_with_intermediate(ν, _vartransform_intermediate(ν, μ), μ, x) +function _transport_with_intermediate(ν, ::NoTransformOrigin, ::NoTransformOrigin, μ, x) + _transport_with_intermediate(ν, _transport_intermediate(ν, μ), μ, x) end -@inline _vartransform_intermediate(ν, μ) = _vartransform_intermediate(getdof(ν), getdof(μ)) -@inline _vartransform_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ -@inline _vartransform_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform() +@inline _transport_intermediate(ν, μ) = _transport_intermediate(getdof(ν), getdof(μ)) +@inline _transport_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ +@inline _transport_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform() -function _vartransform_with_intermediate(ν, m, μ, x) +function _transport_with_intermediate(ν, m, μ, x) z = transport_def(m, μ, x) y = transport_def(ν, m, z) return y end # Prevent infinite recursion in case vartransform_intermediate doesn't change type: -@inline _vartransform_with_intermediate(::NU, ::NU, ::MU, ::Any) where {NU,MU} = NoTransport{NU,MU}() -@inline _vartransform_with_intermediate(::NU, ::MU, ::MU, ::Any) where {NU,MU} = NoTransport{NU,MU}() +@inline _transport_with_intermediate(::NU, ::NU, ::MU, ::Any) where {NU,MU} = NoTransport{NU,MU}() +@inline _transport_with_intermediate(::NU, ::MU, ::MU, ::Any) where {NU,MU} = NoTransport{NU,MU}() """ From 552efd6ed82d7a095856e6f8129186b05f83fdd2 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 19 Jun 2022 21:06:34 -0400 Subject: [PATCH 8/9] Clarify NoVolCorr and WithVolCorr docstrings --- src/transport.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transport.jl b/src/transport.jl index b9ae2a67..0195ab03 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -262,7 +262,7 @@ abstract type TransformVolCorr end NoVolCorr() Indicate that density calculations should ignore the volume element of -var transformations. Should only be used in special cases in which +variate transformations. Should only be used in special cases in which the volume element has already been taken into account in a different way. """ @@ -272,7 +272,7 @@ struct NoVolCorr <: TransformVolCorr end WithVolCorr() Indicate that density calculations should take the volume element of -var transformations into account (typically via the +variate transformations into account (typically via the log-abs-det-Jacobian of the transform). """ struct WithVolCorr <: TransformVolCorr end From 90c88f62642d75f123011304819e380f6e59cabe Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 20 Jun 2022 11:31:40 -0400 Subject: [PATCH 9/9] Increase package version to v0.12 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fbce68d7..e95dd49e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MeasureBase" uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14" authors = ["Chad Scherrer and contributors"] -version = "0.11.0" +version = "0.12.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"