Skip to content

Commit

Permalink
Add pypsnr fex and subclasses; add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
li-zhi committed Mar 2, 2021
1 parent 30e4cb7 commit e698b4d
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 1 deletion.
144 changes: 143 additions & 1 deletion python/test/feature_extractor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
MomentFeatureExtractor, \
PsnrFeatureExtractor, SsimFeatureExtractor, MsSsimFeatureExtractor, \
VifFrameDifferenceFeatureExtractor, \
AnsnrFeatureExtractor, VmafIntegerFeatureExtractor
AnsnrFeatureExtractor, PypsnrFeatureExtractor, VmafIntegerFeatureExtractor, \
PypsnrMaxdb100FeatureExtractor
from vmaf.core.asset import Asset
from vmaf.core.result_store import FileSystemResultStore

Expand Down Expand Up @@ -626,6 +627,147 @@ def test_run_psnr_fextractor_proc(self):
self.assertAlmostEqual(results[0]['PSNR_feature_psnr_score'], 27.645446604166665, places=8)
self.assertAlmostEqual(results[1]['PSNR_feature_psnr_score'], 31.87683660416667, places=8)

def test_run_pypsnr_fextractor(self):

ref_path, dis_path, asset, asset_original = set_default_576_324_videos_for_testing()

self.fextractor = PypsnrFeatureExtractor(
[asset, asset_original],
None, fifo_mode=True,
result_store=None
)
self.fextractor.run(parallelize=True)

results = self.fextractor.results

self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 30.755063979166664, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 38.449441057158786, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 40.9919102486235, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 60.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 60.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 60.0, places=4)

def test_run_pypsnr_fextractor_10bit(self):

ref_path, dis_path, asset, asset_original = set_default_576_324_10bit_videos_for_testing()

self.fextractor = PypsnrFeatureExtractor(
[asset, asset_original],
None, fifo_mode=True,
result_store=None
)
self.fextractor.run(parallelize=True)

results = self.fextractor.results

self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 30.780573260053277, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 38.769832063651364, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 41.28418847734209, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 72.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 72.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 72.0, places=4)

def test_run_pypsnr_fextractor_10bit_b(self):

ref_path, dis_path, asset, asset_original = set_default_576_324_10bit_videos_for_testing_b()

self.fextractor = PypsnrFeatureExtractor(
[asset, asset_original],
None, fifo_mode=True,
result_store=None
)
self.fextractor.run(parallelize=True)

results = self.fextractor.results

self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 32.57145231892744, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 39.03859552689696, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 41.28060001337217, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 72.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 72.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 72.0, places=4)

def test_run_pypsnr_fextractor_12bit(self):

ref_path, dis_path, asset, asset_original = set_default_576_324_12bit_videos_for_testing()

self.fextractor = PypsnrFeatureExtractor(
[asset, asset_original],
None, fifo_mode=True,
result_store=None
)
self.fextractor.run(parallelize=True)

results = self.fextractor.results

self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 32.577817940053734, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 39.044961148023255, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 41.28696563449846, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 84.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 84.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 84.0, places=4)

def test_run_pypsnr_fextractor_16bit(self):

ref_path, dis_path, asset, asset_original = set_default_576_324_16bit_videos_for_testing()

self.fextractor = PypsnrFeatureExtractor(
[asset, asset_original],
None, fifo_mode=True,
result_store=None
)
self.fextractor.run(parallelize=True)

results = self.fextractor.results

self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 32.579806240311484, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 39.046949448281005, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 41.288953934756215, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 108.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 108.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 108.0, places=4)

def test_run_pypsnr_fextractor_16bit_custom_max_db(self):

ref_path, dis_path, asset, asset_original = set_default_576_324_16bit_videos_for_testing()

self.fextractor = PypsnrFeatureExtractor(
[asset, asset_original],
None, fifo_mode=True,
result_store=None,
optional_dict={'max_db': 100.0}
)
self.fextractor.run(parallelize=True)

results = self.fextractor.results

self.assertAlmostEqual(results[0]['Pypsnr_feature_psnry_score'], 32.579806240311484, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnru_score'], 39.046949448281005, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_feature_psnrv_score'], 41.288953934756215, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnry_score'], 100.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnru_score'], 100.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_feature_psnrv_score'], 100.0, places=4)

def test_run_pypsnr_fextractor_maxdb100_16bit(self):

ref_path, dis_path, asset, asset_original = set_default_576_324_16bit_videos_for_testing()

self.fextractor = PypsnrMaxdb100FeatureExtractor(
[asset, asset_original],
None, fifo_mode=True,
result_store=None,
)
self.fextractor.run(parallelize=True)

results = self.fextractor.results

self.assertAlmostEqual(results[0]['Pypsnr_maxdb100_feature_psnry_score'], 32.579806240311484, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_maxdb100_feature_psnru_score'], 39.046949448281005, places=4)
self.assertAlmostEqual(results[0]['Pypsnr_maxdb100_feature_psnrv_score'], 41.288953934756215, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_maxdb100_feature_psnry_score'], 100.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_maxdb100_feature_psnru_score'], 100.0, places=4)
self.assertAlmostEqual(results[1]['Pypsnr_maxdb100_feature_psnrv_score'], 100.0, places=4)


if __name__ == '__main__':
unittest.main(verbosity=2)
111 changes: 111 additions & 0 deletions python/vmaf/core/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,117 @@ def _post_process_result(cls, result):
return result


class PypsnrFeatureExtractor(FeatureExtractor):

TYPE = "Pypsnr_feature"
VERSION = "1.0"

ATOM_FEATURES = ['psnry', 'psnru', 'psnrv']

@staticmethod
def _assert_bit_depth(ref_yuv_reader, dis_yuv_reader):
if ref_yuv_reader._is_8bit():
assert dis_yuv_reader._is_8bit()
elif ref_yuv_reader._is_10bitle():
assert dis_yuv_reader._is_10bitle()
elif ref_yuv_reader._is_12bitle():
assert dis_yuv_reader._is_12bitle()
elif ref_yuv_reader._is_16bitle():
assert dis_yuv_reader._is_16bitle()
else:
assert False, 'unknown bit depth and type'

def _get_max_db(self, ref_yuv_reader):
if self.optional_dict is not None and 'max_db' in self.optional_dict:
assert type(self.optional_dict['max_db']) == int or float
return self.optional_dict['max_db']
elif ref_yuv_reader._is_8bit():
return 60.0
elif ref_yuv_reader._is_10bitle():
return 72.0
elif ref_yuv_reader._is_12bitle():
return 84.0
elif ref_yuv_reader._is_16bitle():
return 108.0
else:
assert False, 'unknown bit depth and type'

def _generate_result(self, asset):
quality_w, quality_h = asset.quality_width_height
yuv_type = self._get_workfile_yuv_type(asset)
log_dicts = list()
with YuvReader(filepath=asset.ref_procfile_path, width=quality_w, height=quality_h,
yuv_type=yuv_type) as ref_yuv_reader:
with YuvReader(filepath=asset.dis_procfile_path, width=quality_w, height=quality_h,
yuv_type=yuv_type) as dis_yuv_reader:

self._assert_bit_depth(ref_yuv_reader, dis_yuv_reader)
max_db = self._get_max_db(ref_yuv_reader)

frm = 0
while True:
try:
ref_yuv = ref_yuv_reader.next(format='float')
dis_yuv = dis_yuv_reader.next(format='float')
except StopIteration:
break

ref_y, ref_u, ref_v = ref_yuv
dis_y, dis_u, dis_v = dis_yuv
mse_y, mse_u, mse_v = np.mean((ref_y - dis_y)**2) + 1e-16, \
np.mean((ref_u - dis_u)**2) + 1e-16, \
np.mean((ref_v - dis_v)**2) + 1e-16
psnr_y, psnr_u, psnr_v = min(10 * np.log10(1.0 / mse_y), max_db), \
min(10 * np.log10(1.0 / mse_u), max_db), \
min(10 * np.log10(1.0 / mse_v), max_db)

log_dicts.append({
'frame': frm,
'psnry': psnr_y,
'psnru': psnr_u,
'psnrv': psnr_v,
})

frm += 1

log_file_path = self._get_log_file_path(asset)
with open(log_file_path, 'wt') as log_file:
log_file.write(str(log_dicts))

@override(FeatureExtractor)
def _get_feature_scores(self, asset):

log_file_path = self._get_log_file_path(asset)

with open(log_file_path, 'rt') as log_file:
log_str = log_file.read()
log_dicts = ast.literal_eval(log_str)

feature_result = dict()
frm = 0
for log_dict in log_dicts:
assert frm == log_dict['frame']
for ft in self.ATOM_FEATURES:
feature_result.setdefault(self.get_scores_key(ft), []).append(log_dict[ft])
frm += 1

return feature_result


class PypsnrMaxdb100FeatureExtractor(PypsnrFeatureExtractor):

TYPE = "Pypsnr_maxdb100_feature"

@override(Executor)
def _custom_init(self):
super()._custom_init()
if self.optional_dict is not None:
assert 'max_db' not in self.optional_dict
if self.optional_dict is None:
self.optional_dict = dict()
self.optional_dict['max_db'] = 100.0


class PsnrFeatureExtractor(VmafexecFeatureExtractorMixin, FeatureExtractor):

TYPE = "PSNR_feature"
Expand Down

0 comments on commit e698b4d

Please sign in to comment.