Skip to content

Commit

Permalink
Move statistics update to statistics class
Browse files Browse the repository at this point in the history
  • Loading branch information
jollyjonson committed Dec 11, 2023
1 parent e5e1a96 commit 5a922c1
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions frechet_audio_distance/frechet_audio_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,37 @@ def __init__(self, dim: int):
tf.zeros((dim, dim), dtype=tf.float64)
)
self.mean = tf.Variable(tf.zeros((dim,), dtype=tf.float64))
self.num_items_processed = tf.Variable(0.0, dtype=tf.float64)

def update(self, data: tf.Tensor) -> None:
"""
Updates the means and covariances held by an instance of this class
"""
data = tf.cast(data, dtype=tf.float64)
num_items_this_update = tf.cast(
tf.shape(data)[0], dtype=tf.float64
)
self.num_items_processed.assign_add(
tf.cast(num_items_this_update, dtype=tf.float64)
)

_statistics_dim: int = VGGishParams.EMBEDDING_SIZE
_num_items_processed: tf.Variable # dtype = tf.float64
x_norm_old = data - self.mean
self.mean.assign_add(
tf.reduce_sum(x_norm_old, axis=0) / self.num_items_processed
)
x_norm_new = data - self.mean

self.covariance.assign(
self.covariance
* (self.num_items_processed - num_items_this_update)
/ self.num_items_processed
)
self.covariance.assign_add(
tf.matmul(tf.transpose(x_norm_old), x_norm_new)
/ self.num_items_processed
)

_statistics_dim: int = VGGish.embedding_size
_true_statistics: _Statistics
_pred_statistics: _Statistics

Expand Down Expand Up @@ -99,22 +127,16 @@ def update_state(
y_true_embedding, y_pred_embedding = map(
self._vggish_model, [y_true_features, y_pred_features]
)
self._num_items_processed.assign_add(
tf.cast(tf.shape(y_true_embedding)[0], dtype=tf.float64)
)
for data, statistics in zip(
[y_true_embedding, y_pred_embedding],
[self._true_statistics, self._pred_statistics],
):
self._update_statistics(
statistics, data, self._num_items_processed
)
statistics.update(data)

def result(self) -> tf.Tensor: # pragma: no cover
return self._compute_distance()

def reset_state(self) -> None:
self._num_items_processed = tf.Variable(0, dtype=tf.float64)
self._true_statistics = self._Statistics(self._statistics_dim)
self._pred_statistics = self._Statistics(self._statistics_dim)

Expand Down

0 comments on commit 5a922c1

Please sign in to comment.