-
Notifications
You must be signed in to change notification settings - Fork 72
/
compute_pdrc.py
112 lines (94 loc) · 3.02 KB
/
compute_pdrc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# prdc
# Copyright (c) 2020-present NAVER Corp.
# MIT license
import numpy as np
import sklearn.metrics
__all__ = ["compute_prdc"]
def compute_pairwise_distance(data_x, data_y=None):
"""
Parameters
----------
data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
Returns
-------
numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
"""
if data_y is None:
data_y = data_x
dists = sklearn.metrics.pairwise_distances(
data_x, data_y, metric="euclidean", n_jobs=8
)
return dists
def get_kth_value(unsorted, k, axis=-1):
"""
Parameters
----------
unsorted: numpy.ndarray of any dimensionality.
k: int
axis: int
Returns
-------
kth values along the designated axis.
"""
indices = np.argpartition(unsorted, k, axis=axis)[..., :k]
k_smallests = np.take_along_axis(unsorted, indices, axis=axis)
kth_values = k_smallests.max(axis=axis)
return kth_values
def compute_nearest_neighbour_distances(input_features, nearest_k):
"""
Parameters
----------
input_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
nearest_k: int
Returns
-------
Distances to kth nearest neighbours.
"""
distances = compute_pairwise_distance(input_features)
radii = get_kth_value(distances, k=nearest_k + 1, axis=-1)
return radii
def compute_prdc(real_features, fake_features, nearest_k):
"""
Computes precision, recall, density, and coverage given two manifolds.
Parameters
----------
real_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
nearest_k: int.
Returns
-------
dict of precision, recall, density, and coverage.
"""
print(
"Num real: {} Num fake: {}".format(
real_features.shape[0], fake_features.shape[0]
)
)
real_nearest_neighbour_distances = compute_nearest_neighbour_distances(
real_features, nearest_k
)
fake_nearest_neighbour_distances = compute_nearest_neighbour_distances(
fake_features, nearest_k
)
distance_real_fake = compute_pairwise_distance(real_features, fake_features)
precision = (
(distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1))
.any(axis=0)
.mean()
)
recall = (
(distance_real_fake < np.expand_dims(fake_nearest_neighbour_distances, axis=0))
.any(axis=1)
.mean()
)
density = (1.0 / float(nearest_k)) * (
distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1)
).sum(axis=0).mean()
coverage = (
distance_real_fake.min(axis=1) < real_nearest_neighbour_distances
).mean()
return dict(precision=precision, recall=recall, density=density, coverage=coverage)