Skip to content

Commit

Permalink
add five-parameter logistic function as a model option (#813)
Browse files Browse the repository at this point in the history
* add five-parameter logistic function as a model

* add unit test for 5PL fit model

* change signature for  in  to match
  • Loading branch information
cosmin authored Feb 2, 2021
1 parent 34fcfeb commit bf60788
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 7 deletions.
9 changes: 9 additions & 0 deletions 5PL_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
model_type = "5PL"
model_param_dict = {
# ==== preprocess: normalize each feature ==== #
# 'norm_type':'none',
'norm_type': 'clip_0to1', # rescale to within [0, 1]

# ==== postprocess: clip final quality score ==== #
'score_clip':[0.0, 100.0], # clip to within [0, 100]
}
17 changes: 16 additions & 1 deletion python/test/train_test_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from vmaf.config import VmafConfig
from vmaf.core.train_test_model import TrainTestModel, \
LibsvmNusvrTrainTestModel, SklearnRandomForestTrainTestModel, \
MomentRandomForestTrainTestModel, SklearnExtraTreesTrainTestModel, SklearnLinearRegressionTrainTestModel
MomentRandomForestTrainTestModel, SklearnExtraTreesTrainTestModel, \
SklearnLinearRegressionTrainTestModel, Logistic5PLRegressionTrainTestModel
from vmaf.core.noref_feature_extractor import MomentNorefFeatureExtractor
from vmaf.routine import read_dataset
from vmaf.tools.misc import import_python_file
Expand Down Expand Up @@ -309,6 +310,20 @@ def test_train_predict_extratrees(self):
result = model.evaluate(xs, ys)
self.assertAlmostEqual(result['RMSE'], 0.042867322777879642, places=4)

def test_train_logistic_fit_5PL(self):
xs = Logistic5PLRegressionTrainTestModel.get_xs_from_results(self.features, [0, 1, 2, 3, 4, 5], features=['Moment_noref_feature_1st_score'])
ys = Logistic5PLRegressionTrainTestModel.get_ys_from_results(self.features, [0, 1, 2, 3, 4, 5])

xys = {}
xys.update(xs)
xys.update(ys)

model = Logistic5PLRegressionTrainTestModel({'norm_type': 'clip_0to1'}, None)
model.train(xys)
result = model.evaluate(xs, ys)

self.assertAlmostEqual(result['RMSE'], 0.3603374311919728, places=4)


class TrainTestModelWithDisYRawVideoExtractorTest(unittest.TestCase):

Expand Down
4 changes: 2 additions & 2 deletions python/vmaf/core/niqe_train_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def _assert_dimension(cls, feature_names, results):

@classmethod
@override(TrainTestModel)
def get_xs_from_results(cls, results, indexs=None, aggregate=False):
def get_xs_from_results(cls, results, indexs=None, aggregate=False, features=None):
"""
override by altering aggregate
default to False
"""
return super(NiqeTrainTestModel, cls).get_xs_from_results(
results, indexs, aggregate)
results, indexs, aggregate, features)

@classmethod
@override(TrainTestModel)
Expand Down
84 changes: 81 additions & 3 deletions python/vmaf/core/train_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def _delete(filename, **more):
os.remove(filename)

@classmethod
def get_xs_from_results(cls, results, indexs=None, aggregate=True):
def get_xs_from_results(cls, results, indexs=None, aggregate=True, features=None):
"""
:param results: list of BasicResult, or pandas.DataFrame
:param indexs: indices of results to be used
Expand All @@ -756,8 +756,11 @@ def get_xs_from_results(cls, results, indexs=None, aggregate=True):
# or get_ordered_list_scores_key. Instead, just get the sorted keys
feature_names = results[0].get_ordered_results()

feature_names = list(feature_names)
cls._assert_dimension(feature_names, results)
if features is not None:
feature_names = [f for f in feature_names if f in features]
else:
feature_names = list(feature_names)
cls._assert_dimension(feature_names, results)

# collect results into xs
xs = {}
Expand Down Expand Up @@ -1156,6 +1159,81 @@ def _predict(cls, model, xs_2d):
ys_label_pred = model.predict(xs_2d)
return ys_label_pred

class Logistic5PLRegressionTrainTestModel(TrainTestModel, RegressorMixin):

TYPE = '5PL'
VERSION = "0.1"

@classmethod
def _train(cls, model_param, xys_2d, **kwargs):
"""
Fit the following 5PL curve using scipy.optimize.curve_fit
Q(x) = B1 + (1/2 - 1/(1 + exp(B2 * (x - B3)))) + B4 * x + B5
H. R. Sheikh, M. F. Sabir, and A. C. Bovik,
"A statistical evaluation of recent full reference image quality assessment algorithms"
IEEE Trans. Image Process., vol. 15, no. 11, pp. 3440–3451, Nov. 2006.
:param model_param:
:param xys_2d:
:return:
"""
model_param_ = model_param.copy()

# remove keys unassociated with sklearn
if 'norm_type' in model_param_:
del model_param_['norm_type']
if 'score_clip' in model_param_:
del model_param_['score_clip']
if 'custom_clip_0to1_map' in model_param_:
del model_param_['custom_clip_0to1_map']
if 'num_models' in model_param_:
del model_param_['num_models']

from scipy.optimize import curve_fit
[[b1, b2, b3, b4, b5], _] = curve_fit(
lambda x, b1, b2, b3, b4, b5: b1 + (0.5 - 1/(1+np.exp(b2*(x-b3))))+b4*x+b5,
np.ravel(xys_2d[:, 1]),
np.ravel(xys_2d[:, 0]),
p0=0.5 * np.ones((5,)),
maxfev=20000
)

return dict(b1=b1, b2=b2, b3=b3, b4=b4, b5=b5)

@staticmethod
@override(TrainTestModel)
def _to_file(filename, param_dict, model_dict, **more):
format = more['format'] if 'format' in more else 'pkl'
supported_formats = ['pkl', 'json']
assert format in supported_formats, \
f'format must be in {supported_formats}, but got: {format}'

info_to_save = {'param_dict': param_dict,
'model_dict': model_dict.copy()}

if format == 'pkl':
with open(filename, 'wb') as file:
pickle.dump(info_to_save, file)
elif format == 'json':
with open(filename, 'wt') as file:
json.dump(info_to_save, file, indent=4)
else:
assert False

@classmethod
def _predict(cls, model, xs_2d):
b1 = model['b1']
b2 = model['b2']
b3 = model['b3']
b4 = model['b4']
b5 = model['b5']

curve = lambda x: b1 + (0.5 - 1/(1+np.exp(b2*(x-b3))))+b4*x+b5
predicted = [curve(x) for x in np.ravel(xs_2d)]

return predicted

class RawVideoTrainTestModelMixin(object):
"""
Expand Down
3 changes: 3 additions & 0 deletions resource/feature_param/ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
feature_dict = {
'SSIM_feature': ['ssim'],
}
8 changes: 7 additions & 1 deletion unittest
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#!/usr/bin/env sh

PYTHONPATH=python python3 -m unittest discover -v -s python/test/ -p '*_test.py'
if [ -z "$1" ]; then
pattern='*_test.py'
else
pattern="$1"
fi

PYTHONPATH=python python3 -m unittest discover -v -s python/test/ -p $pattern

0 comments on commit bf60788

Please sign in to comment.