From 1590e79a93613b9fd6e8021f142310c2540dc565 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 14 Oct 2023 18:28:31 +0100 Subject: [PATCH] WIP: add methods for soft promotion --- ext/DynamicQuantitiesMeasurementsExt.jl | 4 ++-- src/math.jl | 6 +++--- src/symbolic_dimensions.jl | 23 ++++++++++++++++++++++ src/utils.jl | 26 ++++++++++++++++++++----- test/unittests.jl | 9 +++++++++ 5 files changed, 58 insertions(+), 10 deletions(-) diff --git a/ext/DynamicQuantitiesMeasurementsExt.jl b/ext/DynamicQuantitiesMeasurementsExt.jl index 1a5294fe..b3287bdb 100644 --- a/ext/DynamicQuantitiesMeasurementsExt.jl +++ b/ext/DynamicQuantitiesMeasurementsExt.jl @@ -1,10 +1,10 @@ module DynamicQuantitiesMeasurementsExt -using DynamicQuantities: AbstractQuantity, new_quantity, dimension, ustrip, DimensionError +using DynamicQuantities: AbstractQuantity, new_quantity, dimension, ustrip, dimension_promote using Measurements: Measurements, measurement, value, uncertainty function Measurements.measurement(a::Q, b::Q) where {Q<:AbstractQuantity} - dimension(a) == dimension(b) || throw(DimensionError(a, b)) + a, b = dimension_promote(a, b) raw_measurement = measurement(ustrip(a), ustrip(b)) return new_quantity(Q, raw_measurement, dimension(a)) end diff --git a/src/math.jl b/src/math.jl index 1f8f88ba..f1643dfd 100644 --- a/src/math.jl +++ b/src/math.jl @@ -18,7 +18,7 @@ Base.:/(l, r::AbstractDimensions) = error("Please use an `AbstractQuantity` for Base.:+(l::AbstractQuantity, r::AbstractQuantity) = let - dimension(l) == dimension(r) || throw(DimensionError(l, r)) + l, r = dimension_promote(l, r) new_quantity(typeof(l), ustrip(l) + ustrip(r), dimension(l)) end Base.:-(l::AbstractQuantity) = new_quantity(typeof(l), -ustrip(l), dimension(l)) @@ -26,12 +26,12 @@ Base.:-(l::AbstractQuantity, r::AbstractQuantity) = l + (-r) Base.:+(l::AbstractQuantity, r) = let - iszero(dimension(l)) || throw(DimensionError(l, r)) + l, r = dimension_promote(l, r) new_quantity(typeof(l), ustrip(l) + r, dimension(l)) end Base.:+(l, r::AbstractQuantity) = let - iszero(dimension(r)) || throw(DimensionError(l, r)) + l, r = dimension_promote(l, r) new_quantity(typeof(r), l + ustrip(r), dimension(r)) end Base.:-(l::AbstractQuantity, r) = l + (-r) diff --git a/src/symbolic_dimensions.jl b/src/symbolic_dimensions.jl index af5926a7..6a2c030c 100644 --- a/src/symbolic_dimensions.jl +++ b/src/symbolic_dimensions.jl @@ -130,6 +130,29 @@ a function equivalent to `q -> uconvert(qout, q)`. """ uconvert(qout::AbstractQuantity{<:Any, <:SymbolicDimensions}) = Base.Fix1(uconvert, qout) +function dimension_promote(l::AbstractQuantity{<:Any,<:SymbolicDimensions}, r::AbstractQuantity{<:Any,<:SymbolicDimensions}) + if dimension(l) == dimension(r) + return l, r + else + # We can first try to make the dimensions equivalent + l_unit = l / ustrip(l) + r_raw = uconvert(l_unit, expand_units(r)) + # Ensure type stability: + r = convert(typeof(r), r_raw) + return l, r + end +end +function dimension_promote(l::AbstractQuantity{T,S}, r) where {T,S<:SymbolicDimensions} + l_raw = uconvert(Quantity(one(T), S), expand_units(l)) + l = convert(typeof(l), l_raw) + return l, r +end +function dimension_promote(l, r::AbstractQuantity{T,S}) where {T,S<:SymbolicDimensions} + r_raw = uconvert(Quantity(one(T), S), expand_units(r)) + r = convert(typeof(r), r_raw) + return l, r +end + Base.copy(d::SymbolicDimensions) = SymbolicDimensions(copy(getfield(d, :nzdims)), copy(getfield(d, :nzvals))) function Base.:(==)(l::SymbolicDimensions, r::SymbolicDimensions) nzdims_l = getfield(l, :nzdims) diff --git a/src/utils.jl b/src/utils.jl index 41ee4650..bff66193 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -59,11 +59,11 @@ function Base.isapprox(l::AbstractQuantity, r::AbstractQuantity; kws...) return isapprox(ustrip(l), ustrip(r); kws...) && dimension(l) == dimension(r) end function Base.isapprox(l, r::AbstractQuantity; kws...) - iszero(dimension(r)) || throw(DimensionError(l, r)) + l, r = dimension_promote(l, r) return isapprox(l, ustrip(r); kws...) end function Base.isapprox(l::AbstractQuantity, r; kws...) - iszero(dimension(l)) || throw(DimensionError(l, r)) + l, r = dimension_promote(l, r) return isapprox(ustrip(l), r; kws...) end Base.iszero(d::AbstractDimensions) = all_dimensions(iszero, d) @@ -72,15 +72,15 @@ Base.:(==)(l::AbstractQuantity, r::AbstractQuantity) = ustrip(l) == ustrip(r) && Base.:(==)(l, r::AbstractQuantity) = ustrip(l) == ustrip(r) && iszero(dimension(r)) Base.:(==)(l::AbstractQuantity, r) = ustrip(l) == ustrip(r) && iszero(dimension(l)) function Base.isless(l::AbstractQuantity, r::AbstractQuantity) - dimension(l) == dimension(r) || throw(DimensionError(l, r)) + l, r = dimension_promote(l, r) return isless(ustrip(l), ustrip(r)) end function Base.isless(l::AbstractQuantity, r) - iszero(dimension(l)) || throw(DimensionError(l, r)) + l, r = dimension_promote(l, r) return isless(ustrip(l), r) end function Base.isless(l, r::AbstractQuantity) - iszero(dimension(r)) || throw(DimensionError(l, r)) + l, r = dimension_promote(l, r) return isless(l, ustrip(r)) end @@ -267,3 +267,19 @@ Get the amount dimension of a quantity (e.g., mol^(uamount)). """ uamount(q::AbstractQuantity) = uamount(dimension(q)) uamount(d::AbstractDimensions) = d.amount + + +"""This function allows custom behavior for dimensionality analysis""" +@inline function dimension_promote(l::AbstractQuantity, r::AbstractQuantity) + dimension(l) == dimension(r) || throw(DimensionError(l, r)) + return l, r +end +@inline function dimension_promote(l::AbstractQuantity, r) + iszero(dimension(l)) || throw(DimensionError(l, r)) + return l, r +end +@inline function dimension_promote(l, r::AbstractQuantity) + iszero(dimension(r)) || throw(DimensionError(l, r)) + return l, r +end +# TODO: May want to have methods for arrays as well diff --git a/test/unittests.jl b/test/unittests.jl index bc3327c3..525ac4b2 100644 --- a/test/unittests.jl +++ b/test/unittests.jl @@ -615,6 +615,15 @@ end @test qs ≈ 7.5us"g" end +@testset "Soft conversion" begin + x = 1.5us"km" + y = 1.5us"m" + @test x + y == 1.5015us"km" + @test y + x == 1501.5us"m" + + # TODO: Should allow `==` for non-equal dimensions +end + @testset "Test ambiguities" begin R = DEFAULT_DIM_BASE_TYPE x = convert(R, 10)