From 58920c82c27bbe7797dd3b363eb62f9061b854ef Mon Sep 17 00:00:00 2001 From: Valentin Belyaev Date: Tue, 6 Feb 2024 16:33:06 +0300 Subject: [PATCH] Add R2 metric + test --- .gitignore | 2 +- setup.py | 2 +- .../__pycache__/__init__.cpython-310.pyc | Bin 220 -> 220 bytes .../__pycache__/boosting.cpython-310.pyc | Bin 5361 -> 5361 bytes .../__pycache__/xgboost.cpython-310.pyc | Bin 9261 -> 9374 bytes src/ensemble/catboost.py | 56 ++++++++++++++++ .../classification.cpython-310.pyc | Bin 3648 -> 3648 bytes .../__pycache__/regression.cpython-310.pyc | Bin 895 -> 1350 bytes src/metrics/regression.py | 23 +++++-- ...test_r2_score.cpython-310-pytest-7.4.4.pyc | Bin 0 -> 2410 bytes tests/metrics/regression/test_r2_score.py | 63 ++++++++++++++++++ 11 files changed, 140 insertions(+), 6 deletions(-) create mode 100644 src/ensemble/catboost.py create mode 100644 tests/metrics/regression/__pycache__/test_r2_score.cpython-310-pytest-7.4.4.pyc create mode 100644 tests/metrics/regression/test_r2_score.py diff --git a/.gitignore b/.gitignore index 3a0b019..1eddeb1 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,4 @@ src/__pycache__ build dist -tulia.egg-info \ No newline at end of file +tulia.egg-info diff --git a/setup.py b/setup.py index 754278e..f9064da 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -VERSION = '0.2.1' +VERSION = '0.3.0' DESCRIPTION = 'numpy based machine learning package with sklearn-like API' with open("README.md", "r") as fn: diff --git a/src/ensemble/__pycache__/__init__.cpython-310.pyc b/src/ensemble/__pycache__/__init__.cpython-310.pyc index 2c38bf5c08389fccfe149be4cb9d2fd75d661c71..be836631b94345015a6fa034aeeece191987aba4 100644 GIT binary patch delta 18 Ycmcb^c!!ZIpO=@50SJ^GCvsf^04a0?DF6Tf delta 18 Ycmcb^c!!ZIpO=@50SH3nCURW@04d@GH~;_u diff --git a/src/ensemble/__pycache__/boosting.cpython-310.pyc b/src/ensemble/__pycache__/boosting.cpython-310.pyc index d8f4cb4142de3da9aaf234914726db1f2df2a234..cdb0c24847a3e8397c99d2939bfe65ca03d3aa2b 100644 GIT binary patch delta 19 ZcmeyU`B9TApO=@50SJ^GH*!4}0RT3i1nd9+ delta 19 ZcmeyU`B9TApO=@50SIhuH*!4}0RT5A1p)v7 diff --git a/src/ensemble/__pycache__/xgboost.cpython-310.pyc b/src/ensemble/__pycache__/xgboost.cpython-310.pyc index 47242cb845ade29b0efdb2eac1462c6369bdd81a..890dc234061b401c6fc03d9fa3799673b6069069 100644 GIT binary patch delta 888 zcmZvZKWNlo7{>GFlH4VCxi-E3ZLiJURjRZW1*w9A)Wr&(+UlUE^gR21=|44)oFG@U z{^?Y3>9=)oadL2wh&TvNZgvnsp}2O|!C8F27Dr`P-6X*3N+B7fe9>PiB414 z&rBj|8a9A~M0)lOE8`H+-iVT$oWp1nv;O=hz1=5<(_qDkl_W>GfEBEgA%nF`(xC%8 z)|E4-XW4?se1FNkbJ=68Gj7;)b)nrRYCEoHlh99IcXjmGC2~lzx%u+CVL~~K0c`Oh$_Gps$MfyAy6$aWZi}D=YNpUeJ-_!QXWvA~rri9La3M2R zLp@=$37dmFvc|u)P>AnksX|r!pVAxYPy`Rl%>T+t*f5I)I+US;SqrMvM#jNrHK-%w zPBoi>2C|ui#cHfv(j zh9`yNr%{ze5ioI47REN3gM+!T7DAy|1&Bw+QS(s1N%e#+CJK2_Le`kDX1pMU1}IIu zp~?1fLG>;B?}GXfe_nqflP){eJelg^bE8$(@f+-xQ6Ss=b90?o*95y;z*QFdyR5kC zl8Avl@|TIpzxX9WcJl36C#eRJ?J_lJRo8^;Ou!Z1K7{@3NU+epBEpV np.ndarray: + pass + + def _predict(self, x: np.ndarray) -> Union[np.ndarray, float, int]: + pass + + def _encode_cat_features(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: + cat_feature_idxs = [] + + encoded_features = [] + for idx in cat_feature_idxs: + encoded_feature = [] + option_count = {} + total_count = {} + for i, x_sample in enumerate(x[:, idx]): + ctr = (option_count.get(x_sample, 0) + 0.05) / (total_count.get(x_sample, 0) + 1) + encoded_feature.append(ctr) + + if y[i] == 1: + option_count[x_sample] = option_count.get(x_sample, 0) + 1 + total_count[x_sample] = total_count.get(x_sample, 0) + 1 + + encoded_features.append(np.array(encoded_feature)) + + x[:, cat_feature_idxs] = encoded_features + return x diff --git a/src/metrics/__pycache__/classification.cpython-310.pyc b/src/metrics/__pycache__/classification.cpython-310.pyc index 974888af35980b83ddae0feee4cf6ff4d0ab2c99..ce7a319f9639045a867bada467ea711f165b9744 100644 GIT binary patch delta 52 zcmX>gb3leWpO=@50SJ^Gw{7IM+&S3~2xW delta 52 zcmX>gb3leWpO=@50SGKFY}v?d$;)<&xul@z)?_ck(=TyrB;^KllEbrID&U@SmZ2fo zoym*bela>Gha2LTlF float: """ Calculate mean-squared error. - :param y_true: Target labels. - :param y_pred: Target predictions. + :param y_true: Target labels (n_examples, ). + :param y_pred: Target predictions (n_examples, ). :return: Loss. """ n_examples = len(y_true) @@ -17,9 +17,24 @@ def mean_squared_error(y_true: np.ndarray, y_pred: np.ndarray) -> float: def mean_absolute_error(y_true: np.ndarray, y_pred: np.ndarray) -> float: """ Calculate mean-absolute error. - :param y_true: Target labels. - :param y_pred: Target predictions. + :param y_true: Target labels (n_examples, ). + :param y_pred: Target predictions (n_examples, ). :return: Loss. """ error = np.mean(np.abs(y_true - y_pred)) return error + + +def r2_score(y_true: np.ndarray, y_pred: np.ndarray) -> float: + """ + Calculate R-squared. + :param y_true: Target labels (n_examples, ). + :param y_pred: Target predictions (n_examples, ). + :return: R-squared score. + """ + + tss = np.sum((y_true - np.mean(y_true))**2) + rss = np.sum((y_true - y_pred)**2) + + r_squared = 1 - rss / tss + return r_squared diff --git a/tests/metrics/regression/__pycache__/test_r2_score.cpython-310-pytest-7.4.4.pyc b/tests/metrics/regression/__pycache__/test_r2_score.cpython-310-pytest-7.4.4.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9fb0d82f4432bfe91093313ca697f2901f910a4 GIT binary patch literal 2410 zcmbtV&2Jk;6yII%de>{mDT0QUuSS7FjHs13X%i5wiz>tgMHL7Mt&CP1&&1hk?X}*l zQ!LgOiUb$_01$_mn{OQ8j<_Mg4GC#aNUg9|>H*XX>H+cItevKj-h5NdB?;REwI`J zr+{-grhle5HZxcr$09SCg=2{on2qBcE3y)fWj4pkIL=P6;L3LZ(tYyB1Fo*+Csm+Rfp=}QB7sqK#R0kiPbyWAQvgOas!>^1_t(gtjAiM zi;XzX^e^>+8JnUE_`DyC<7}h;EEkvL-mgODP~<)20C^F zdwzhq#(ddZBO_HY4=GhD%@ODX2ErISw!iA*zjyC`@xwp&u8lND_o6UxbU=LA^@QLM zEnZyi_E%Rb!v16`EnbzAwdouRlPNT-`l13X;CZuk5k@;f5Q2Vv6-L}KG?#6{$KkeqYHZrBU(Ox^{z6@=Sf;C8$=PYSN^ zb~sHZ<+o%${IK&jKnTf`?oJ48FLKf2!0UJsu8ay!zVAlR<1}`GGbp1kFpuCA$&W0l ztFy}GS!H8ZdEJq|XXmW0judzfH+uNqjW-$}U~NQW&kJ~`akJmt^`Lzd!dsX+(P#&k z+9BhCXi$Kt^9cN=XaH{k7Xtm13z*@IMAdFTIVW?0Gd|;;$j206VlbBiiLX2$R+ZJV zs%y6Tn+7kSdMxuE9LqdS@=!&F{hxU{GEaYmdBy?r@-uuont7&V2{G@GS4T6?Vui;r zk9cLpM={SnG4qOBKy0fRKvVGzn?Dk*-a;JZ~M`%8}_2W$J}qyDbiP}a^*~7O&@xo3zD37vsb>Fy}S7IRe|3I zKfubFe`FF)$C^z9gbk6ws_r-cR{ik#C=*709|%$$|0&