Skip to content

Implement Adjusted mutual information #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
arch: x64
- os: macos-latest
version: '1'
arch: x64
arch: aarch64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand All @@ -48,7 +48,7 @@ jobs:
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v5
with:
file: lcov.info
files: lcov.info
docs:
name: Documentation
runs-on: ubuntu-latest
Expand Down
14 changes: 10 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,32 @@ NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
SpecialFunctionsExt = "SpecialFunctions"

[compat]
Distances = "0.10.9"
NearestNeighbors = "0.4"
SpecialFunctions = ">= 0.8"
Statistics = "1"
StatsBase = "0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
julia = "1"

[extras]
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CodecZlib", "Statistics", "LinearAlgebra", "SparseArrays", "Distances", "Random", "DelimitedFiles", "StableRNGs", "Test"]
test = ["CodecZlib", "LinearAlgebra", "SparseArrays", "DelimitedFiles", "SpecialFunctions", "StableRNGs", "Test"]
64 changes: 64 additions & 0 deletions ext/SpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
module SpecialFunctionsExt # Should be same name as the file (just like a normal package)

using SpecialFunctions: loggamma
using StatsBase: counts
using Statistics: middle

import Clustering: _mutualinfo

function _mutualinfo(::Val{:adjusted}, a, b; aggregate::Symbol = :mean)
norm_f = if aggregate === :mean
middle
elseif aggregate === :geomean
(a,b) -> sqrt(a*b)
elseif aggregate === :max
max
elseif aggregate === :min
min
else
throw(ArgumentError("Valid options for `aggregate` are: `mean`, `geomean`, `max`, `min`"))
end

return _mutualinfo(counts(a, b)) do hck, hc, hk, rows, cols, N
mi = hc - hck
emi = _expectedmutualinfo(rows, cols, N)
normalizer = norm_f(hc, hk)
denominator = normalizer - emi
(mi - emi) / denominator
end
end

# Adjusted Mutual Information

function _expectedmutualinfo(a, b, n_samples)
nijs = 1:max(maximum(a), maximum(b))

term1 = nijs ./ n_samples

log_ab = [log(a[i]) + log(b[j]) for i in eachindex(a), j in eachindex(b)]
log_Nnij = log(n_samples) .+ log.(nijs)

gln_a = loggamma.(a .+ 1)
gln_b = loggamma.(b .+ 1)
gln_Na = loggamma.(n_samples .- a .+ 1)
gln_Nb = loggamma.(n_samples .- b .+ 1)
gln_Nnij = loggamma.(nijs .+ 1) .+ loggamma.(n_samples + 1)

emi = zero(Float64)
for i in eachindex(a), j in eachindex(b)
nij_idxs = max(1, a[i] - n_samples + b[j]):min(a[i], b[j])
for nij in nij_idxs
term2 = log_Nnij[nij] - log_ab[i,j]
gln = (
gln_a[i] + gln_b[j] + gln_Na[i] + gln_Nb[j] - gln_Nnij[nij] -
loggamma(a[i] - nij + 1) - loggamma(b[j] - nij + 1) -
loggamma(n_samples - a[i] - b[j] + nij + 1)
)
term3 = exp(gln)
emi += (term1[nij] * term2 * term3)
end
end
return emi
end

end # module
4 changes: 4 additions & 0 deletions src/Clustering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,8 @@ module Clustering
include("hclust.jl")

include("deprecate.jl")

if !isdefined(Base, :get_extension)
include("../ext/SpecialFunctionsExt.jl")
end
end
63 changes: 54 additions & 9 deletions src/mutualinfo.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Mutual Information

function _mutualinfo(A::AbstractMatrix{<:Integer}, normed::Bool)
@inline function _mutualinfo(f, A::AbstractMatrix{<:Integer})
N = sum(A)
(N == 0.0) && return 0.0

Expand All @@ -14,28 +14,73 @@ function _mutualinfo(A::AbstractMatrix{<:Integer}, normed::Bool)
hc = entArows/N + log(N)
hk = entAcols/N + log(N)

mi = hc - hck
return if normed
2*mi/(hc+hk)
f(hck,hc,hk,rows,cols,N)
end

function _mutualinfo(::Val{T}, a, b; kwargs...) where T
if T === :adjusted && isempty(kwargs)
error("Error: mutualinfo(): `method=:adjusted` requires SpecialFunctions package to be loaded")
elseif T !== :classic && T !== :normalized
throw(ArgumentError("mutualinfo(): `method=:$(T)` is not supported"))
else
mi
throw(ArgumentError("mutualinfo(): unsupported kwargs used. See the `mutualinfo` docstring for more information"))
end
end

"""
mutualinfo(a, b; normed=true) -> Float64
mutualinfo(a, b; method=:normalized, kwargs...) -> Float64

Compute the *mutual information* between the two clusterings of the same
data points.

`a` and `b` can be either [`ClusteringResult`](@ref) instances or
assignments vectors (`AbstractVector{<:Integer}`).

If `normed` parameter is `true` the return value is the normalized mutual information (symmetric uncertainty),
see "Data Mining Practical Machine Tools and Techniques", Witten & Frank 2005.
`method` can be one of `:classic`, `:normalized` (default), or `:adjusted`, to calculate the
original mutual information score, the normalized mutual information, or the adjusted mutual
information respectively.

When `method=:adjusted`, the `aggregate` kwarg determines how the normalizer
in the denominator is computed. It can be one of:
- `:mean`: The arithmetic mean of two values
- `:geomean`: The geometric mean of two values
- `:max`: The highest of two values
- `:min`: The lowest of two values

# References
> Vinh, Epps, and Bailey, (2009). *Information theoretic measures for clusterings comparison*.
> Proceedings of the 26th Annual International Conference on Machine Learning - ICML ‘09.

> "Data Mining Practical Machine Tools and Techniques", Witten & Frank 2005.
"""
mutualinfo(a, b; normed::Bool=true) = _mutualinfo(counts(a, b), normed)
function mutualinfo(a, b; method::Union{Nothing, Symbol} = nothing, normed::Union{Nothing, Bool} = nothing, kwargs...)
# Disallow `method` and `normed` to be used together
if method === nothing
(normed === nothing) || Base.depwarn("`normed` kwarg is deprecated, please use `method=:normalized` instead of `normed=true`, and `method=:classic` instead of `normed=false'", :mutualinfo)
method = if (normed === nothing) || normed
:normalized
else
:classic
end
else
(normed === nothing) || throw(ArgumentError("`normed` kwarg is not compatible with `method` kwarg"))
end
# Little hack to ensure the invalid kwargs error is thrown
if method === :adjusted && length(kwargs) >= 1 && :aggregate ∉ keys(kwargs)
method = :classic
end

_mutualinfo(Val(method), a, b; kwargs...)
end

function _mutualinfo(::Val{:normalized}, a, b)
return _mutualinfo(counts(a, b)) do hck, hc, hk, _, _, _
mi = hc - hck
2*mi/(hc+hk)
end
end
function _mutualinfo(::Val{:classic}, a, b)
return _mutualinfo(counts(a, b)) do hck, hc, _, _, _, _
hc - hck
end
end
33 changes: 29 additions & 4 deletions test/mutualinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,38 @@ using Clustering
# https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html
a1 = [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3]
a2 = [1, 1, 1, 1, 1, 2, 3, 3, 1, 2, 2, 2, 2, 2, 3, 3, 3]
@test mutualinfo(a1, a2, normed=false) ≈ 0.39 atol=1.0e-2
@test mutualinfo(a1, a2) ≈ 0.36 atol=1.0e-2
@test mutualinfo(a1, a2; method=:classic) ≈ 0.39 atol=1.0e-2
@test mutualinfo(a1, a2;) ≈ 0.36 atol=1.0e-2
@test mutualinfo(a1, a2; method=:adjusted) ≈ 0.2602 atol=1.0e-4
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:geomean) ≈ 0.2602 atol=1.0e-4
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:max) ≈ 0.2547 atol=1.0e-4
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:min) ≈ 0.2659 atol=1.0e-4

# test deprecated kwarg
@test mutualinfo(a1, a2; normed=false) ≈ 0.39 atol=1.0e-2
@test mutualinfo(a1, a2; normed=true) ≈ 0.36 atol=1.0e-2

# https://doi.org/10.1186/1471-2105-7-380
a1 = [1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 1, 2]
a2 = [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4]
@test mutualinfo(a1, a2, normed=false) ≈ 0.6 atol=0.1
@test mutualinfo(a1, a2) ≈ 0.5 atol=0.1
@test mutualinfo(a1, a2; method=:classic) ≈ 0.6 atol=0.1
@test mutualinfo(a1, a2; method=:normalized) ≈ 0.5 atol=0.1
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:mean) ≈ 0.3839 atol=1.0e-4
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:geomean) ≈ 0.3861 atol=1.0e-4
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:max) ≈ 0.3437 atol=1.0e-4
@test mutualinfo(a1, a2; method=:adjusted, aggregate=:min) ≈ 0.4348 atol=1.0e-4

# test errors. More precise tests on Julia 1.8+ when supported

if VERSION >= v"1.8"
@test_throws "ArgumentError: `normed` kwarg is not compatible with `method` kwarg" mutualinfo(a1, a2; method=:adjusted, normed=false)
@test_throws "ArgumentError: mutualinfo(): `method=:adjusfted` is not supported" mutualinfo(a1, a2; method=:adjusfted, aggregate=:min)
@test_throws "ArgumentError: mutualinfo(): unsupported kwargs used." mutualinfo(a1, a2; method=:adjusted, notaggregate=:min)
@test_throws "ArgumentError: mutualinfo(): unsupported kwargs used." mutualinfo(a1, a2; method=:classic, notaggregate=:min)
else
@test_throws ArgumentError mutualinfo(a1, a2; method=:adjusted, normed=false)
@test_throws ArgumentError mutualinfo(a1, a2; method=:adjusfted, aggregate=:min)
@test_throws ArgumentError mutualinfo(a1, a2; method=:adjusted, notaggregate=:min)
@test_throws ArgumentError mutualinfo(a1, a2; method=:classic, notaggregate=:min)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test
using Random
using LinearAlgebra
using SparseArrays
using SpecialFunctions
using StableRNGs
using Statistics

Expand Down
Loading