|
| 1 | +Building blocks |
| 2 | +=============== |
| 3 | + |
| 4 | +Parameter object |
| 5 | +---------------- |
| 6 | +:py:class:`~torchtree.core.parameter.Parameter` objects play a central role in torchtree. They are used to define the parameters of models and distributions, and are involved in any kind of optimization or inference. |
| 7 | +A Parameter object contains a reference to a pytorch tensor object which can be accessed using the :py:attr:`~torchtree.core.parameter.Parameter.tensor` property. |
| 8 | +There are different ways to define a Parameter object in a JSON file. The most common way is to define a :keycode:`tensor` key associated with a list of real numbers as shown below: |
| 9 | + |
| 10 | +.. code-block:: JSON |
| 11 | +
|
| 12 | + { |
| 13 | + "id": "gtr_frequencies", |
| 14 | + "type": "Parameter", |
| 15 | + "tensor": [0.25, 0.25, 0.25, 0.25] |
| 16 | + } |
| 17 | +
|
| 18 | +Inside torchtree, this JSON object will be converted to a python object: |
| 19 | + |
| 20 | +.. code-block:: python |
| 21 | +
|
| 22 | + Parameter("gtr_frequencies", torch.tensor([0.25, 0.25, 0.25, 0.25])) |
| 23 | +
|
| 24 | +Another way to define the same object using a different initialization method is: |
| 25 | + |
| 26 | +.. code-block:: JSON |
| 27 | +
|
| 28 | + { |
| 29 | + "id": "gtr_frequencies", |
| 30 | + "type": "Parameter", |
| 31 | + "full": [4], |
| 32 | + "value": 0.25 |
| 33 | + } |
| 34 | +
|
| 35 | +which in python will be converted to: |
| 36 | + |
| 37 | +.. code-block:: python |
| 38 | + :linenos: |
| 39 | +
|
| 40 | + Parameter("gtr_frequencies", torch.full([4], 0.25)) |
| 41 | +
|
| 42 | +
|
| 43 | +TransformedParameter object |
| 44 | +--------------------------- |
| 45 | +In torchtree, :py:class:`~torchtree.core.parameter.Parameter` objects are typically considered to be unconstrained. |
| 46 | +Optimizers (such as those used in ADVI and MAP) and samplers (e.g. HMC) will change the value of the tensor they encapsulate without checking if the new value is within the parameter's domain. |
| 47 | +However, in many cases, phylogenetic models contain constrained parameters such as branch lengths (positive real numbers) or equilibrium base frequencies (positive real numbers that sum to 1). |
| 48 | +For example, the GTR model expects the equilibrium base frequencies to be positive real numbers that sum to 1 and a standard optimizer will ignore such constraints. |
| 49 | +:py:class:`~torchtree.core.parameter.TransformedParameter` objects allow moving from unconstrained to constrained spaces using `transform <https://pytorch.org/docs/stable/distributions.html#torch.distributions.transforms.Transform>`_ objects available in pytorch. |
| 50 | + |
| 51 | +We can replace the JSON object defining the GTR equilibrium base frequencies with a TransformedParameter object as shown below: |
| 52 | + |
| 53 | +.. code-block:: JSON |
| 54 | +
|
| 55 | + { |
| 56 | + "id": "gtr_frequencies", |
| 57 | + "type": "TransformedParameter", |
| 58 | + "transform": "torch.distributions.StickBreakingTransform", |
| 59 | + "x":{ |
| 60 | + "id": "gtr_frequencies_unconstrained", |
| 61 | + "type": "TransformedParameter", |
| 62 | + "type": "Parameter", |
| 63 | + "zeros": [3] |
| 64 | + } |
| 65 | + } |
| 66 | +
|
| 67 | +This is equivalent to the following python code: |
| 68 | + |
| 69 | +.. code-block:: python |
| 70 | +
|
| 71 | + import torch |
| 72 | + from torchtree import Parameter, TransformedParameter |
| 73 | +
|
| 74 | + unconstrained = Parameter("gtr_frequencies_unconstrained", torch.zeros([3])) |
| 75 | + transform = torch.distributions.StickBreakingTransform() |
| 76 | + constrained = TransformedParameter("gtr_frequencies", unconstrained, transform) |
| 77 | + |
| 78 | +An optimizer will change the value of the **gtr_frequencies_unconstrained** Parameter object and the **gtr_frequencies** (transformed) parameter will apply the StickBreakingTransform transform to the value of **gtr_frequencies_unconstrained** to update the transition rate matrix. |
| 79 | + |
| 80 | +In this example, we are using the `StickBreakingTransform <https://pytorch.org/docs/stable/distributions.html#torch.distributions.transforms.StickBreakingTransform>`_ object that will transform the unconstrained parameter **gtr_frequencies_unconstrained** to a constrained parameter **gtr_frequencies**. |
| 81 | +Note the value of the :keycode:`transform` key is a string containing the full path to the pytorch class that implements the transformation. |
| 82 | +Specifically, ``torch`` is the package name, ``distributions`` is the module name, and ``StickBreakingTransform`` is the class name. |
| 83 | + |
| 84 | + |
| 85 | +Models and CallableModels |
| 86 | +------------------------- |
| 87 | +Virtually every torchtree object that does some kind of computations inherits from the :py:class:`~torchtree.core.model.Model` class. |
| 88 | +Computations can involve Parameter and/or other Model objects. |
| 89 | +The Distribution class we described earlier is derived from the class Model since it defines a probability distribution and return a log probability. |
| 90 | +The GTR substitution model is also a Model object since its role is to calculate a transition probability matrix. |
| 91 | + |
| 92 | +A model that returns a value when called is said to be *callable* and it extends the :py:class:`~torchtree.core.model.CallableModel` abstract class. |
| 93 | +A distribution is a callable model since it returns the log probability of a sample. |
| 94 | +The class representing a tree likelihood model is also callable since it calculates the log likelihood and we will describe it further in the next section. |
| 95 | + |
0 commit comments