diff --git a/extension_templates/distributions.py b/extension_templates/distributions.py new file mode 100644 index 00000000..64339338 --- /dev/null +++ b/extension_templates/distributions.py @@ -0,0 +1,379 @@ +"""Extension template for probability distributions - simple pattern.""" +# todo: write an informative docstring for the file or module, remove the above +# todo: add an appropriate copyright notice for your estimator +# estimators contributed to skpro should have the copyright notice at the top +# estimators of your own do not need to have permissive or BSD-3 copyright + +# todo: uncomment the following line, enter authors' GitHub IDs +# __author__ = [authorGitHubID, anotherAuthorGitHubID] + +from skpro.distributions.base import BaseDistribution + +# todo: add any necessary imports here - no soft dependency imports + +# todo: for imports of skpro soft dependencies: +# make sure to fill in the "python_dependencies" tag with the package import name +# import soft dependencies only inside methods of the class, not at the top of the file + + +# todo: change class name and write docstring +class ClassName(BaseDistribution): + """Custom probability distribution. todo: write docstring. + + todo: describe your custom probability distribution here + + Parameters + ---------- + parama : float or np.ndarray + descriptive explanation of parama + paramb : float or np.ndarray, optional (default='default') + descriptive explanation of paramb + """ + + # todo: fill out estimator tags here + # tags are inherited from parent class if they are not set + # tags inherited from base are "safe defaults" which can usually be left as-is + _tags = { + # packaging info + # -------------- + "authors": ["author1", "author2"], # authors, GitHub handles + "maintainers": ["maintainer1", "maintainer2"], # maintainers, GitHub handles + # author = significant contribution to code at some point + # maintainer = algorithm maintainer role, "owner" + # specify one or multiple authors and maintainers, only for skpro contribution + # remove maintainer tag if maintained by skpro/sktim core team + # + "python_version": None, # PEP 440 python version specifier to limit versions + "python_dependencies": None, # PEP 440 python dependencies specifier, + # e.g., "numba>0.53", or a list, e.g., ["numba>0.53", "numpy>=1.19.0"] + # delete if no python dependencies or version limitations + # + # estimator tags + # -------------- + "distr:measuretype": "continuous", # one of "discrete", "continuous", "mixed" + # these tags should correspond to which methods are numerically exact + # and which are approximations, e.g., using Monte Carlo + "capabilities:approx": ["pdfnorm", "energy"], + "capabilities:exact": ["mean", "var", "pdf", "log_pdf", "cdf", "ppf"], + # leave the broadcast_init tag as-is, this tag exists for compatibility with + # distributions deviating from assumptions on input parameters, e.g., Empirical + "broadcast_init": "on", + } + + # todo: fill init + # params should be written to self and never changed + # super call must not be removed, change class name + # parameter checks can go after super call + def __init__(self, param1, param2="param2default", index=None, columns=None): + # all distributions must have index and columns arg with None defaults + # this is to ensure pandas-like behaviour + + # todo: write any hyper-parameters and components to self + self.param1 = param1 + self.param2 = param2 + + # leave this as is + super().__init__(index=index, columns=columns) + + # todo: optional, parameter checking logic (if applicable) should happen here + # if writes derived values to self, should *not* overwrite self.parama etc + # instead, write to self._parama, self._newparam (starting with _) + + # todo: implement as many of the following methods as possible + # if not implemented, the base class will try to fill it in + # from the other implemented methods + # at least _ppf, or sample should be implemented for the distribution to be usable + # if _ppf is implemented, sample does not need to be implemented (uses ppf sampling) + + # todo: consider implementing + # if not implemented, uses Monte Carlo estimate via sample + def _mean(self): + """Return expected value of the distribution. + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + expected value of distribution (entry-wise) + """ + param1 = self._bc_params["param1"] # returns broadcast params to x.shape + param2 = self._bc_params["param2"] # returns broadcast params to x.shape + + res = "do_sth_with(" + param1 + param2 + ")" # replace this by internal logic + return res + + # todo: consider implementing + # if not implemented, uses Monte Carlo estimate via sample + def _var(self): + r"""Return element/entry-wise variance of the distribution. + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + variance of the distribution (entry-wise) + """ + param1 = self._bc_params["param1"] # returns broadcast params to x.shape + param2 = self._bc_params["param2"] # returns broadcast params to x.shape + + res = "do_sth_with(" + param1 + param2 + ")" # replace this by internal logic + return res + + # todo: consider implementing - only for continuous or mixed distributions + # at least one of _pdf and _log_pdf should be implemented + # if not implemented, returns exp of log_pdf + def _pdf(self, x): + """Probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + pdf values at the given points + """ + param1 = self._bc_params["param1"] # returns broadcast params to x.shape + param2 = self._bc_params["param2"] # returns broadcast params to x.shape + + res = "do_sth_with(" + param1 + param2 + ")" # replace this by internal logic + return res + + # todo: consider implementing - only for continuous or mixed distributions + # at least one of _pdf and _log_pdf should be implemented + # if not implemented, returns log of pdf + def _log_pdf(self, x): + """Logarithmic probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + log pdf values at the given points + """ + param1 = self._bc_params["param1"] # returns broadcast params to x.shape + param2 = self._bc_params["param2"] # returns broadcast params to x.shape + + res = "do_sth_with(" + param1 + param2 + ")" # replace this by internal logic + return res + + # todo: consider implementing - only for discrete or mixed distributions + # at least one of _pmf and _log_pmf should be implemented + # if not implemented, returns exp of log_pmf + def _pmf(self, x): + """Probability mass function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pmf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + pmf values at the given points + """ + param1 = self._bc_params["param1"] # returns broadcast params to x.shape + param2 = self._bc_params["param2"] # returns broadcast params to x.shape + + res = "do_sth_with(" + param1 + param2 + ")" # replace this by internal logic + return res + + # todo: consider implementing - only for discrete or mixed distributions + # at least one of _pmf and _log_pmf should be implemented + # if not implemented, returns log of pmf + def _log_pmf(self, x): + """Logarithmic probability mass function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pmf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + log pmf values at the given points + """ + param1 = self._bc_params["param1"] # returns broadcast params to x.shape + param2 = self._bc_params["param2"] # returns broadcast params to x.shape + + res = "do_sth_with(" + param1 + param2 + ")" # replace this by internal logic + return res + + # todo: consider implementing + # at least one of _ppf and sample must be implemented + # if not implemented, uses Monte Carlo estimate based on sample + def _cdf(self, x): + """Cumulative distribution function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the cdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + cdf values at the given points + """ + param1 = self._bc_params["param1"] # returns broadcast params to x.shape + param2 = self._bc_params["param2"] # returns broadcast params to x.shape + + res = "do_sth_with(" + param1 + param2 + ")" # replace this by internal logic + return res + + # todo: consider implementing + # at least one of _ppf and sample must be implemented + # if not implemented, uses bisection method on cdf + def _ppf(self, p): + """Quantile function = percent point function = inverse cdf. + + Parameters + ---------- + p : 2D np.ndarray, same shape as ``self`` + values to evaluate the ppf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + ppf values at the given points + """ + pass + + # todo: consider implementing + # if not implemented, uses Monte Carlo estimate via sample + def _energy_self(self): + r"""Energy of self, w.r.t. self. + + :math:`\mathbb{E}[|X-Y|]`, where :math:`X, Y` are i.i.d. copies of self. + + Private method, to be implemented by subclasses. + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + energy values w.r.t. the given points + """ + param1 = self._bc_params["param1"] # returns broadcast params to x.shape + param2 = self._bc_params["param2"] # returns broadcast params to x.shape + + res = "do_sth_with(" + param1 + param2 + ")" # replace this by internal logic + return res + + # todo: consider implementing + # if not implemented, uses Monte Carlo estimate via sample + def _energy_x(self, x): + r"""Energy of self, w.r.t. a constant frame x. + + :math:`\mathbb{E}[|X-x|]`, where :math:`X` is a copy of self, + and :math:`x` is a constant. + + Private method, to be implemented by subclasses. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to compute energy w.r.t. to + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + energy values w.r.t. the given points + """ + param1 = self._bc_params["param1"] # returns broadcast params to x.shape + param2 = self._bc_params["param2"] # returns broadcast params to x.shape + + res = "do_sth_with(" + param1 + param2 + ")" # replace this by internal logic + return res + + # todo: consider implementing + # at least one of _ppf and sample must be implemented + # if not implemented, uses _ppf for sampling (inverse cdf on uniform) + def sample(self, n_samples=None): + """Sample from the distribution. + + Parameters + ---------- + n_samples : int, optional, default = None + + Returns + ------- + if `n_samples` is `None`: + returns a sample that contains a single sample from `self`, + in `pd.DataFrame` mtype format convention, with `index` and `columns` as `self` + if n_samples is `int`: + returns a `pd.DataFrame` that contains `n_samples` i.i.d. samples from `self`, + in `pd-multiindex` mtype format convention, with same `columns` as `self`, + and `MultiIndex` that is product of `RangeIndex(n_samples)` and `self.index` + """ + param1 = self._bc_params["param1"] # returns broadcast params to x.shape + param2 = self._bc_params["param2"] # returns broadcast params to x.shape + + res = "do_sth_with(" + param1 + param2 + ")" # replace this by internal logic + return res + + # todo: return default parameters, so that a test instance can be created + # required for automated unit and integration testing of estimator + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + + # todo: set the testing parameters for the estimators + # Testing parameters can be dictionary or list of dictionaries + # + # this can, if required, use: + # class properties (e.g., inherited); parent class test case + # imported objects such as estimators from skpro or sklearn + # important: all such imports should be *inside get_test_params*, not at the top + # since imports are used only at testing time + # + # A parameter dictionary must be returned *for all values* of parameter_set, + # i.e., "parameter_set not available" errors should never be raised. + # + # A good parameter set should primarily satisfy two criteria, + # 1. Chosen set of parameters should have a low testing time, + # ideally in the magnitude of few seconds for the entire test suite. + # This is vital for the cases where default values result in + # "big" models which not only increases test time but also + # run into the risk of test workers crashing. + # 2. There should be a minimum two such parameter sets with different + # sets of values to ensure a wide range of code coverage is provided. + # + # example 1: specify params as dictionary + # any number of params can be specified + # params = {"est": value0, "parama": value1, "paramb": value2} + # + # example 2: specify params as list of dictionary + # note: Only first dictionary will be used by create_test_instance + # params = [{"est": value1, "parama": value2}, + # {"est": value3, "parama": value4}] + # + # example 3: parameter set depending on param_set value + # note: only needed if a separate parameter set is needed in tests + # if parameter_set == "special_param_set": + # params = {"est": value1, "parama": value2} + # return params + # + # # "default" params + # params = {"est": value3, "parama": value4} + # return params diff --git a/extension_templates/regression.py b/extension_templates/regression.py index 99ed1fea..00151ff5 100644 --- a/extension_templates/regression.py +++ b/extension_templates/regression.py @@ -319,8 +319,6 @@ def get_test_params(cls, parameter_set="default"): # # The parameter_set argument is not used for most automated, module level tests. # It can be used in custom, estimator specific tests, for "special" settings. - # For classification, this is also used in tests for reference settings, - # such as published in benchmarking studies, or for identity testing. # A parameter dictionary must be returned *for all values* of parameter_set, # i.e., "parameter_set not available" errors should never be raised. # diff --git a/extension_templates/survival.py b/extension_templates/survival.py index da0994bc..f270761d 100644 --- a/extension_templates/survival.py +++ b/extension_templates/survival.py @@ -306,8 +306,6 @@ def get_test_params(cls, parameter_set="default"): # # The parameter_set argument is not used for most automated, module level tests. # It can be used in custom, estimator specific tests, for "special" settings. - # For classification, this is also used in tests for reference settings, - # such as published in benchmarking studies, or for identity testing. # A parameter dictionary must be returned *for all values* of parameter_set, # i.e., "parameter_set not available" errors should never be raised. # diff --git a/skpro/distributions/base/_base.py b/skpro/distributions/base/_base.py index 5b5bd2db..d28207eb 100644 --- a/skpro/distributions/base/_base.py +++ b/skpro/distributions/base/_base.py @@ -11,7 +11,6 @@ import pandas as pd from skpro.base import BaseObject -from skpro.utils.pandas import df_map from skpro.utils.validation._dependencies import _check_estimator_deps @@ -23,13 +22,22 @@ class BaseDistribution(BaseObject): "object_type": "distribution", # type of object, e.g., 'distribution' "python_version": None, # PEP 440 python version specifier to limit versions "python_dependencies": None, # string or str list of pkg soft dependencies - "reserved_params": ["index", "columns"], - "capabilities:approx": ["energy", "mean", "var", "pdfnorm"], + # default parameter settings for MC estimates + # ------------------------------------------- + # these are used in default implementations of mean, var, energy, pdfnorm, ppf "approx_mean_spl": 1000, # sample size used in MC estimates of mean "approx_var_spl": 1000, # sample size used in MC estimates of var "approx_energy_spl": 1000, # sample size used in MC estimates of energy "approx_spl": 1000, # sample size used in other MC estimates "bisect_iter": 1000, # max iters for bisection method in ppf + # which methods are approximate (not numerically exact) should be listed here + "capabilities:approx": ["energy", "mean", "var", "pdfnorm"], + # broadcasting and parameter settings + # ----------------------------------- + # used to control broadcasting of parameters + "reserved_params": ["index", "columns"], + "broadcast_params": None, # list of params to broadcast + "broadcast_init": "off", # whether to auto-broadcast params in __init__ } def __init__(self, index=None, columns=None): @@ -39,6 +47,37 @@ def __init__(self, index=None, columns=None): super().__init__() _check_estimator_deps(self) + self._init_shape_bc(index=index, columns=columns) + + def _init_shape_bc(self, index=None, columns=None): + """Initialize shape and broadcasting of distribution parameters. + + Subclasses may choose to override this, if + default broadcasting and pre-initialization is not desired or applicable, + e.g., distribution parameters are not array-like. + + If overriden, must set ``self._shape``: this should be an empty tuple + if the distribution is scalar, or a pair of integers otherwise. + """ + if self.get_tags()["broadcast_init"] == "off": + if index is None and columns is None: + self._shape = () + else: + self._shape = (len(index), len(columns)) + return None + + # if broadcast_init os on or other, run this auto-init + bc_params, shape, is_scalar = self._get_bc_params_dict(return_shape=True) + self._bc_params = bc_params + self._is_scalar = is_scalar + self._shape = shape + + if index is None and self.ndim > 0: + self.index = pd.RangeIndex(shape[0]) + + if columns is None and self.ndim > 0: + self.columns = pd.RangeIndex(shape[1]) + @property def loc(self): """Location indexer. @@ -74,11 +113,19 @@ def iloc(self): @property def shape(self): """Shape of self, a pair (2-tuple).""" - return (len(self.index), len(self.columns)) + return self._shape + + @property + def ndim(self): + """Number of dimensions of self. 2 if array, 0 if scalar.""" + return len(self._shape) def __len__(self): """Length of self, number of rows.""" - return len(self.index) + shape = self._shape + if len(shape) == 0: + return 1 + return shape[0] def _loc(self, rowidx=None, colidx=None): if rowidx is not None: @@ -96,14 +143,16 @@ def _subset_params(self, rowidx, colidx): subset_param_dict = {} for param, val in params.items(): - if val is not None: - arr = np.array(val) - else: - arr = None + if val is None: + subset_param_dict[param] = None + continue + arr = np.array(val) # if len(arr.shape) == 0: # do nothing with arr - if len(arr.shape) >= 1 and rowidx is not None: - arr = arr[rowidx] + if len(arr.shape) == 1 and colidx is not None: + arr = arr[colidx] + if len(arr.shape) == 2 and rowidx is not None: + arr = arr[rowidx, :] if len(arr.shape) >= 2 and colidx is not None: arr = arr[:, colidx] if np.issubdtype(arr.dtype, np.integer): @@ -168,7 +217,7 @@ def _method_error_msg(self, method="this method", severity="warn", fill_in=None) else: return msg - def _get_bc_params(self, *args, dtype=None, oned_as="row"): + def _get_bc_params(self, *args, dtype=None, oned_as="row", return_shape=False): """Fully broadcast tuple of parameters given param shapes and index, columns. Parameters @@ -183,19 +232,25 @@ def _get_bc_params(self, *args, dtype=None, oned_as="row"): oned_as : str, optional, "row" (default) or "col" If 'row', then 1D arrays are treated as row vectors. If 'column', then 1D arrays are treated as column vectors. + return_shape : bool, optional, default=False + If True, return shape tuple, and a boolean tuple + indicating which parameters are scalar. Returns ------- Tuple of float or integer arrays Each element of the tuple represents a different broadcastable distribution parameter. + shape : Tuple, only returned if ``return_shape`` is True + Shape of the broadcasted parameters. + Pair of row/column if not scalar, empty tuple if scalar. + is_scalar : Tuple of bools, only returned if ``return_is_scalar`` is True + Each element of the tuple is True if the corresponding parameter is scalar. """ number_of_params = len(args) if number_of_params == 0: # Handle case where no positional arguments are provided - params = self.get_params() - params.pop("index") - params.pop("columns") + params = self._get_dist_params() args = tuple(params.values()) number_of_params = len(args) @@ -216,39 +271,190 @@ def row_to_col(arr): bc = np.broadcast_arrays(*args_as_np) if dtype is not None: bc = [array.astype(dtype) for array in bc] - return bc[:number_of_params] + bc = bc[:number_of_params] + if return_shape: + shape = bc[0].shape + is_scalar = tuple([arr.ndim == 0 for arr in bc]) + return bc, shape, is_scalar + return bc + + def _get_bc_params_dict( + self, dtype=None, oned_as="row", return_shape=False, **kwargs + ): + """Fully broadcast dict of parameters given param shapes and index, columns. + + Parameters + ---------- + kwargs : float, int, array of floats, or array of ints (1D or 2D) + Distribution parameters that are to be made broadcastable. If no positional + arguments are provided, all parameters of `self` are used except for `index` + and `columns`. + dtype : str, optional + broadcasted arrays are cast to all have datatype `dtype`. If None, then no + datatype casting is done. + oned_as : str, optional, "row" (default) or "col" + If 'row', then 1D arrays are treated as row vectors. If 'column', then 1D + arrays are treated as column vectors. + return_shape : bool, optional, default=False + If True, return shape tuple, and a boolean tuple + indicating which parameters are scalar. + + Returns + ------- + dict of float or integer arrays + Each element of the tuple represents a different broadcastable distribution + parameter. + shape : Tuple, only returned if ``return_shape`` is True + Shape of the broadcasted parameters. + Pair of row/column if not scalar, empty tuple if scalar. + is_scalar : Tuple of bools, only returned if ``return_is_scalar`` is True + Each element of the tuple is True if the corresponding parameter is scalar. + """ + number_of_params = len(kwargs) + if number_of_params == 0: + # Handle case where no positional arguments are provided + kwargs = self._get_dist_params() + number_of_params = len(kwargs) + + def row_to_col(arr): + """Convert 1D arrays to 2D col arrays, leave 2D arrays unchanged.""" + if arr.ndim == 1 and oned_as == "col": + return arr.reshape(-1, 1) + return arr + + kwargs_as_np = {k: row_to_col(np.array(v)) for k, v in kwargs.items()} + + if hasattr(self, "index") and self.index is not None: + kwargs_as_np["index"] = self.index.to_numpy().reshape(-1, 1) + if hasattr(self, "columns") and self.columns is not None: + kwargs_as_np["columns"] = self.columns.to_numpy() + + bc_params = self.get_tags()["broadcast_params"] + if bc_params is None: + bc_params = kwargs_as_np.keys() + + args_as_np = [kwargs_as_np[k] for k in bc_params] + bc = np.broadcast_arrays(*args_as_np) + if dtype is not None: + bc = [array.astype(dtype) for array in bc] + + shape = () + for i, k in enumerate(bc_params): + kwargs_as_np[k] = row_to_col(bc[i]) + if bc[i].ndim > 0: + shape = bc[i].shape + + # special case: user provided iterables so it broadcasts to 1D + # this is interpreted as a row vector, i.e., one multivariate distr + if len(shape) == 1: + shape = (1, shape[0]) + for k, v in kwargs_as_np.items(): + kwargs_as_np[k] = np.expand_dims(v, 0) + + if return_shape: + is_scalar = tuple([arr.ndim == 0 for arr in bc]) + return kwargs_as_np, shape, is_scalar + return kwargs_as_np + + def _boilerplate(self, method, columns=None, **kwargs): + """Broadcasting boilerplate for distribution methods. + + Used to link public methods to private methods, + handles coercion, broadcasting, and checks. + + Parameters + ---------- + method : str + Name of the method to be called, e.g., '_pdf' + columns : None (default) or pd.Index coercible + If not None, set return columns to this value + kwargs : dict + Keyword arguments to the method + Checks and broadcasts are applied to all values in kwargs + + Examples + -------- + >>> self._boilerplate('_pdf', x=x) # doctest: +SKIP + >>> # calls self._pdf(x=x_inner), broadcasting x to self's shape in x_inner + """ + kwargs_inner = kwargs.copy() + d = self + + for k, x in kwargs.items(): + # if x is a DataFrame, subset and reorder distribution to match it + if isinstance(x, pd.DataFrame): + d = self.loc[x.index, x.columns] + x_inner = x.values + # else, coerce to a numpy array if needed + # then, broadcast it to the shape of self + else: + x_inner = self._coerce_to_self_index_np(x, flatten=False) + kwargs_inner[k] = x_inner + + # pass the broadcasted values to the private method + res = getattr(d, method)(**kwargs_inner) + + # if the result is not a DataFrame, coerce it to one + # ensur the index and columns are the same as d + if not isinstance(res, pd.DataFrame) and self.ndim > 1: + if columns is not None: + res_cols = pd.Index(columns) + else: + res_cols = d.columns + res = pd.DataFrame(res, index=d.index, columns=res_cols) + # if numpy scalar, convert to python scalar, e.g., float + if isinstance(res, np.ndarray) and self.ndim == 0: + res = res[()] + return res def pdf(self, x): r"""Probability density function. - Let :math:`X` be a random variables with the distribution of `self`, - taking values in `(N, n)` `DataFrame`-s + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in ``(N, n)`` ``DataFrame``-s Let :math:`x\in \mathbb{R}^{N\times n}`. By :math:`p_{X_{ij}}`, denote the marginal pdf of :math:`X` at the :math:`(i,j)`-th entry. - The output of this method, for input `x` representing :math:`x`, - is a `DataFrame` with same columns and indices as `self`, + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, and entries :math:`p_{X_{ij}}(x_{ij})`. + If ``self`` has a mixed or discrete distribution, this returns + the weighted continuous part of `self`'s distribution instead of the pdf, + i.e., the marginal pdf integrate to the weight of the continuous part. + Parameters ---------- - x : `pandas.DataFrame` or 2D np.ndarray + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` representing :math:`x`, as above Returns ------- - `DataFrame` with same columns and index as `self` + ``pd.DataFrame`` with same columns and index as ``self`` containing :math:`p_{X_{ij}}(x_{ij})`, as above """ - if self._has_implementation_of("log_pdf"): + return self._boilerplate("_pdf", x=x) + + def _pdf(self, x): + """Probability density function. + + Private method, to be implemented by subclasses. + """ + self_has_logpdf = self._has_implementation_of("log_pdf") + self_has_logpdf = self_has_logpdf or self._has_implementation_of("_log_pdf") + if self_has_logpdf: approx_method = ( "by exponentiating the output returned by the log_pdf method, " "this may be numerically unstable" ) warn(self._method_error_msg("pdf", fill_in=approx_method)) - return df_map(self.log_pdf(x=x))(np.exp) + x_df = pd.DataFrame(x, index=self.index, columns=self.columns) + res = self.log_pdf(x=x_df) + if isinstance(res, pd.DataFrame): + res = res.values + return np.exp(res) raise NotImplementedError(self._method_error_msg("pdf", "error")) @@ -257,43 +463,82 @@ def log_pdf(self, x): Numerically more stable than calling pdf and then taking logartihms. - Let :math:`X` be a random variables with the distribution of `self`, - taking values in `(N, n)` `DataFrame`-s + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in `(N, n)` ``DataFrame``-s Let :math:`x\in \mathbb{R}^{N\times n}`. By :math:`p_{X_{ij}}`, denote the marginal pdf of :math:`X` at the :math:`(i,j)`-th entry. - The output of this method, for input `x` representing :math:`x`, - is a `DataFrame` with same columns and indices as `self`, + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, and entries :math:`\log p_{X_{ij}}(x_{ij})`. - If `self` has a mixed or discrete distribution, this returns + If ``self`` has a mixed or discrete distribution, this returns the weighted continuous part of `self`'s distribution instead of the pdf, i.e., the marginal pdf integrate to the weight of the continuous part. Parameters ---------- - x : `pandas.DataFrame` or 2D np.ndarray + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` representing :math:`x`, as above Returns ------- - `DataFrame` with same columns and index as `self` + ``pd.DataFrame`` with same columns and index as ``self`` containing :math:`\log p_{X_{ij}}(x_{ij})`, as above """ - if self._has_implementation_of("pdf"): + return self._boilerplate("_log_pdf", x=x) + + def _log_pdf(self, x): + """Logarithmic probability density function. + + Private method, to be implemented by subclasses. + """ + if self._has_implementation_of("pdf") or self._has_implementation_of("_pdf"): approx_method = ( "by taking the logarithm of the output returned by the pdf method, " "this may be numerically unstable" ) warn(self._method_error_msg("log_pdf", fill_in=approx_method)) - return df_map(self.pdf(x=x))(np.log) + x_df = pd.DataFrame(x, index=self.index, columns=self.columns) + res = self.pdf(x=x_df) + if isinstance(res, pd.DataFrame): + res = res.values + return np.log(res) raise NotImplementedError(self._method_error_msg("log_pdf", "error")) def cdf(self, x): - """Cumulative distribution function.""" + r"""Cumulative distribution function. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in ``(N, n)`` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`F_{X_{ij}}`, denote the marginal cdf of :math:`X` at the + :math:`(i,j)`-th entry. + + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`F_{X_{ij}}(x_{ij})`. + + Parameters + ---------- + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` + representing :math:`x`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`F_{X_{ij}}(x_{ij})`, as above + """ + return self._boilerplate("_cdf", x=x) + + def _cdf(self, x): + """Cumulative distribution function. + + Private method, to be implemented by subclasses. + """ N = self.get_tag("approx_spl") approx_method = ( "by approximating the expected value by the indicator function on " @@ -301,6 +546,9 @@ def cdf(self, x): ) warn(self._method_error_msg("mean", fill_in=approx_method)) + if not isinstance(x, pd.DataFrame): + x = pd.DataFrame(x, index=self.index, columns=self.columns) + # TODO: ensure this works for scalar x splx = pd.concat([x] * N, keys=range(N)) spl = self.sample(N) ind = splx <= spl @@ -308,8 +556,46 @@ def cdf(self, x): return ind.groupby(level=1, sort=False).mean() def ppf(self, p): - """Quantile function = percent point function = inverse cdf.""" - if self._has_implementation_of("cdf"): + r"""Quantile function = percent point function = inverse cdf. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in ``(N, n)`` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`F_{X_{ij}}`, denote the marginal cdf of :math:`X` at the + :math:`(i,j)`-th entry. + + The output of this method, for input ``p`` representing :math:`p`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`F^{-1}_{X_{ij}}(p_{ij})`. + + Parameters + ---------- + p : ``pandas.DataFrame`` or 2D np.ndarray + representing :math:`p`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`F_{X_{ij}}(x_{ij})`, as above + """ + return self._boilerplate("_ppf", p=p) + + def _ppf(self, p): + """Quantile function = percent point function = inverse cdf. + + Private method, to be implemented by subclasses. + + Parameters + ---------- + p : 2D np.ndarray, same shape as ``self`` + values to evaluate the ppf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + ppf values at the given points + """ + if self._has_implementation_of("cdf") or self._has_implementation_of("_cdf"): from scipy.optimize import bisect max_iter = self.get_tag("bisect_iter") @@ -319,26 +605,45 @@ def ppf(self, p): ) warn(self._method_error_msg("cdf", fill_in=approx_method)) - result = pd.DataFrame(index=p.index, columns=p.columns, dtype="float") - for ix in p.index: - for col in p.columns: + def bisect_unb(opt_fun, **kwargs): + """Unbound version of bisect.""" + left_bd = -1e6 + right_bd = 1e6 + while opt_fun(left_bd) > 0: + left_bd *= 10 + while opt_fun(right_bd) < 0: + right_bd *= 10 + return bisect(opt_fun, left_bd, right_bd, maxiter=max_iter, **kwargs) + + shape = self.shape + + # TODO: remove duplications in the code below + # requires cdf to accept numpy, or allow subsetting to produce scalar + if len(shape) == 0: + + def opt_fun(x): + """Optimization function, to find x s.t. cdf(x) = p_ix.""" + return d_ix.cdf(x) - p # noqa: B023 + + result = bisect_unb(opt_fun) + return result + + n_row, n_col = self.shape + result = np.array([[0.0] * n_col] * n_row, dtype=float) + + for i in range(n_row): + for j in range(n_col): + ix = self.index[i] + col = self.columns[j] d_ix = self.loc[[ix], [col]] - p_ix = p.loc[ix, col] + p_ix = p[i, j] def opt_fun(x): """Optimization function, to find x s.t. cdf(x) = p_ix.""" x = pd.DataFrame(x, index=[ix], columns=[col]) # noqa: B023 return d_ix.cdf(x).values[0][0] - p_ix # noqa: B023 - left_bd = -1e6 - right_bd = 1e6 - while opt_fun(left_bd) > 0: - left_bd *= 10 - while opt_fun(right_bd) < 0: - right_bd *= 10 - result.loc[ix, col] = bisect( - opt_fun, left_bd, right_bd, maxiter=max_iter - ) + result[i, j] = bisect_unb(opt_fun) return result raise NotImplementedError(self._method_error_msg("ppf", "error")) @@ -346,27 +651,82 @@ def opt_fun(x): def energy(self, x=None): r"""Energy of self, w.r.t. self or a constant frame x. - Let :math:`X, Y` be i.i.d. random variables with the distribution of `self`. + Let :math:`X, Y` be i.i.d. random variables with the distribution of ``self``. + + If ``x`` is ``None``, returns :math:`\mathbb{E}[|X-Y|]` (per row), + "self-energy". + If ``x`` is passed, returns :math:`\mathbb{E}[|X-x|]` (per row), "energy wrt x". - If `x` is `None`, returns :math:`\mathbb{E}[|X-Y|]` (for each row), - "self-energy" (of the row marginal distribution). - If `x` is passed, returns :math:`\mathbb{E}[|X-x|]` (for each row), - "energy wrt x" (of the row marginal distribution). + The CRPS is related to energy: + it holds that + :math:`\mbox{CRPS}(\mbox{self}, y)` = `self.energy(y) - 0.5 * self.energy()`. Parameters ---------- x : None or pd.DataFrame, optional, default=None - if pd.DataFrame, must have same rows and columns as `self` + if ``pd.DataFrame``, must have same rows and columns as ``self`` + + Returns + ------- + ``pd.DataFrame`` with same rows as ``self``, single column ``"energy"`` + each row contains one float, self-energy/energy as described above. + """ + if x is None: + return self._boilerplate("_energy_self", columns=["energy"]) + return self._boilerplate("_energy_x", x=x, columns=["energy"]) + + def _energy_self(self): + r"""Energy of self, w.r.t. self. + + :math:`\mathbb{E}[|X-Y|]`, where :math:`X, Y` are i.i.d. copies of self. + + Private method, to be implemented by subclasses. + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + energy values w.r.t. the given points + """ + return self._energy_default() + + def _energy_x(self, x): + r"""Energy of self, w.r.t. a constant frame x. + + :math:`\mathbb{E}[|X-x|]`, where :math:`X` is a copy of self, + and :math:`x` is a constant. + + Private method, to be implemented by subclasses. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to compute energy w.r.t. to + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + energy values w.r.t. the given points + """ + return self._energy_default(x) + + def _energy_default(self, x=None): + """Energy of self, w.r.t. self or a constant frame x. + + Default implementation, using Monte Carlo estimates. + + Parameters + ---------- + x : None or 2D np.ndarray, same shape as ``self`` + values to compute energy w.r.t. to Returns ------- - pd.DataFrame with same rows as `self`, single column `"energy"` - each row contains one float, self-energy/energy as described above. + 2D np.ndarray, same shape as ``self`` + energy values w.r.t. the given points """ # we want to approximate E[abs(X-Y)] # if x = None, X,Y are i.i.d. copies of self # if x is not None, X=x (constant), Y=self - approx_spl_size = self.get_tag("approx_energy_spl") approx_method = ( "by approximating the energy expectation by the arithmetic mean of " @@ -380,6 +740,7 @@ def energy(self, x=None): splx = self.sample(N) sply = self.sample(N) else: + x = pd.DataFrame(x, index=self.index, columns=self.columns) splx = pd.concat([x] * N, keys=range(N)) sply = self.sample(N) @@ -387,19 +748,27 @@ def energy(self, x=None): spl = splx - sply energy = spl.apply(np.linalg.norm, axis=1, ord=1) energy = energy.groupby(level=1, sort=False).mean() - energy = pd.DataFrame(energy, index=self.index, columns=["energy"]) + if self.ndim > 0: + energy = pd.DataFrame(energy, index=self.index, columns=["energy"]) return energy def mean(self): r"""Return expected value of the distribution. - Let :math:`X` be a random variable with the distribution of `self`. + Let :math:`X` be a random variable with the distribution of ``self``. Returns the expectation :math:`\mathbb{E}[X]` Returns ------- - pd.DataFrame with same rows, columns as `self` - expected value of distribution (entry-wise) + ``pd.DataFrame`` with same rows, columns as ``self`` + expected value of distribution (entry-wise) + """ + return self._boilerplate("_mean") + + def _mean(self): + """Return expected value of the distribution. + + Private method, to be implemented by subclasses. """ approx_spl_size = self.get_tag("approx_mean_spl") approx_method = ( @@ -414,13 +783,21 @@ def mean(self): def var(self): r"""Return element/entry-wise variance of the distribution. - Let :math:`X` be a random variable with the distribution of `self`. - Returns :math:`\mathbb{V}[X] = \mathbb{E}\left(X - \mathbb{E}[X]\right)^2` + Let :math:`X` be a random variable with the distribution of ``self``. + Returns :math:`\mathbb{V}[X] = \mathbb{E}\left(X - \mathbb{E}[X]\right)^2`, + where the square is element-wise. Returns ------- - pd.DataFrame with same rows, columns as `self` - variance of distribution (entry-wise) + ``pd.DataFrame`` with same rows, columns as ``self`` + variance of distribution (entry-wise) + """ + return self._boilerplate("_var") + + def _var(self): + """Return element/entry-wise variance of the distribution. + + Private method, to be implemented by subclasses. """ approx_spl_size = self.get_tag("approx_var_spl") approx_method = ( @@ -467,13 +844,44 @@ def pdfnorm(self, a=2): spl_df = pd.concat(spl, keys=range(approx_spl_size)) return spl_df.groupby(level=1, sort=False).mean() - def _coerce_to_self_index_df(self, x): + def _coerce_to_self_index_df(self, x, flatten=True): + """Coerce input to type similar to self. + + If self is not scalar with index and columns, + coerces x to a pd.DataFrame with index and columns as self. + + If self is scalar, coerces x to a scalar (0D) np.ndarray. + """ + x = np.array(x) + if flatten: + x = x.reshape(1, -1) + df_shape = self.shape + x = np.broadcast_to(x, df_shape) + if self.ndim != 0: + df = pd.DataFrame(x, index=self.index, columns=self.columns) + return df + return x + + def _coerce_to_self_index_np(self, x, flatten=False): + """Coerce input to type similar to self. + + Coerces x to a np.ndarray with same shape as self. + Broadcasts x to self.shape, if necessary, via np.broadcast_to. + + Parameters + ---------- + x : array-like, np.ndarray coercible + input to be coerced to self + flatten : bool, optional, default=True + if True, flattens x before broadcasting + if False, broadcasts x as is + """ x = np.array(x) - x = x.reshape(1, -1) + if flatten: + x = x.reshape(1, -1) df_shape = self.shape x = np.broadcast_to(x, df_shape) - df = pd.DataFrame(x, index=self.index, columns=self.columns) - return df + return x def quantile(self, alpha): """Return entry-wise quantiles, in Proba/pred_quantiles mtype format. @@ -539,16 +947,21 @@ def sample(self, n_samples=None): def gen_unif(): np_unif = np.random.uniform(size=self.shape) - return pd.DataFrame(np_unif, index=self.index, columns=self.columns) + if self.ndim > 0: + return pd.DataFrame(np_unif, index=self.index, columns=self.columns) + return np_unif # if ppf is implemented, we use inverse transform sampling - if self._has_implementation_of("ppf"): + if self._has_implementation_of("_ppf") or self._has_implementation_of("ppf"): if n_samples is None: return self.ppf(gen_unif()) - else: - pd_smpl = [self.ppf(gen_unif()) for _ in range(n_samples)] + # else, we generate n_samples i.i.d. samples + pd_smpl = [self.ppf(gen_unif()) for _ in range(n_samples)] + if self.ndim > 0: df_spl = pd.concat(pd_smpl, keys=range(n_samples)) - return df_spl + else: + df_spl = pd.DataFrame(pd_smpl) + return df_spl raise NotImplementedError(self._method_error_msg("sample", "error")) diff --git a/skpro/distributions/empirical.py b/skpro/distributions/empirical.py index 1ddaad39..0886d18c 100644 --- a/skpro/distributions/empirical.py +++ b/skpro/distributions/empirical.py @@ -67,6 +67,8 @@ def __init__(self, spl, weights=None, time_indep=True, index=None, columns=None) if columns is None: columns = spl.columns + self._shape = (len(index), len(columns)) + super().__init__(index=index, columns=columns) # initialized sorted samples diff --git a/skpro/distributions/fisk.py b/skpro/distributions/fisk.py index aa292e48..6bdd2dc6 100644 --- a/skpro/distributions/fisk.py +++ b/skpro/distributions/fisk.py @@ -12,7 +12,7 @@ class Fisk(BaseDistribution): r"""Fisk distribution, aka log-logistic distribution. - The Fisk distibution is parametrized by a scale parameter :math:`\alpha` + The Fisk distribution is parametrized by a scale parameter :math:`\alpha` and a shape parameter :math:`\beta`, such that the cumulative distribution function (CDF) is given by: @@ -38,80 +38,118 @@ class Fisk(BaseDistribution): "capabilities:approx": ["energy", "pdfnorm"], "capabilities:exact": ["mean", "var", "pdf", "log_pdf", "cdf", "ppf"], "distr:measuretype": "continuous", + "broadcast_init": "on", } def __init__(self, alpha=1, beta=1, index=None, columns=None): self.alpha = alpha self.beta = beta - self.index = index - self.columns = columns - # todo: untangle index handling - # and broadcast of parameters. - # move this functionality to the base class - # important: if only one argument, it is a lenght-1-tuple, deal with this - self._alpha, self._beta = self._get_bc_params(self.alpha, self.beta) - shape = self._alpha.shape + super().__init__(index=index, columns=columns) - if index is None: - index = pd.RangeIndex(shape[0]) + def _mean(self): + """Return expected value of the distribution. - if columns is None: - columns = pd.RangeIndex(shape[1]) + Returns + ------- + 2D np.ndarray, same shape as ``self`` + expected value of distribution (entry-wise) + """ + alpha = self._bc_params["alpha"] + beta = self._bc_params["beta"] - super().__init__(index=index, columns=columns) + mean_arr = fisk.mean(scale=alpha, c=beta) + return mean_arr - def mean(self): - r"""Return expected value of the distribution. + def _var(self): + r"""Return element/entry-wise variance of the distribution. - Let :math:`X` be a random variable with the distribution of `self`. - Returns the expectation :math:`\mathbb{E}[X]` + Returns + ------- + 2D np.ndarray, same shape as ``self`` + variance of the distribution (entry-wise) + """ + alpha = self._bc_params["alpha"] + beta = self._bc_params["beta"] + + var_arr = fisk.var(scale=alpha, c=beta) + return var_arr + + def _pdf(self, x): + """Probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at Returns ------- - pd.DataFrame with same rows, columns as `self` - expected value of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + pdf values at the given points """ - mean_arr = fisk.mean(scale=self._alpha, c=self._beta) - return pd.DataFrame(mean_arr, index=self.index, columns=self.columns) + alpha = self._bc_params["alpha"] + beta = self._bc_params["beta"] - def var(self): - r"""Return element/entry-wise variance of the distribution. + pdf_arr = fisk.pdf(x, scale=alpha, c=beta) + return pdf_arr + + def _log_pdf(self, x): + """Logarithmic probability density function. - Let :math:`X` be a random variable with the distribution of `self`. - Returns :math:`\mathbb{V}[X] = \mathbb{E}\left(X - \mathbb{E}[X]\right)^2` + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at Returns ------- - pd.DataFrame with same rows, columns as `self` - variance of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + log pdf values at the given points """ - var_arr = fisk.var(scale=self._alpha, c=self._beta) - return pd.DataFrame(var_arr, index=self.index, columns=self.columns) - - def pdf(self, x): - """Probability density function.""" - d = self.loc[x.index, x.columns] - pdf_arr = fisk.pdf(x.values, scale=d.alpha, c=d.beta) - return pd.DataFrame(pdf_arr, index=x.index, columns=x.columns) - - def log_pdf(self, x): - """Logarithmic probability density function.""" - d = self.loc[x.index, x.columns] - lpdf_arr = fisk.logpdf(x.values, scale=d.alpha, c=d.beta) - return pd.DataFrame(lpdf_arr, index=x.index, columns=x.columns) - - def cdf(self, x): - """Cumulative distribution function.""" - d = self.loc[x.index, x.columns] - cdf_arr = fisk.cdf(x.values, scale=d.alpha, c=d.beta) - return pd.DataFrame(cdf_arr, index=x.index, columns=x.columns) - - def ppf(self, p): - """Quantile function = percent point function = inverse cdf.""" - d = self.loc[p.index, p.columns] - icdf_arr = fisk.ppf(p.values, scale=d.alpha, c=d.beta) - return pd.DataFrame(icdf_arr, index=p.index, columns=p.columns) + alpha = self._bc_params["alpha"] + beta = self._bc_params["beta"] + + lpdf_arr = fisk.logpdf(x, scale=alpha, c=beta) + return lpdf_arr + + def _cdf(self, x): + """Cumulative distribution function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the cdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + cdf values at the given points + """ + alpha = self._bc_params["alpha"] + beta = self._bc_params["beta"] + + cdf_arr = fisk.cdf(x, scale=alpha, c=beta) + return cdf_arr + + def _ppf(self, p): + """Quantile function = percent point function = inverse cdf. + + Parameters + ---------- + p : 2D np.ndarray, same shape as ``self`` + values to evaluate the ppf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + ppf values at the given points + """ + alpha = self._bc_params["alpha"] + beta = self._bc_params["beta"] + + icdf_arr = fisk.ppf(p, scale=alpha, c=beta) + return icdf_arr @classmethod def get_test_params(cls, parameter_set="default"): diff --git a/skpro/distributions/laplace.py b/skpro/distributions/laplace.py index 5cfaa4cf..0ff88c8e 100644 --- a/skpro/distributions/laplace.py +++ b/skpro/distributions/laplace.py @@ -10,11 +10,26 @@ class Laplace(BaseDistribution): - """Laplace distribution. + r"""Laplace distribution. + + This distribution is univariate, without correlation between dimensions + for the array-valued case. + + The Laplace distribution is parametrized by mean :math:`\mu` and + scale :math:`b`, such that the pdf is + + .. math:: f(x) = \frac{1}{2b} \exp\left(-\frac{|x - \mu|}{b}\right) + + The mean :math:`\mu` is represented by the parameter ``mu``, + and the scale :math:`b` by the parameter ``scale``. + + It should be noted that this parametrization differs from the mean/standard + deviation parametrization, which is also common in the literature. + The standard deviation of this distribution is :math:`\sqrt{2} s`. Parameters ---------- - mean : float or array of float (1D or 2D) + mu : float or array of float (1D or 2D) mean of the distribution scale : float or array of float (1D or 2D), must be positive scale parameter of the distribution, same as standard deviation / sqrt(2) @@ -32,115 +47,160 @@ class Laplace(BaseDistribution): "capabilities:approx": ["pdfnorm"], "capabilities:exact": ["mean", "var", "energy", "pdf", "log_pdf", "cdf", "ppf"], "distr:measuretype": "continuous", + "broadcast_init": "on", } def __init__(self, mu, scale, index=None, columns=None): self.mu = mu self.scale = scale - self.index = index - self.columns = columns - # todo: untangle index handling - # and broadcast of parameters. - # move this functionality to the base class - self._mu, self._scale = self._get_bc_params(self.mu, self.scale) - shape = self._mu.shape + super().__init__(index=index, columns=columns) + + def _energy_self(self): + r"""Energy of self, w.r.t. self. - if index is None: - index = pd.RangeIndex(shape[0]) + :math:`\mathbb{E}[|X-Y|]`, where :math:`X, Y` are i.i.d. copies of self. - if columns is None: - columns = pd.RangeIndex(shape[1]) + Private method, to be implemented by subclasses. - super().__init__(index=index, columns=columns) + Returns + ------- + 2D np.ndarray, same shape as ``self`` + energy values w.r.t. the given points + """ + energy_arr = self._bc_params["scale"] + if energy_arr.ndim > 0: + energy_arr = np.sum(energy_arr, axis=1) * 1.5 + return energy_arr - def energy(self, x=None): - r"""Energy of self, w.r.t. self or a constant frame x. + def _energy_x(self, x): + r"""Energy of self, w.r.t. a constant frame x. - Let :math:`X, Y` be i.i.d. random variables with the distribution of `self`. + :math:`\mathbb{E}[|X-x|]`, where :math:`X` is a copy of self, + and :math:`x` is a constant. - If `x` is `None`, returns :math:`\mathbb{E}[|X-Y|]` (per row), "self-energy". - If `x` is passed, returns :math:`\mathbb{E}[|X-x|]` (per row), "energy wrt x". + Private method, to be implemented by subclasses. Parameters ---------- - x : None or pd.DataFrame, optional, default=None - if pd.DataFrame, must have same rows and columns as `self` + x : 2D np.ndarray, same shape as ``self`` + values to compute energy w.r.t. to Returns ------- - pd.DataFrame with same rows as `self`, single column `"energy"` - each row contains one float, self-energy/energy as described above. + 2D np.ndarray, same shape as ``self`` + energy values w.r.t. the given points """ - if x is None: - sc_arr = self._scale - energy_arr = np.sum(sc_arr, axis=1) * 1.5 - energy = pd.DataFrame(energy_arr, index=self.index, columns=["energy"]) - else: - d = self.loc[x.index, x.columns] - mu_arr, sc_arr = d.mu, d.scale - y_arr = np.abs((x.values - mu_arr) / sc_arr) - c_arr = y_arr + np.exp(-y_arr) - energy_arr = np.sum(sc_arr * c_arr, axis=1) - energy = pd.DataFrame(energy_arr, index=self.index, columns=["energy"]) - return energy - - def mean(self): - r"""Return expected value of the distribution. - - Let :math:`X` be a random variable with the distribution of `self`. - Returns the expectation :math:`\mathbb{E}[X]` + mu = self._bc_params["mu"] + sc = self._bc_params["scale"] + + y_arr = np.abs((x - mu) / sc) + c_arr = y_arr + np.exp(-y_arr) + energy_arr = sc * c_arr + if energy_arr.ndim > 0: + energy_arr = np.sum(energy_arr, axis=1) + return energy_arr + + def _mean(self): + """Return expected value of the distribution. Returns ------- - pd.DataFrame with same rows, columns as `self` - expected value of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + expected value of distribution (entry-wise) """ - mean_arr = self._mu - return pd.DataFrame(mean_arr, index=self.index, columns=self.columns) + return self._bc_params["mu"] - def var(self): + def _var(self): r"""Return element/entry-wise variance of the distribution. - Let :math:`X` be a random variable with the distribution of `self`. - Returns :math:`\mathbb{V}[X] = \mathbb{E}\left(X - \mathbb{E}[X]\right)^2` + Returns + ------- + 2D np.ndarray, same shape as ``self`` + variance of the distribution (entry-wise) + """ + sc = self._bc_params["scale"] + sd_arr = np.sqrt(2) * sc + return sd_arr**2 + + def _pdf(self, x): + """Probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + pdf values at the given points + """ + mu = self._bc_params["mu"] + sc = self._bc_params["scale"] + pdf_arr = np.exp(-np.abs((x - mu) / sc)) + pdf_arr = pdf_arr / (2 * sc) + return pdf_arr + + def _log_pdf(self, x): + """Logarithmic probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at Returns ------- - pd.DataFrame with same rows, columns as `self` - variance of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + log pdf values at the given points """ - sd_arr = self._scale / np.sqrt(2) - return pd.DataFrame(sd_arr, index=self.index, columns=self.columns) ** 2 - - def pdf(self, x): - """Probability density function.""" - d = self.loc[x.index, x.columns] - pdf_arr = np.exp(-np.abs((x.values - d.mu) / d.scale)) - pdf_arr = pdf_arr / (2 * d.scale) - return pd.DataFrame(pdf_arr, index=x.index, columns=x.columns) - - def log_pdf(self, x): - """Logarithmic probability density function.""" - d = self.loc[x.index, x.columns] - lpdf_arr = -np.abs((x.values - d.mu) / d.scale) - lpdf_arr = lpdf_arr - np.log(2 * d.scale) - return pd.DataFrame(lpdf_arr, index=x.index, columns=x.columns) - - def cdf(self, x): - """Cumulative distribution function.""" - d = self.loc[x.index, x.columns] - sgn_arr = np.sign(x.values - d.mu) - exp_arr = np.exp(-np.abs((x.values - d.mu) / d.scale)) + mu = self._bc_params["mu"] + sc = self._bc_params["scale"] + lpdf_arr = -np.abs((x - mu) / sc) + lpdf_arr = lpdf_arr - np.log(2 * sc) + return lpdf_arr + + def _cdf(self, x): + """Cumulative distribution function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the cdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + cdf values at the given points + """ + mu = self._bc_params["mu"] + sc = self._bc_params["scale"] + + sgn_arr = np.sign(x - mu) + exp_arr = np.exp(-np.abs((x - mu) / sc)) cdf_arr = 0.5 + 0.5 * sgn_arr * (1 - exp_arr) - return pd.DataFrame(cdf_arr, index=x.index, columns=x.columns) - - def ppf(self, p): - """Quantile function = percent point function = inverse cdf.""" - d = self.loc[p.index, p.columns] - sgn_arr = np.sign(p.values - 0.5) - icdf_arr = d.mu - d.scale * sgn_arr * np.log(1 - 2 * np.abs(p.values - 0.5)) - return pd.DataFrame(icdf_arr, index=p.index, columns=p.columns) + return cdf_arr + + def _ppf(self, p): + """Quantile function = percent point function = inverse cdf. + + Parameters + ---------- + p : 2D np.ndarray, same shape as ``self`` + values to evaluate the ppf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + ppf values at the given points + """ + mu = self._bc_params["mu"] + sc = self._bc_params["scale"] + + sgn_arr = np.sign(p - 0.5) + icdf_arr = mu - sc * sgn_arr * np.log(1 - 2 * np.abs(p - 0.5)) + return icdf_arr @classmethod def get_test_params(cls, parameter_set="default"): diff --git a/skpro/distributions/logistic.py b/skpro/distributions/logistic.py index 3c8b3705..30b8b44d 100644 --- a/skpro/distributions/logistic.py +++ b/skpro/distributions/logistic.py @@ -10,7 +10,15 @@ class Logistic(BaseDistribution): - """Logistic distribution. + r"""Logistic distribution. + + The logistic distribution is parametrized by a mean parameter :math:`\mu`, + and scale parameter :math:`s`, such that the cdf is given by: + + .. math:: F(x) = \frac{1}{1 + \exp\left(\frac{x - \mu}{s}\right)} + + The scale parameter :math:`s` is represented by the parameter ``scale``, + and the mean parameter :math:`\mu` by the parameter ``mu``. Parameters ---------- @@ -32,43 +40,26 @@ class Logistic(BaseDistribution): "capabilities:approx": ["pdfnorm", "energy"], "capabilities:exact": ["mean", "var", "pdf", "log_pdf", "cdf", "ppf"], "distr:measuretype": "continuous", + "broadcast_init": "on", } def __init__(self, mu, scale, index=None, columns=None): self.mu = mu self.scale = scale - self.index = index - self.columns = columns - - # todo: untangle index handling - # and broadcast of parameters. - # move this functionality to the base class - self._mu, self._scale = self._get_bc_params(self.mu, self.scale) - shape = self._mu.shape - - if index is None: - index = pd.RangeIndex(shape[0]) - - if columns is None: - columns = pd.RangeIndex(shape[1]) super().__init__(index=index, columns=columns) - def mean(self): - r"""Return expected value of the distribution. - - Let :math:`X` be a random variable with the distribution of `self`. - Returns the expectation :math:`\mathbb{E}[X]` + def _mean(self): + """Return expected value of the distribution. Returns ------- - pd.DataFrame with same rows, columns as `self` - expected value of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + expected value of distribution (entry-wise) """ - mean_arr = self._mu - return pd.DataFrame(mean_arr, index=self.index, columns=self.columns) + return self._bc_params["mu"] - def var(self): + def _var(self): r"""Return variance of the distribution. Let :math:`X` be a random variable with the distribution of `self`. @@ -79,35 +70,88 @@ def var(self): pd.DataFrame with same rows, columns as `self` variance of distribution (entry-wise) """ - var_arr = (self._scale**2 * np.pi**2) / 3 - return pd.DataFrame(var_arr, index=self.index, columns=self.columns) - - def pdf(self, x): - """Probability density function.""" - d = self.loc[x.index, x.columns] - numerator = np.exp(-(x.values - d.mu) / d.scale) - denominator = d.scale * ((1 + np.exp(-(x.values - d.mu) / d.scale)) ** 2) + scale = self._bc_params["scale"] + var_arr = (scale**2 * np.pi**2) / 3 + return var_arr + + def _pdf(self, x): + """Probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + pdf values at the given points + """ + mu = self._bc_params["mu"] + scale = self._bc_params["scale"] + + numerator = np.exp(-(x - mu) / scale) + denominator = scale * ((1 + np.exp(-(x - mu) / scale)) ** 2) pdf_arr = numerator / denominator - return pd.DataFrame(pdf_arr, index=x.index, columns=x.columns) - - def log_pdf(self, x): - """Logarithmic probability density function.""" - d = self.loc[x.index, x.columns] - y = -(x.values - d.mu) / d.scale - lpdf_arr = y - np.log(d.scale) - 2 * np.logaddexp(0, y) - return pd.DataFrame(lpdf_arr, index=x.index, columns=x.columns) - - def cdf(self, x): - """Cumulative distribution function.""" - d = self.loc[x.index, x.columns] - cdf_arr = (1 + np.tanh((x.values - d.mu) / (2 * d.scale))) / 2 - return pd.DataFrame(cdf_arr, index=x.index, columns=x.columns) - - def ppf(self, p): - """Quantile function = percent point function = inverse cdf.""" - d = self.loc[p.index, p.columns] - ppf_arr = d.mu + d.scale * np.log(p.values / (1 - p.values)) - return pd.DataFrame(ppf_arr, index=p.index, columns=p.columns) + return pdf_arr + + def _log_pdf(self, x): + """Logarithmic probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + log pdf values at the given points + """ + mu = self._bc_params["mu"] + scale = self._bc_params["scale"] + + y = -(x - mu) / scale + lpdf_arr = y - np.log(scale) - 2 * np.logaddexp(0, y) + return lpdf_arr + + def _cdf(self, x): + """Cumulative distribution function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the cdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + cdf values at the given points + """ + mu = self._bc_params["mu"] + scale = self._bc_params["scale"] + + cdf_arr = (1 + np.tanh((x - mu) / (2 * scale))) / 2 + return cdf_arr + + def _ppf(self, p): + """Quantile function = percent point function = inverse cdf. + + Parameters + ---------- + p : 2D np.ndarray, same shape as ``self`` + values to evaluate the ppf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + ppf values at the given points + """ + mu = self._bc_params["mu"] + scale = self._bc_params["scale"] + + ppf_arr = mu + scale * np.log(p / (1 - p)) + return ppf_arr @classmethod def get_test_params(cls, parameter_set="default"): diff --git a/skpro/distributions/lognormal.py b/skpro/distributions/lognormal.py index 135afa70..6f919798 100644 --- a/skpro/distributions/lognormal.py +++ b/skpro/distributions/lognormal.py @@ -36,28 +36,15 @@ class LogNormal(BaseDistribution): _tags = { "authors": ["bhavikar04", "fkiraly"], - "capabilities:approx": ["pdflognorm"], - "capabilities:exact": ["mean", "var", "energy", "pdf", "log_pdf", "cdf", "ppf"], + "capabilities:approx": ["energy", "pdfnorm"], + "capabilities:exact": ["mean", "var", "pdf", "log_pdf", "cdf", "ppf"], "distr:measuretype": "continuous", + "broadcast_init": "on", } def __init__(self, mu, sigma, index=None, columns=None): self.mu = mu self.sigma = sigma - self.index = index - self.columns = columns - - # todo: untangle index handling - # and broadcast of parameters. - # move this functionality to the base class - self._mu, self._sigma = self._get_bc_params() - shape = self._mu.shape - - if index is None: - index = pd.RangeIndex(shape[0]) - - if columns is None: - columns = pd.RangeIndex(shape[1]) super().__init__(index=index, columns=columns) @@ -99,62 +86,112 @@ def __init__(self, mu, sigma, index=None, columns=None): # energy = pd.DataFrame(energy_arr, index=self.index, columns=["energy"]) # return energy - def mean(self): - r"""Return expected value of the distribution. - - Let :math:`X` be a random variable with the distribution of `self`. - Returns the expectation :math:`\mathbb{E}[X]` + def _mean(self): + """Return expected value of the distribution. Returns ------- - pd.DataFrame with same rows, columns as `self` - expected value of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + expected value of distribution (entry-wise) """ - mean_arr = np.exp(self._mu + self._sigma**2 / 2) - return pd.DataFrame(mean_arr, index=self.index, columns=self.columns) + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] - def var(self): - r"""Return element/entry-wise variance of the distribution. + mean_arr = np.exp(mu + sigma**2 / 2) + return mean_arr - Let :math:`X` be a random variable with the distribution of `self`. - Returns :math:`\mathbb{V}[X] = \mathbb{E}\left(X - \mathbb{E}[X]\right)^2` + def _var(self): + r"""Return element/entry-wise variance of the distribution. Returns ------- - pd.DataFrame with same rows, columns as `self` - variance of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + variance of the distribution (entry-wise) """ - mu = self._mu - sigma = self._sigma + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + sd_arr = np.exp(2 * mu + 2 * sigma**2) - np.exp(2 * mu + sigma**2) - return pd.DataFrame(sd_arr, index=self.index, columns=self.columns) ** 2 - - def pdf(self, x): - """Probability density function.""" - d = self.loc[x.index, x.columns] - pdf_arr = np.exp(-0.5 * ((np.log(x.values) - d.mu) / d.sigma) ** 2) - pdf_arr = pdf_arr / (x.values * d.sigma * np.sqrt(2 * np.pi)) - return pd.DataFrame(pdf_arr, index=x.index, columns=x.columns) - - def log_pdf(self, x): - """Logarithmic probability density function.""" - d = self.loc[x.index, x.columns] - lpdf_arr = -0.5 * ((np.log(x.values) - d.mu) / d.sigma) ** 2 - lpdf_arr = lpdf_arr - np.log(x.values * d.sigma * np.sqrt(2 * np.pi)) - return pd.DataFrame(lpdf_arr, index=x.index, columns=x.columns) - - def cdf(self, x): - """Cumulative distribution function.""" - d = self.loc[x.index, x.columns] - cdf_arr = 0.5 + 0.5 * erf((np.log(x.values) - d.mu) / (d.sigma * np.sqrt(2))) - return pd.DataFrame(cdf_arr, index=x.index, columns=x.columns) - - def ppf(self, p): - """Quantile function = percent point function = inverse cdf.""" - d = self.loc[p.index, p.columns] - icdf_arr = d.mu + d.sigma * np.sqrt(2) * erfinv(2 * p.values - 1) + return sd_arr + + def _pdf(self, x): + """Probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + pdf values at the given points + """ + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + pdf_arr = np.exp(-0.5 * ((np.log(x) - mu) / sigma) ** 2) + pdf_arr = pdf_arr / (x * sigma * np.sqrt(2 * np.pi)) + return pdf_arr + + def _log_pdf(self, x): + """Logarithmic probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + log pdf values at the given points + """ + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + lpdf_arr = -0.5 * ((np.log(x) - mu) / sigma) ** 2 + lpdf_arr = lpdf_arr - np.log(x * sigma * np.sqrt(2 * np.pi)) + return lpdf_arr + + def _cdf(self, x): + """Cumulative distribution function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the cdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + cdf values at the given points + """ + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + cdf_arr = 0.5 + 0.5 * erf((np.log(x) - mu) / (sigma * np.sqrt(2))) + return cdf_arr + + def _ppf(self, p): + """Quantile function = percent point function = inverse cdf. + + Parameters + ---------- + p : 2D np.ndarray, same shape as ``self`` + values to evaluate the ppf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + ppf values at the given points + """ + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + icdf_arr = mu + sigma * np.sqrt(2) * erfinv(2 * p - 1) icdf_arr = np.exp(icdf_arr) - return pd.DataFrame(icdf_arr, index=p.index, columns=p.columns) + return icdf_arr @classmethod def get_test_params(cls, parameter_set="default"): diff --git a/skpro/distributions/normal.py b/skpro/distributions/normal.py index 766c6d6e..3e32985e 100644 --- a/skpro/distributions/normal.py +++ b/skpro/distributions/normal.py @@ -11,7 +11,18 @@ class Normal(BaseDistribution): - """Normal distribution (skpro native). + r"""Normal distribution (skpro native). + + This distribution is univariate, without correlation between dimensions + for the array-valued case. + + The normal distribution is parametrized by mean :math:`\mu` and + standard deviation :math:`\sigma`, such that the pdf is + + .. math:: f(x) = \frac{1}{\sigma \sqrt{2\pi}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right) # noqa E501 + + The mean :math:`\mu` is represented by the parameter ``mu``, + and the standard deviation :math:`\sigma` by the parameter ``sigma``. Parameters ---------- @@ -33,110 +44,158 @@ class Normal(BaseDistribution): "capabilities:approx": ["pdfnorm"], "capabilities:exact": ["mean", "var", "energy", "pdf", "log_pdf", "cdf", "ppf"], "distr:measuretype": "continuous", + "broadcast_init": "on", } def __init__(self, mu, sigma, index=None, columns=None): self.mu = mu self.sigma = sigma - self.index = index - self.columns = columns - # todo: untangle index handling - # and broadcast of parameters. - # move this functionality to the base class - self._mu, self._sigma = self._get_bc_params(self.mu, self.sigma) - shape = self._mu.shape + super().__init__(index=index, columns=columns) - if index is None: - index = pd.RangeIndex(shape[0]) + def _energy_self(self): + r"""Energy of self, w.r.t. self. - if columns is None: - columns = pd.RangeIndex(shape[1]) + :math:`\mathbb{E}[|X-Y|]`, where :math:`X, Y` are i.i.d. copies of self. - super().__init__(index=index, columns=columns) + Private method, to be implemented by subclasses. - def energy(self, x=None): - r"""Energy of self, w.r.t. self or a constant frame x. + Returns + ------- + 2D np.ndarray, same shape as ``self`` + energy values w.r.t. the given points + """ + sigma = self._bc_params["sigma"] + energy_arr = 2 * sigma / np.sqrt(np.pi) + if energy_arr.ndim > 0: + energy_arr = np.sum(energy_arr, axis=1) + return energy_arr - Let :math:`X, Y` be i.i.d. random variables with the distribution of `self`. + def _energy_x(self, x): + r"""Energy of self, w.r.t. a constant frame x. - If `x` is `None`, returns :math:`\mathbb{E}[|X-Y|]` (per row), "self-energy". - If `x` is passed, returns :math:`\mathbb{E}[|X-x|]` (per row), "energy wrt x". + :math:`\mathbb{E}[|X-x|]`, where :math:`X` is a copy of self, + and :math:`x` is a constant. + + Private method, to be implemented by subclasses. Parameters ---------- - x : None or pd.DataFrame, optional, default=None - if pd.DataFrame, must have same rows and columns as `self` + x : 2D np.ndarray, same shape as ``self`` + values to compute energy w.r.t. to Returns ------- - pd.DataFrame with same rows as `self`, single column `"energy"` - each row contains one float, self-energy/energy as described above. + 2D np.ndarray, same shape as ``self`` + energy values w.r.t. the given points """ - if x is None: - sd_arr = self._sigma - energy_arr = 2 * np.sum(sd_arr, axis=1) / np.sqrt(np.pi) - energy = pd.DataFrame(energy_arr, index=self.index, columns=["energy"]) - else: - mu_arr, sd_arr = self._mu, self._sigma - c_arr = (x - mu_arr) * (2 * self.cdf(x) - 1) + 2 * sd_arr**2 * self.pdf(x) - energy_arr = np.sum(c_arr, axis=1) - energy = pd.DataFrame(energy_arr, index=self.index, columns=["energy"]) - return energy - - def mean(self): - r"""Return expected value of the distribution. - - Let :math:`X` be a random variable with the distribution of `self`. - Returns the expectation :math:`\mathbb{E}[X]` + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + cdf = self.cdf(x) + pdf = self.pdf(x) + energy_arr = (x - mu) * (2 * cdf - 1) + 2 * sigma**2 * pdf + if energy_arr.ndim > 0: + energy_arr = np.sum(energy_arr, axis=1) + return energy_arr + + def _mean(self): + """Return expected value of the distribution. Returns ------- - pd.DataFrame with same rows, columns as `self` - expected value of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + expected value of distribution (entry-wise) """ - mean_arr = self._mu - return pd.DataFrame(mean_arr, index=self.index, columns=self.columns) + return self._bc_params["mu"] - def var(self): + def _var(self): r"""Return element/entry-wise variance of the distribution. - Let :math:`X` be a random variable with the distribution of `self`. - Returns :math:`\mathbb{V}[X] = \mathbb{E}\left(X - \mathbb{E}[X]\right)^2` + Returns + ------- + 2D np.ndarray, same shape as ``self`` + variance of the distribution (entry-wise) + """ + return self._bc_params["sigma"] ** 2 + + def _pdf(self, x): + """Probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at Returns ------- - pd.DataFrame with same rows, columns as `self` - variance of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + pdf values at the given points """ - sd_arr = self._sigma - return pd.DataFrame(sd_arr, index=self.index, columns=self.columns) ** 2 - - def pdf(self, x): - """Probability density function.""" - d = self.loc[x.index, x.columns] - pdf_arr = np.exp(-0.5 * ((x.values - d.mu) / d.sigma) ** 2) - pdf_arr = pdf_arr / (d.sigma * np.sqrt(2 * np.pi)) - return pd.DataFrame(pdf_arr, index=x.index, columns=x.columns) - - def log_pdf(self, x): - """Logarithmic probability density function.""" - d = self.loc[x.index, x.columns] - lpdf_arr = -0.5 * ((x.values - d.mu) / d.sigma) ** 2 - lpdf_arr = lpdf_arr - np.log(d.sigma * np.sqrt(2 * np.pi)) - return pd.DataFrame(lpdf_arr, index=x.index, columns=x.columns) - - def cdf(self, x): - """Cumulative distribution function.""" - d = self.loc[x.index, x.columns] - cdf_arr = 0.5 + 0.5 * erf((x.values - d.mu) / (d.sigma * np.sqrt(2))) - return pd.DataFrame(cdf_arr, index=x.index, columns=x.columns) - - def ppf(self, p): - """Quantile function = percent point function = inverse cdf.""" - d = self.loc[p.index, p.columns] - icdf_arr = d.mu + d.sigma * np.sqrt(2) * erfinv(2 * p.values - 1) - return pd.DataFrame(icdf_arr, index=p.index, columns=p.columns) + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + pdf_arr = np.exp(-0.5 * ((x - mu) / sigma) ** 2) + pdf_arr = pdf_arr / (sigma * np.sqrt(2 * np.pi)) + return pdf_arr + + def _log_pdf(self, x): + """Logarithmic probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + log pdf values at the given points + """ + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + lpdf_arr = -0.5 * ((x - mu) / sigma) ** 2 + lpdf_arr = lpdf_arr - np.log(sigma * np.sqrt(2 * np.pi)) + return lpdf_arr + + def _cdf(self, x): + """Cumulative distribution function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the cdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + cdf values at the given points + """ + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + cdf_arr = 0.5 + 0.5 * erf((x - mu) / (sigma * np.sqrt(2))) + return cdf_arr + + def _ppf(self, p): + """Quantile function = percent point function = inverse cdf. + + Parameters + ---------- + p : 2D np.ndarray, same shape as ``self`` + values to evaluate the ppf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + ppf values at the given points + """ + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + icdf_arr = mu + sigma * np.sqrt(2) * erfinv(2 * p - 1) + return icdf_arr @classmethod def get_test_params(cls, parameter_set="default"): diff --git a/skpro/distributions/t.py b/skpro/distributions/t.py index cd34c3f4..e2386ef6 100644 --- a/skpro/distributions/t.py +++ b/skpro/distributions/t.py @@ -37,47 +37,38 @@ class TDistribution(BaseDistribution): "capabilities:approx": ["pdfnorm", "energy"], "capabilities:exact": ["mean", "var", "pdf", "log_pdf", "cdf", "ppf"], "distr:measuretype": "continuous", + "broadcast_init": "on", } def __init__(self, mu, sigma, df=1, index=None, columns=None): self.mu = mu self.sigma = sigma self.df = df - self.index = index - self.columns = columns - - self._mu, self._sigma, self._df = self._get_bc_params( - self.mu, self.sigma, self.df - ) - shape = self._mu.shape - - if index is None: - index = pd.RangeIndex(shape[0]) - - if columns is None: - columns = pd.RangeIndex(shape[1]) super().__init__(index=index, columns=columns) - def mean(self): - r"""Return expected value of the distribution. - - Let :math:`X` be a random variable with the distribution of `self`. - Returns the expectation :math:`\mathbb{E}[X]`. The expectation, - :math:`\mathbb{E}[X]`, as infinite if :math:`\nu \le 1`. + def _mean(self): + """Return expected value of the distribution. Returns ------- - pd.DataFrame with same rows, columns as `self` - expected value of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + expected value of distribution (entry-wise) """ - mean_arr = self._mu.copy() - if (self._df <= 1).any(): + mean_arr = self._bc_params["mu"] + df = self._bc_params["df"] + + if self.ndim == 0: + if df <= 1: + return np.inf + return mean_arr + + if (df <= 1).any(): mean_arr = mean_arr.astype(np.float32) - mean_arr[self._df <= 1] = np.inf - return pd.DataFrame(mean_arr, index=self.index, columns=self.columns) + mean_arr[df <= 1] = np.inf + return mean_arr - def var(self): + def _var(self): r"""Return element/entry-wise variance of the distribution. Let :math:`X` be a random variable with the distribution of `self`. @@ -96,64 +87,121 @@ def var(self): pd.DataFrame with same rows, columns as `self` variance of distribution (entry-wise) """ - df_arr = self._df.copy() - df_arr = df_arr.astype(np.float32) + sigma = self._bc_params["sigma"] + df = self._bc_params["df"] + df_arr = df.astype(np.float32) + + if self.ndim == 0: + if df <= 2: + return np.inf + return sigma**2 * df / (df - 2) + df_arr[df_arr <= 2] = np.inf mask = (df_arr > 2) & (df_arr != np.inf) - df_arr[mask] = self._sigma[mask] ** 2 * df_arr[mask] / (df_arr[mask] - 2) - return pd.DataFrame(df_arr, index=self.index, columns=self.columns) - - def pdf(self, x): - """Probability density function.""" - d = self.loc[x.index, x.columns] - pdf_arr = gamma((d._df + 1) / 2) - pdf_arr = pdf_arr / (np.sqrt(np.pi * d._df) * gamma(d._df / 2)) - pdf_arr = pdf_arr * (1 + ((x - d._mu) / d._sigma) ** 2 / d._df) ** ( - -(d._df + 1) / 2 - ) - pdf_arr = pdf_arr / d._sigma - return pd.DataFrame(pdf_arr, index=x.index, columns=x.columns) - - def log_pdf(self, x): - """Logarithmic probability density function.""" - d = self.loc[x.index, x.columns] - lpdf_arr = loggamma((d._df + 1) / 2) - lpdf_arr = lpdf_arr - 0.5 * np.log(d._df * np.pi) - lpdf_arr = lpdf_arr - loggamma(d._df / 2) - lpdf_arr = lpdf_arr - ((d._df + 1) / 2) * np.log( - 1 + ((x - d._mu) / d._sigma) ** 2 / d._df - ) - lpdf_arr = lpdf_arr - np.log(d._sigma) - return pd.DataFrame(lpdf_arr, index=x.index, columns=x.columns) - - def cdf(self, x): - """Cumulative distribution function.""" - d = self.loc[x.index, x.columns] - x_ = (x - d._mu) / d._sigma - cdf_arr = x_ * gamma((d._df + 1) / 2) - cdf_arr = cdf_arr * hyp2f1(0.5, (d._df + 1) / 2, 3 / 2, -(x_**2) / d._df) - cdf_arr = 0.5 + cdf_arr / (np.sqrt(np.pi * d._df) * gamma(d._df / 2)) - return pd.DataFrame(cdf_arr, index=x.index, columns=x.columns) - - def ppf(self, p): - """Quantile function = percent point function = inverse cdf.""" - d = self.loc[p.index, p.columns] - ppf_arr = p.to_numpy(copy=True) - ppf_arr[p.values == 0.5] = 0.0 - ppf_arr[p.values <= 0] = -np.inf - ppf_arr[p.values >= 1] = np.inf - - mask1 = (p.values < 0.5) & (p.values > 0) - mask2 = (p.values < 1) & (p.values > 0.5) - ppf_arr[mask1] = 1 / betaincinv(0.5 * d._df[mask1], 0.5, 2 * ppf_arr[mask1]) - ppf_arr[mask2] = 1 / betaincinv( - 0.5 * d._df[mask2], 0.5, 2 * (1 - ppf_arr[mask2]) - ) + df_arr[mask] = sigma[mask] ** 2 * df_arr[mask] / (df_arr[mask] - 2) + return df_arr + + def _pdf(self, x): + """Probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + pdf values at the given points + """ + df = self._bc_params["df"] + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + pdf_arr = gamma((df + 1) / 2) + pdf_arr = pdf_arr / (np.sqrt(np.pi * df) * gamma(df / 2)) + pdf_arr = pdf_arr * (1 + ((x - mu) / sigma) ** 2 / df) ** (-(df + 1) / 2) + pdf_arr = pdf_arr / sigma + return pdf_arr + + def _log_pdf(self, x): + """Logarithmic probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + log pdf values at the given points + """ + df = self._bc_params["df"] + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + lpdf_arr = loggamma((df + 1) / 2) + lpdf_arr = lpdf_arr - 0.5 * np.log(df * np.pi) + lpdf_arr = lpdf_arr - loggamma(df / 2) + lpdf_arr = lpdf_arr - ((df + 1) / 2) * np.log(1 + ((x - mu) / sigma) ** 2 / df) + lpdf_arr = lpdf_arr - np.log(sigma) + return lpdf_arr + + def _cdf(self, x): + """Cumulative distribution function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the cdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + cdf values at the given points + """ + df = self._bc_params["df"] + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + x_ = (x - mu) / sigma + cdf_arr = x_ * gamma((df + 1) / 2) + cdf_arr = cdf_arr * hyp2f1(0.5, (df + 1) / 2, 3 / 2, -(x_**2) / df) + cdf_arr = 0.5 + cdf_arr / (np.sqrt(np.pi * df) * gamma(df / 2)) + return cdf_arr + + def _ppf(self, p): + """Quantile function = percent point function = inverse cdf. + + Parameters + ---------- + p : 2D np.ndarray, same shape as ``self`` + values to evaluate the ppf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + ppf values at the given points + """ + df = self._bc_params["df"] + mu = self._bc_params["mu"] + sigma = self._bc_params["sigma"] + + ppf_arr = p.copy() + ppf_arr[p == 0.5] = 0.0 + ppf_arr[p <= 0] = -np.inf + ppf_arr[p >= 1] = np.inf + + mask1 = (p < 0.5) & (p > 0) + mask2 = (p < 1) & (p > 0.5) + ppf_arr[mask1] = 1 / betaincinv(0.5 * df[mask1], 0.5, 2 * ppf_arr[mask1]) + ppf_arr[mask2] = 1 / betaincinv(0.5 * df[mask2], 0.5, 2 * (1 - ppf_arr[mask2])) ppf_arr[mask1 | mask2] = np.sqrt(ppf_arr[mask1 | mask2] - 1) - ppf_arr[mask1 | mask2] = np.sqrt(d._df[mask1 | mask2]) * ppf_arr[mask1 | mask2] + ppf_arr[mask1 | mask2] = np.sqrt(df[mask1 | mask2]) * ppf_arr[mask1 | mask2] ppf_arr[mask1] = -ppf_arr[mask1] - ppf_arr = d._sigma * ppf_arr + d._mu - return pd.DataFrame(ppf_arr, index=p.index, columns=p.columns) + ppf_arr = sigma * ppf_arr + mu + return ppf_arr @classmethod def get_test_params(cls, parameter_set="default"): diff --git a/skpro/distributions/tests/test_base_scalar.py b/skpro/distributions/tests/test_base_scalar.py new file mode 100644 index 00000000..a0efe306 --- /dev/null +++ b/skpro/distributions/tests/test_base_scalar.py @@ -0,0 +1,90 @@ +"""Test base class logic for scalar distributions. + +Distributions behave polymorphically - if no index is passed, +all methods should accept and return scalars. + +This is tested by using the Normal distribution as a test case, +which invokes the boilerplate for scalar distributions. +""" +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) + +__author__ = ["fkiraly"] + +import numpy as np +import pandas as pd + +from skpro.distributions.normal import Normal + + +def test_scalar_distribution(): + """Test scalar distribution logic.""" + # test params + mu = 1 + sigma = 2 + + # instantiate distribution + d = Normal(mu=mu, sigma=sigma) + assert d.ndim == 0 + assert d.shape == () + assert d.index is None + assert d.columns is None + + # test scalar input + x = 0.5 + assert np.isscalar(d.mean()) + assert np.isscalar(d.var()) + assert np.isscalar(d.energy()) + assert np.isscalar(d.energy(x)) + assert np.isscalar(d.pdf(x)) + assert np.isscalar(d.log_pdf(x)) + assert np.isscalar(d.cdf(x)) + assert np.isscalar(d.ppf(x)) + assert np.isscalar(d.sample()) + + spl_mult = d.sample(5) + assert isinstance(spl_mult, pd.DataFrame) + assert spl_mult.shape == (5, 1) + assert spl_mult.index.equals(pd.RangeIndex(5)) + + +def test_broadcast_ambiguous(): + """Test broadcasting in cases of ambiguous parameter dimensions.""" + mu = [1] + sigma = 2 + # this should result in 2D array distribution + # anything that is not scalar is broadcast to 2D + d = Normal(mu=mu, sigma=sigma) + assert d.ndim == 2 + assert d.shape == (1, 1) + assert d.index.equals(pd.RangeIndex(1)) + assert d.columns.equals(pd.RangeIndex(1)) + + def is_expected_2d(output, col=None): + assert isinstance(output, pd.DataFrame) + assert output.ndim == 2 + assert output.shape == (1, 1) + assert output.index.equals(pd.RangeIndex(1)) + if col is None: + col = pd.RangeIndex(1) + assert output.columns.equals(pd.Index(col)) + return True + + # test scalar input + x = 0.5 + + assert is_expected_2d(d.mean()) + assert is_expected_2d(d.var()) + assert is_expected_2d(d.energy(), ["energy"]) + assert is_expected_2d(d.energy(x), ["energy"]) + assert is_expected_2d(d.pdf(x)) + assert is_expected_2d(d.log_pdf(x)) + assert is_expected_2d(d.cdf(x)) + assert is_expected_2d(d.ppf(x)) + assert is_expected_2d(d.sample()) + + spl_mult = d.sample(5) + assert isinstance(spl_mult, pd.DataFrame) + assert spl_mult.shape == (5, 1) + assert isinstance(spl_mult.index, pd.MultiIndex) + assert spl_mult.index.nlevels == 2 + assert spl_mult.columns.equals(pd.RangeIndex(1)) diff --git a/skpro/distributions/weibull.py b/skpro/distributions/weibull.py index 013bad7a..f246bcd5 100644 --- a/skpro/distributions/weibull.py +++ b/skpro/distributions/weibull.py @@ -11,7 +11,15 @@ class Weibull(BaseDistribution): - """Weibull distribution. + r"""Weibull distribution. + + The Weibull distribution is parametrized by scale parameter :math:`\lambda`, + and shape parameter :math:`k`, such that the cdf is given by: + + .. math:: F(x) = 1 - \exp\left(-\left(\frac{x}{\lambda}\right)^k\right) + + The scale parameter :math:`\lambda` is represented by the parameter ``scale``, + and the shape parameter :math:`k` by the parameter ``k``. Parameters ---------- @@ -33,29 +41,16 @@ class Weibull(BaseDistribution): "capabilities:approx": ["pdfnorm", "energy"], "capabilities:exact": ["mean", "var", "pdf", "log_pdf", "cdf", "ppf"], "distr:measuretype": "continuous", + "broadcast_init": "on", } def __init__(self, scale, k, index=None, columns=None): self.scale = scale self.k = k - self.index = index - self.columns = columns - - # todo: untangle index handling - # and broadcast of parameters. - # move this functionality to the base class - self._scale, self._k = self._get_bc_params(self.scale, self.k) - shape = self._scale.shape - - if index is None: - index = pd.RangeIndex(shape[0]) - - if columns is None: - columns = pd.RangeIndex(shape[1]) super().__init__(index=index, columns=columns) - def mean(self): + def _mean(self): r"""Return expected value of the distribution. For Weibull distribution, expectation is given by, @@ -63,13 +58,15 @@ def mean(self): Returns ------- - pd.DataFrame with same rows, columns as `self` - expected value of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + expected value of distribution (entry-wise) """ - mean_arr = self._scale * gamma(1 + 1 / self._k) - return pd.DataFrame(mean_arr, index=self.index, columns=self.columns) + scale = self._bc_params["scale"] + k = self._bc_params["k"] + mean_arr = scale * gamma(1 + 1 / k) + return mean_arr - def var(self): + def _var(self): r"""Return element/entry-wise variance of the distribution. For Weibull distribution, variance is given by @@ -77,47 +74,95 @@ def var(self): Returns ------- - pd.DataFrame with same rows, columns as `self` - variance of distribution (entry-wise) + 2D np.ndarray, same shape as ``self`` + pdf values at the given points + """ + scale = self._bc_params["scale"] + k = self._bc_params["k"] + + left_gamma = gamma(1 + 2 / k) + right_gamma = gamma(1 + 1 / k) ** 2 + var_arr = scale**2 * (left_gamma - right_gamma) + return var_arr + + def _pdf(self, x): + """Probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + pdf values at the given points """ - left_gamma = gamma(1 + 2 / self._k) - right_gamma = gamma(1 + 1 / self._k) ** 2 - var_arr = self._scale**2 * (left_gamma - right_gamma) - return pd.DataFrame(var_arr, index=self.index, columns=self.columns) - - def pdf(self, x): - """Probability density function.""" - d = self.loc[x.index, x.columns] - # if x.values[i] < 0, then pdf_arr[i] = 0 - pdf_arr = ( - (d.k / d.scale) - * (x.values / d.scale) ** (d.k - 1) - * np.exp(-((x.values / d.scale) ** d.k)) - ) - return pd.DataFrame(pdf_arr, index=x.index, columns=x.columns) - - def log_pdf(self, x): - """Logarithmic probability density function.""" - d = self.loc[x.index, x.columns] - lpdf_arr = ( - np.log(d.k / d.scale) - + (d.k - 1) * np.log(x.values / d.scale) - - (x.values / d.scale) ** d.k - ) - return pd.DataFrame(lpdf_arr, index=x.index, columns=x.columns) - - def cdf(self, x): - """Cumulative distribution function.""" - d = self.loc[x.index, x.columns] - # if x.values[i] < 0, then cdf_arr[i] = 0 - cdf_arr = 1 - np.exp(-((x.values / d.scale) ** d.k)) - return pd.DataFrame(cdf_arr, index=x.index, columns=x.columns) - - def ppf(self, p): - """Quantile function = percent point function = inverse cdf.""" - d = self.loc[p.index, p.columns] - ppf_arr = d.scale * (-np.log(1 - p.values)) ** (1 / d.k) - return pd.DataFrame(ppf_arr, index=p.index, columns=p.columns) + k = self._bc_params["k"] + scale = self._bc_params["scale"] + + pdf_arr = (k / scale) * (x / scale) ** (k - 1) * np.exp(-((x / scale) ** k)) + pdf_arr[x < 0] = 0 # if x < 0, pdf = 0 + return pdf_arr + + def _log_pdf(self, x): + """Logarithmic probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + log pdf values at the given points + """ + k = self._bc_params["k"] + scale = self._bc_params["scale"] + + lpdf_arr = np.log(k / scale) + (k - 1) * np.log(x / scale) - (x / scale) ** k + lpdf_arr[x < 0] = -np.inf # if x < 0, pdf = 0, so log pdf = -inf + return lpdf_arr + + def _cdf(self, x): + """Cumulative distribution function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the cdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + cdf values at the given points + """ + k = self._bc_params["k"] + scale = self._bc_params["scale"] + + cdf_arr = 1 - np.exp(-((x / scale) ** k)) + cdf_arr[x < 0] = 0 # if x < 0, cdf = 0 + return cdf_arr + + def _ppf(self, p): + """Quantile function = percent point function = inverse cdf. + + Parameters + ---------- + p : 2D np.ndarray, same shape as ``self`` + values to evaluate the ppf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + ppf values at the given points + """ + k = self._bc_params["k"] + scale = self._bc_params["scale"] + + ppf_arr = scale * (-np.log(1 - p)) ** (1 / k) + return ppf_arr @classmethod def get_test_params(cls, parameter_set="default"): diff --git a/skpro/registry/_tags.py b/skpro/registry/_tags.py index 0cf68f12..bcae4604 100644 --- a/skpro/registry/_tags.py +++ b/skpro/registry/_tags.py @@ -191,6 +191,18 @@ "int", "max iters for bisection method in ppf", ), + ( + "broadcast_params", + "distribution", + ("list", "str"), + "distribution parameters to broadcast, complement is not broadcast", + ), + ( + "broadcast_init", + "distribution", + "str", + "whether to initialize broadcast parameters in __init__, 'on' or 'off'", + ), # --------------- # BaseProbaMetric # ---------------