From 5a922c1aa5e73444f930f7e1c307ebda7aa9fe94 Mon Sep 17 00:00:00 2001 From: jollyjonson Date: Mon, 11 Dec 2023 10:59:33 +0100 Subject: [PATCH] Move statistics update to statistics class --- .../frechet_audio_distance.py | 40 ++++++++++++++----- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/frechet_audio_distance/frechet_audio_distance.py b/frechet_audio_distance/frechet_audio_distance.py index c043855..00fd7e7 100644 --- a/frechet_audio_distance/frechet_audio_distance.py +++ b/frechet_audio_distance/frechet_audio_distance.py @@ -44,9 +44,37 @@ def __init__(self, dim: int): tf.zeros((dim, dim), dtype=tf.float64) ) self.mean = tf.Variable(tf.zeros((dim,), dtype=tf.float64)) + self.num_items_processed = tf.Variable(0.0, dtype=tf.float64) + + def update(self, data: tf.Tensor) -> None: + """ + Updates the means and covariances held by an instance of this class + """ + data = tf.cast(data, dtype=tf.float64) + num_items_this_update = tf.cast( + tf.shape(data)[0], dtype=tf.float64 + ) + self.num_items_processed.assign_add( + tf.cast(num_items_this_update, dtype=tf.float64) + ) - _statistics_dim: int = VGGishParams.EMBEDDING_SIZE - _num_items_processed: tf.Variable # dtype = tf.float64 + x_norm_old = data - self.mean + self.mean.assign_add( + tf.reduce_sum(x_norm_old, axis=0) / self.num_items_processed + ) + x_norm_new = data - self.mean + + self.covariance.assign( + self.covariance + * (self.num_items_processed - num_items_this_update) + / self.num_items_processed + ) + self.covariance.assign_add( + tf.matmul(tf.transpose(x_norm_old), x_norm_new) + / self.num_items_processed + ) + + _statistics_dim: int = VGGish.embedding_size _true_statistics: _Statistics _pred_statistics: _Statistics @@ -99,22 +127,16 @@ def update_state( y_true_embedding, y_pred_embedding = map( self._vggish_model, [y_true_features, y_pred_features] ) - self._num_items_processed.assign_add( - tf.cast(tf.shape(y_true_embedding)[0], dtype=tf.float64) - ) for data, statistics in zip( [y_true_embedding, y_pred_embedding], [self._true_statistics, self._pred_statistics], ): - self._update_statistics( - statistics, data, self._num_items_processed - ) + statistics.update(data) def result(self) -> tf.Tensor: # pragma: no cover return self._compute_distance() def reset_state(self) -> None: - self._num_items_processed = tf.Variable(0, dtype=tf.float64) self._true_statistics = self._Statistics(self._statistics_dim) self._pred_statistics = self._Statistics(self._statistics_dim)