From 1f1d18060a9ac99879f53895375741fc4619f01e Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Wed, 9 Apr 2025 10:47:05 -0300 Subject: [PATCH 1/4] Implement Adjusted Mutual Information --- Project.toml | 11 ++++++- ext/SpecialFunctionsExt.jl | 64 ++++++++++++++++++++++++++++++++++++++ src/Clustering.jl | 4 +++ src/mutualinfo.jl | 63 +++++++++++++++++++++++++++++++------ test/mutualinfo.jl | 26 +++++++++++++--- test/runtests.jl | 1 + 6 files changed, 155 insertions(+), 14 deletions(-) create mode 100644 ext/SpecialFunctionsExt.jl diff --git a/Project.toml b/Project.toml index b68444e7..f6460845 100644 --- a/Project.toml +++ b/Project.toml @@ -9,12 +9,20 @@ NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +[weakdeps] +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[extensions] +SpecialFunctionsExt = "SpecialFunctions" + [compat] Distances = "0.10.9" NearestNeighbors = "0.4" +SpecialFunctions = ">= 0.8" Statistics = "1" StatsBase = "0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34" julia = "1" @@ -26,9 +34,10 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CodecZlib", "Statistics", "LinearAlgebra", "SparseArrays", "Distances", "Random", "DelimitedFiles", "StableRNGs", "Test"] +test = ["CodecZlib", "Statistics", "LinearAlgebra", "SparseArrays", "Distances", "Random", "DelimitedFiles", "SpecialFunctions", "StableRNGs", "Test"] diff --git a/ext/SpecialFunctionsExt.jl b/ext/SpecialFunctionsExt.jl new file mode 100644 index 00000000..d4289cd1 --- /dev/null +++ b/ext/SpecialFunctionsExt.jl @@ -0,0 +1,64 @@ +module SpecialFunctionsExt # Should be same name as the file (just like a normal package) + +using SpecialFunctions: loggamma +using StatsBase: counts +using Statistics: middle + +import Clustering: _mutualinfo + +function _mutualinfo(::Val{:adjusted}, a, b; aggregate::Symbol = :mean) + norm_f = if aggregate === :mean + middle + elseif aggregate === :geomean + (a,b) -> sqrt(a*b) + elseif aggregate === :max + max + elseif aggregate === :min + min + else + throw(ArgumentError("Valid options for `aggregate` are: `mean`, `geomean`, `max`, `min`")) + end + + return _mutualinfo(counts(a, b)) do hck, hc, hk, rows, cols, N + mi = hc - hck + emi = _expectedmutualinfo(rows, cols, N) + normalizer = norm_f(hc, hk) + denominator = normalizer - emi + (mi - emi) / denominator + end +end + +# Adjusted Mutual Information + +function _expectedmutualinfo(a, b, n_samples) + nijs = 1:max(maximum(a), maximum(b)) + + term1 = nijs ./ n_samples + + log_ab = [log(a[i]) + log(b[j]) for i in eachindex(a), j in eachindex(b)] + log_Nnij = log(n_samples) .+ log.(nijs) + + gln_a = loggamma.(a .+ 1) + gln_b = loggamma.(b .+ 1) + gln_Na = loggamma.(n_samples .- a .+ 1) + gln_Nb = loggamma.(n_samples .- b .+ 1) + gln_Nnij = loggamma.(nijs .+ 1) .+ loggamma.(n_samples + 1) + + emi = zero(Float64) + for i in eachindex(a), j in eachindex(b) + nij_idxs = max(1, a[i] - n_samples + b[j]):min(a[i], b[j]) + for nij in nij_idxs + term2 = log_Nnij[nij] - log_ab[i,j] + gln = ( + gln_a[i] + gln_b[j] + gln_Na[i] + gln_Nb[j] - gln_Nnij[nij] - + loggamma(a[i] - nij + 1) - loggamma(b[j] - nij + 1) - + loggamma(n_samples - a[i] - b[j] + nij + 1) + ) + term3 = exp(gln) + emi += (term1[nij] * term2 * term3) + end + end + return emi +end + +end # module diff --git a/src/Clustering.jl b/src/Clustering.jl index fae9ee82..0c5ee6d3 100644 --- a/src/Clustering.jl +++ b/src/Clustering.jl @@ -100,4 +100,8 @@ module Clustering include("hclust.jl") include("deprecate.jl") + + if !isdefined(Base, :get_extension) + include("../ext/SpecialFunctionsExt.jl") + end end diff --git a/src/mutualinfo.jl b/src/mutualinfo.jl index 65b0a527..817ba413 100644 --- a/src/mutualinfo.jl +++ b/src/mutualinfo.jl @@ -1,6 +1,6 @@ # Mutual Information -function _mutualinfo(A::AbstractMatrix{<:Integer}, normed::Bool) +@inline function _mutualinfo(f, A::AbstractMatrix{<:Integer}) N = sum(A) (N == 0.0) && return 0.0 @@ -14,16 +14,21 @@ function _mutualinfo(A::AbstractMatrix{<:Integer}, normed::Bool) hc = entArows/N + log(N) hk = entAcols/N + log(N) - mi = hc - hck - return if normed - 2*mi/(hc+hk) + f(hck,hc,hk,rows,cols,N) +end + +function _mutualinfo(::Val{T}, a, b; kwargs...) where T + if T === :adjusted && isempty(kwargs) + error("Error: mutualinfo(): `method=:adjusted` requires SpecialFunctions package to be loaded") + elseif T !== :classic && T !== :normalized + throw(ArgumentError("mutualinfo(): `method=:$(T)` is not supported")) else - mi + throw(ArgumentError("mutualinfo(): unsupported kwargs used. See the `mutualinfo` docstring for more information")) end end """ - mutualinfo(a, b; normed=true) -> Float64 + mutualinfo(a, b; method=:normalized, kwargs...) -> Float64 Compute the *mutual information* between the two clusterings of the same data points. @@ -31,11 +36,51 @@ data points. `a` and `b` can be either [`ClusteringResult`](@ref) instances or assignments vectors (`AbstractVector{<:Integer}`). -If `normed` parameter is `true` the return value is the normalized mutual information (symmetric uncertainty), -see "Data Mining Practical Machine Tools and Techniques", Witten & Frank 2005. +`method` can be one of `:classic`, `:normalized` (default), or `:adjusted`, to calculate the +original mutual information score, the normalized mutual information, or the adjusted mutual +information respectively. + +When `method=:adjusted`, the `aggregate` kwarg determines how the normalizer +in the denominator is computed. It can be one of: +- `:mean`: The arithmetic mean of two values +- `:geomean`: The geometric mean of two values +- `:max`: The highest of two values +- `:min`: The lowest of two values # References > Vinh, Epps, and Bailey, (2009). *Information theoretic measures for clusterings comparison*. > Proceedings of the 26th Annual International Conference on Machine Learning - ICML ‘09. + +> "Data Mining Practical Machine Tools and Techniques", Witten & Frank 2005. """ -mutualinfo(a, b; normed::Bool=true) = _mutualinfo(counts(a, b), normed) +function mutualinfo(a, b; method::Union{Nothing, Symbol} = nothing, normed::Union{Nothing, Bool} = nothing, kwargs...) + # Disallow `method` and `normed` to be used together + if isnothing(method) + isnothing(normed) || Base.depwarn("`normed` kwarg is deprecated, please use `method=:normalized` instead of `normed=true`, and `method=:classic` instead of `normed=false'", :mutualinfo) + method = if isnothing(normed) || normed + :normalized + else + :classic + end + else + isnothing(normed) || throw(ArgumentError("`normed` kwarg is not compatible with `method` kwarg")) + end + # Little hack to ensure the correct error is thrown + if method === :adjusted && length(kwargs) >= 1 && :aggregate ∉ keys(kwargs) + method = :classic + end + + _mutualinfo(Val(method), a, b; kwargs...) +end + +function _mutualinfo(::Val{:normalized}, a, b) + return _mutualinfo(counts(a, b)) do hck, hc, hk, _, _, _ + mi = hc - hck + 2*mi/(hc+hk) + end +end +function _mutualinfo(::Val{:classic}, a, b) + return _mutualinfo(counts(a, b)) do hck, hc, _, _, _, _ + hc - hck + end +end diff --git a/test/mutualinfo.jl b/test/mutualinfo.jl index f2caeed3..a29e203b 100644 --- a/test/mutualinfo.jl +++ b/test/mutualinfo.jl @@ -6,13 +6,31 @@ using Clustering # https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html a1 = [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3] a2 = [1, 1, 1, 1, 1, 2, 3, 3, 1, 2, 2, 2, 2, 2, 3, 3, 3] - @test mutualinfo(a1, a2, normed=false) ≈ 0.39 atol=1.0e-2 - @test mutualinfo(a1, a2) ≈ 0.36 atol=1.0e-2 + @test mutualinfo(a1, a2; method=:classic) ≈ 0.39 atol=1.0e-2 + @test mutualinfo(a1, a2;) ≈ 0.36 atol=1.0e-2 + @test mutualinfo(a1, a2; method=:adjusted) ≈ 0.2602 atol=1.0e-4 + @test mutualinfo(a1, a2; method=:adjusted, aggregate=:geomean) ≈ 0.2602 atol=1.0e-4 + @test mutualinfo(a1, a2; method=:adjusted, aggregate=:max) ≈ 0.2547 atol=1.0e-4 + @test mutualinfo(a1, a2; method=:adjusted, aggregate=:min) ≈ 0.2659 atol=1.0e-4 + + # test deprecated kwarg + @test mutualinfo(a1, a2; normed=false) ≈ 0.39 atol=1.0e-2 + @test mutualinfo(a1, a2; normed=true) ≈ 0.36 atol=1.0e-2 # https://doi.org/10.1186/1471-2105-7-380 a1 = [1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 1, 2] a2 = [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4] - @test mutualinfo(a1, a2, normed=false) ≈ 0.6 atol=0.1 - @test mutualinfo(a1, a2) ≈ 0.5 atol=0.1 + @test mutualinfo(a1, a2; method=:classic) ≈ 0.6 atol=0.1 + @test mutualinfo(a1, a2; method=:normalized) ≈ 0.5 atol=0.1 + @test mutualinfo(a1, a2; method=:adjusted, aggregate=:mean) ≈ 0.3839 atol=1.0e-4 + @test mutualinfo(a1, a2; method=:adjusted, aggregate=:geomean) ≈ 0.3861 atol=1.0e-4 + @test mutualinfo(a1, a2; method=:adjusted, aggregate=:max) ≈ 0.3437 atol=1.0e-4 + @test mutualinfo(a1, a2; method=:adjusted, aggregate=:min) ≈ 0.4348 atol=1.0e-4 + + # test errors + @test_throws "ArgumentError: `normed` kwarg is not compatible with `method` kwarg" mutualinfo(a1, a2; method=:adjusfted, normed=false) + @test_throws "ArgumentError: mutualinfo(): `method=:adjusfted` is not supported" mutualinfo(a1, a2; method=:adjusfted, aggregate=:min) + @test_throws "ArgumentError: mutualinfo(): unsupported kwargs used." mutualinfo(a1, a2; method=:adjusted, notaggregate=:min) + @test_throws "ArgumentError: mutualinfo(): unsupported kwargs used." mutualinfo(a1, a2; method=:classic, notaggregate=:min) end diff --git a/test/runtests.jl b/test/runtests.jl index 2e6a7894..2dfe14da 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Test using Random using LinearAlgebra using SparseArrays +using SpecialFunctions using StableRNGs using Statistics From 5a539ac4efd1f8f0d36e694b4d1a0f09cb7527dd Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Sun, 18 May 2025 10:56:13 -0300 Subject: [PATCH 2/4] tests: Fix tests for older Julia Also a couple of clarity fixes --- src/mutualinfo.jl | 10 +++++----- test/mutualinfo.jl | 17 ++++++++++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/mutualinfo.jl b/src/mutualinfo.jl index 817ba413..65b0afca 100644 --- a/src/mutualinfo.jl +++ b/src/mutualinfo.jl @@ -55,17 +55,17 @@ in the denominator is computed. It can be one of: """ function mutualinfo(a, b; method::Union{Nothing, Symbol} = nothing, normed::Union{Nothing, Bool} = nothing, kwargs...) # Disallow `method` and `normed` to be used together - if isnothing(method) - isnothing(normed) || Base.depwarn("`normed` kwarg is deprecated, please use `method=:normalized` instead of `normed=true`, and `method=:classic` instead of `normed=false'", :mutualinfo) - method = if isnothing(normed) || normed + if method === nothing + (normed === nothing) || Base.depwarn("`normed` kwarg is deprecated, please use `method=:normalized` instead of `normed=true`, and `method=:classic` instead of `normed=false'", :mutualinfo) + method = if (normed === nothing) || normed :normalized else :classic end else - isnothing(normed) || throw(ArgumentError("`normed` kwarg is not compatible with `method` kwarg")) + (normed === nothing) || throw(ArgumentError("`normed` kwarg is not compatible with `method` kwarg")) end - # Little hack to ensure the correct error is thrown + # Little hack to ensure the invalid kwargs error is thrown if method === :adjusted && length(kwargs) >= 1 && :aggregate ∉ keys(kwargs) method = :classic end diff --git a/test/mutualinfo.jl b/test/mutualinfo.jl index a29e203b..9fe89ecb 100644 --- a/test/mutualinfo.jl +++ b/test/mutualinfo.jl @@ -27,10 +27,17 @@ using Clustering @test mutualinfo(a1, a2; method=:adjusted, aggregate=:max) ≈ 0.3437 atol=1.0e-4 @test mutualinfo(a1, a2; method=:adjusted, aggregate=:min) ≈ 0.4348 atol=1.0e-4 - # test errors - @test_throws "ArgumentError: `normed` kwarg is not compatible with `method` kwarg" mutualinfo(a1, a2; method=:adjusfted, normed=false) - @test_throws "ArgumentError: mutualinfo(): `method=:adjusfted` is not supported" mutualinfo(a1, a2; method=:adjusfted, aggregate=:min) - @test_throws "ArgumentError: mutualinfo(): unsupported kwargs used." mutualinfo(a1, a2; method=:adjusted, notaggregate=:min) - @test_throws "ArgumentError: mutualinfo(): unsupported kwargs used." mutualinfo(a1, a2; method=:classic, notaggregate=:min) + # test errors. More precise tests on Julia 1.8+ when supported + if VERSION >= v"1.8" + @test_throws "ArgumentError: `normed` kwarg is not compatible with `method` kwarg" mutualinfo(a1, a2; method=:adjusted, normed=false) + @test_throws "ArgumentError: mutualinfo(): `method=:adjusfted` is not supported" mutualinfo(a1, a2; method=:adjusfted, aggregate=:min) + @test_throws "ArgumentError: mutualinfo(): unsupported kwargs used." mutualinfo(a1, a2; method=:adjusted, notaggregate=:min) + @test_throws "ArgumentError: mutualinfo(): unsupported kwargs used." mutualinfo(a1, a2; method=:classic, notaggregate=:min) + else + @test_throws ArgumentError mutualinfo(a1, a2; method=:adjusted, normed=false) + @test_throws ArgumentError mutualinfo(a1, a2; method=:adjusfted, aggregate=:min) + @test_throws ArgumentError mutualinfo(a1, a2; method=:adjusted, notaggregate=:min) + @test_throws ArgumentError mutualinfo(a1, a2; method=:classic, notaggregate=:min) + end end From b4dd79472474ab824f3eac7cfd9d0706d74740a1 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 10 Apr 2025 12:11:43 -0300 Subject: [PATCH 3/4] Project.toml: Clean up --- Project.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index f6460845..b793659d 100644 --- a/Project.toml +++ b/Project.toml @@ -30,14 +30,11 @@ julia = "1" [extras] CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" -Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CodecZlib", "Statistics", "LinearAlgebra", "SparseArrays", "Distances", "Random", "DelimitedFiles", "SpecialFunctions", "StableRNGs", "Test"] +test = ["CodecZlib", "LinearAlgebra", "SparseArrays", "DelimitedFiles", "SpecialFunctions", "StableRNGs", "Test"] From 275b9dc90d00dd36b5cd7cf28986798d91f0b192 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Sun, 18 May 2025 11:26:36 -0300 Subject: [PATCH 4/4] CI: Fix codecov and macOS arch --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 788566ec..85694bf1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: arch: x64 - os: macos-latest version: '1' - arch: x64 + arch: aarch64 steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -48,7 +48,7 @@ jobs: - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v5 with: - file: lcov.info + files: lcov.info docs: name: Documentation runs-on: ubuntu-latest