Skip to content

Commit

Permalink
Add VGGish as FADFeature implementation
Browse files Browse the repository at this point in the history
- in preparation for removing it from the FAD class
  • Loading branch information
jollyjonson committed Dec 10, 2023
1 parent 673e8c8 commit 2b1ae4e
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 2 deletions.
230 changes: 230 additions & 0 deletions frechet_audio_distance/vggish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import os

import numpy as np
import tensorflow as tf

from .feature import FADFeature

VGGISH_PUBLIC_MODEL_CHECKPOINT_URL: str = (
"https://storage.googleapis.com/tfhub-modules/google/vggish/1.tar.gz"
)


class VGGish(FADFeature):

def __init__(self, step_size_in_s: float = .5):
self._step_size_in_samples = int(
round(step_size_in_s * self.input_sample_rate_in_hz)
)
self.model = self._init_vggish_model()

@property
def input_sample_rate_in_hz(self) -> float:
return self.sample_rate_in_hz

@property
def output_dim(self):
return self.embedding_size

@tf.function
def __call__(self, audio: tf.Tensor) -> tf.Tensor: # pragma: no cover
mel_feature = self._extract_mel_features(audio)
embeddings = self.model(mel_feature)
return embeddings

# Parameters used in the VGGish model by the original authors.
# Content copied straight from
# https://github.com/tensorflow/models/tree/master/research/audioset/vggish

# Architectural constants.
num_frames = 96 # Frames in input mel-spectrogram patch.
embedding_size = 128 # Size of embedding layer.

# Hyperparameters used in feature and example generation.
sample_rate_in_hz = 16000
stft_window_length_seconds = 0.025
STFT_HOP_LENGTH_SECONDS = 0.010
num_mel_bins = 64
mel_min_hz = 125
mel_max_hz = 7500
log_offset = (
0.01 # Offset used for stabilized log of input mel-spectrogram.
)
example_window_seconds = 0.96 # Each example contains 96 10ms frames
example_hop_seconds = 0.96 # with zero overlap.

var_names = [
"vggish/conv1/weights:0",
"vggish/conv1/biases:0",
"vggish/conv2/weights:0",
"vggish/conv2/biases:0",
"vggish/conv3/conv3_1/weights:0",
"vggish/conv3/conv3_1/biases:0",
"vggish/conv3/conv3_2/weights:0",
"vggish/conv3/conv3_2/biases:0",
"vggish/conv4/conv4_1/weights:0",
"vggish/conv4/conv4_1/biases:0",
"vggish/conv4/conv4_2/weights:0",
"vggish/conv4/conv4_2/biases:0",
"vggish/fc1/fc1_1/weights:0",
"vggish/fc1/fc1_1/biases:0",
"vggish/fc1/fc1_2/weights:0",
"vggish/fc1/fc1_2/biases:0",
"vggish/fc2/weights:0",
"vggish/fc2/biases:0",
]

# spectrogram params
_num_mel_bins = 64
_log_additive_offset = 0.001
_log_floor = 1e-12
_window_length_secs = 0.025
_hop_length_secs = 0.010

_window_length_samples = int(round(sample_rate_in_hz
* _window_length_secs))
_hop_length_samples = int(
round(sample_rate_in_hz * _hop_length_secs)
)
_fft_length = 2 ** int(
np.ceil(np.log(_window_length_samples) / np.log(2.0))
)

# spectrogram to mel transform operator
_spec_to_mel_mat = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins=_num_mel_bins,
num_spectrogram_bins=_fft_length // 2 + 1,
sample_rate=sample_rate_in_hz,
lower_edge_hertz=mel_min_hz,
upper_edge_hertz=mel_max_hz,
dtype=tf.dtypes.float32,
)

@staticmethod
def _normalize_audio(
audio_batch: tf.Tensor) -> tf.Tensor: # pragma: no cover
min_ratio_for_normalization = tf.convert_to_tensor(
0.1, dtype=audio_batch.dtype
) # = 10**(max_db/-20) with max_db = 20
normalization_coeff = tf.maximum(
min_ratio_for_normalization,
tf.reduce_max(audio_batch, axis=-1, keepdims=True),
)
return audio_batch / normalization_coeff

@staticmethod
def _stabilized_log(
x: tf.Tensor, additive_offset: float, floor: float
) -> tf.Tensor: # pragma: no cover
"""TF version of mfcc_mel.StabilizedLog."""
return tf.math.log(tf.math.maximum(x, floor) + additive_offset)

def _extract_mel_features(
self,
audio_batch: tf.Tensor) -> tf.Tensor: # pragma: no cover
normalized_audio_batch = self._normalize_audio(audio_batch)
framed_audio = tf.signal.frame(
normalized_audio_batch,
VGGish.sample_rate_in_hz,
self._step_size_in_samples,
)
batched_framed_audio = tf.reshape(
framed_audio,
(
tf.shape(framed_audio)[0] * tf.shape(framed_audio)[1],
tf.shape(framed_audio)[2],
),
)
return tf.map_fn(self._log_mel_spectrogram, batched_framed_audio)

def _log_mel_spectrogram(self, audio: tf.Tensor
) -> tf.Tensor: # pragma: no cover
spectrogram = tf.abs(
tf.signal.stft(
tf.cast(audio, tf.dtypes.float32),
frame_length=self._window_length_samples,
frame_step=self._hop_length_samples,
fft_length=self._fft_length,
window_fn=tf.signal.hann_window,
)
)
# somehow the shapes don't really work by default,
# therefore we throw away two frames here, shouldn't matter
# in the big picture
mel = tf.matmul(spectrogram, self._spec_to_mel_mat)[1:-1]
return self._stabilized_log(
mel, self._log_additive_offset, self._log_floor
)

@classmethod
def _init_vggish_model(cls) -> tf.keras.Model:
model_path = os.path.dirname(
tf.keras.utils.get_file(
"vggish_model.tar.gz",
VGGISH_PUBLIC_MODEL_CHECKPOINT_URL,
extract=True,
cache_subdir="vggish",
)
)
return cls._assign_weights_to_model(
cls._load_vggish_weights(model_path),
cls._build_vggish_as_keras_model(),
)

@staticmethod
def _load_vggish_weights(saved_model_path: str) -> list[tf.Variable]:
weights = []
loaded_obj = tf.saved_model.load(saved_model_path)
for weight_name_in_orig_model in VGGish.var_names:
# accessing this protected member of this class was the only way I
# got this SOMEHOW to work at all... might break someday.
for weight_var in loaded_obj._variables:
if weight_var.name == weight_name_in_orig_model:
weights.append(weight_var)
return weights

@staticmethod
def _assign_weights_to_model(
weights: list[tf.Variable], keras_model: tf.keras.Model
) -> tf.keras.Model:
for layer in keras_model.layers:
for w in layer.trainable_weights:
w.assign(weights.pop(0))
assert len(weights) == 0
return keras_model

@staticmethod
def _build_vggish_as_keras_model():
conv_layer_kwargs = {
"kernel_size": (3, 3),
"strides": (1, 1),
"padding": "SAME",
"activation": "relu",
}
pool_layer_kwargs = {"strides": (2, 2), "padding": "SAME"}

input_layer = tf.keras.layers.Input(
shape=(VGGish.num_frames, VGGish.num_mel_bins)
)
x = tf.reshape(
input_layer, [-1, VGGish.num_frames,
VGGish.num_mel_bins, 1]
)
x = tf.keras.layers.Conv2D(64, **conv_layer_kwargs)(x)
x = tf.keras.layers.MaxPool2D(**pool_layer_kwargs)(x)
x = tf.keras.layers.Conv2D(128, **conv_layer_kwargs)(x)
x = tf.keras.layers.MaxPool2D(**pool_layer_kwargs)(x)
x = tf.keras.layers.Conv2D(256, **conv_layer_kwargs)(x)
x = tf.keras.layers.Conv2D(256, **conv_layer_kwargs)(x)
x = tf.keras.layers.MaxPool2D(**pool_layer_kwargs)(x)
x = tf.keras.layers.Conv2D(512, **conv_layer_kwargs)(x)
x = tf.keras.layers.Conv2D(512, **conv_layer_kwargs)(x)
x = tf.keras.layers.MaxPool2D(**pool_layer_kwargs)(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(4096, activation="relu")(x)
x = tf.keras.layers.Dense(4096, activation="relu")(x)
x = tf.keras.layers.Dense(
VGGish.embedding_size, activation=None
)(x)
embedding = tf.identity(x, name="embedding")
return tf.keras.Model(inputs=[input_layer], outputs=[embedding])
29 changes: 27 additions & 2 deletions tests/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6343,7 +6343,7 @@
]
)

EXPECTED_EMBEDDING = [
EXPECTED_EMBEDDING_FROM_TEST_INPUT = np.array([
-0.43252096,
-0.2533051,
-0.03891921,
Expand Down Expand Up @@ -6472,6 +6472,31 @@
-0.3327835,
-0.11444974,
-0.2591061,
]
])

EXPECTED_EMBEDDING_FROM_1S_1KHZ_AUDIO = np.array([
[-0.5641128, -0.33779827, 0.10225712, 0.11712199, -0.2527151, -0.8443508,
-0.25981572, 0.16792098, -0.95116365, -0.5236671, -0.20232812, -0.70718294,
-0.14579584, -0.07575923, 0.011736862, -0.087973, -0.48265845, 0.43599775,
-0.56677413, -0.22023663, -0.09335762, -0.06987491, -0.44508684,
-0.34173065, 0.08512226, 0.20812036, 0.026984155, -0.39980593, -0.04928696,
-0.3781742, -0.45287308, 0.4446297, 0.22336236, 0.81643176, 0.7233289,
-0.16205387, 0.10287689, 0.25761876, -0.4615103, -0.041211754, 0.76116765,
-0.61563987, 0.47007957, 0.049090154, 0.043110907, 0.46758178, 1.1136638,
0.64291775, 0.06420031, 0.009346962, 0.011331797, -0.7456863, -0.6542363,
-0.13669437, -0.18817484, -0.5706415, 0.3763576, -0.35318202, 0.22129454,
0.18423738, 0.009897411, -0.17536326, -0.2605414, -0.5434814, -0.22638988,
-0.21893057, 0.2974208, -0.25475532, -0.15125813, -0.1255785, 0.38621482,
0.304079, -0.32799524, -0.20186952, 0.23161955, -0.29869, -0.28763267,
0.55693054, 0.21256548, 0.16890164, -0.109760635, 0.12395367, -0.3432478,
-0.40493533, 0.6784371, 0.6638149, 0.47149, -0.011334762, -0.06363778,
0.50288475, -0.6486062, -0.6451371, -0.33297127, 0.27541167, 0.21874958,
0.48818892, 0.05484093, -0.21785353, -0.41137308, -0.21630639, -0.08038767,
0.342594, -0.24396843, 0.4491117, -0.82695603, -0.5479838, 0.30423838,
-0.63839394, 0.44466624, 0.04355956, -0.096717015, -0.448792, -0.30190068,
-0.20647198, -0.40457177, 0.09705981, -0.12927327, -0.08804405, -0.3071276,
0.13788812, 0.10009873, 0.13193467, -0.17012759, 0.052889988, -0.29090428,
-0.6154481, -0.2091115, -0.32193142]
])

REGRESSION_TEST_EXPECTED_VALUE = 1.0506146845541693
41 changes: 41 additions & 0 deletions tests/test_vggish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest

import numpy as np
import numpy.testing as npt
import tensorflow as tf

from frechet_audio_distance.vggish import VGGish
from .test_signals import EXPECTED_EMBEDDING_FROM_TEST_INPUT, VGGISH_TEST_INPUT, \
EXPECTED_EMBEDDING_FROM_1S_1KHZ_AUDIO


class VGGishTests(unittest.TestCase):

def test_vggish_model_computes_correct_embeddings_from_given_feature(self):
actual_embeddings = VGGish().model(VGGISH_TEST_INPUT)
self.assertTrue(
np.allclose(actual_embeddings, EXPECTED_EMBEDDING_FROM_TEST_INPUT,
atol=1e-6)
)

@staticmethod
def _generate_1s_1khz_test_signal() -> tf.Tensor:
test_signal_len_in_s = 9.
time = (np.arange(0, test_signal_len_in_s * VGGish.sample_rate_in_hz)
/ VGGish.sample_rate_in_hz)
test_signal_1khz = np.sin(2. * np.pi * 1000. * time)
test_signal_1khz_with_batch_dim = np.expand_dims(test_signal_1khz, 0)
return tf.convert_to_tensor(test_signal_1khz_with_batch_dim,
dtype=tf.float32)

def test_vggish_class_computes_correct_embeddings_from_audio(self):
actual_embeddings = VGGish()(self._generate_1s_1khz_test_signal())
self.assertTrue(np.allclose(actual_embeddings,
EXPECTED_EMBEDDING_FROM_1S_1KHZ_AUDIO))

def test_vggish_output_dim_is_correct(self):
self.assertEqual(VGGish().output_dim, 128)


if __name__ == '__main__':
unittest.main()

0 comments on commit 2b1ae4e

Please sign in to comment.