-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
James Bristow
committed
Feb 7, 2024
1 parent
fe32913
commit 4c320ee
Showing
9 changed files
with
1,570 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Pyro | ||
|
||
See https://pyro.ai/examples/index.html |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2e5c7973", | ||
"metadata": {}, | ||
"source": [ | ||
"# https://pyro.ai/examples/ekf.html" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "9ebb6829", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"'1.8.6'" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import os\n", | ||
"import math\n", | ||
"\n", | ||
"import torch\n", | ||
"import pyro\n", | ||
"import pyro.distributions as dist\n", | ||
"from pyro.infer.autoguide import AutoDelta\n", | ||
"from pyro.optim import Adam\n", | ||
"from pyro.infer import SVI, Trace_ELBO, config_enumerate\n", | ||
"from pyro.contrib.tracking.extended_kalman_filter import EKFState\n", | ||
"from pyro.contrib.tracking.distributions import EKFDistribution\n", | ||
"from pyro.contrib.tracking.dynamic_models import NcvContinuous\n", | ||
"from pyro.contrib.tracking.measurements import PositionMeasurement\n", | ||
"\n", | ||
"pyro.__version__" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "87dcc0f7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dt = 1e-2\n", | ||
"num_frames = 10\n", | ||
"dim = 4\n", | ||
"\n", | ||
"# Continuous model\n", | ||
"ncv = NcvContinuous(dim, 2.0)\n", | ||
"\n", | ||
"# Truth trajectory\n", | ||
"xs_truth = torch.zeros(num_frames, dim)\n", | ||
"# initial direction\n", | ||
"theta0_truth = 0.0\n", | ||
"# initial state\n", | ||
"with torch.no_grad():\n", | ||
" xs_truth[0, :] = torch.tensor([0.0, 0.0, math.cos(theta0_truth), math.sin(theta0_truth)])\n", | ||
" for frame_num in range(1, num_frames):\n", | ||
" # sample independent process noise\n", | ||
" dx = pyro.sample('process_noise_{}'.format(frame_num), ncv.process_noise_dist(dt))\n", | ||
" xs_truth[frame_num, :] = ncv(xs_truth[frame_num-1, :], dt=dt) + dx\n", | ||
" \n", | ||
"# Measurements\n", | ||
"measurements = []\n", | ||
"mean = torch.zeros(2)\n", | ||
"# no correlations\n", | ||
"cov = 1e-5 * torch.eye(2)\n", | ||
"with torch.no_grad():\n", | ||
" # sample independent measurement noise\n", | ||
" dzs = pyro.sample('dzs', dist.MultivariateNormal(mean, cov).expand((num_frames,)))\n", | ||
" # compute measurement means\n", | ||
" zs = xs_truth[:, :2] + dzs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "046c16cb", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/jbris/miniconda3/envs/data_assim/lib/python3.10/site-packages/torch/autograd/__init__.py:251: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11070). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)\n", | ||
" Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"loss: -15.20763874053955\n", | ||
"loss: -15.339868545532227\n", | ||
"loss: -15.413694381713867\n", | ||
"loss: -15.473196029663086\n", | ||
"loss: -15.507671356201172\n", | ||
"loss: -15.523503303527832\n", | ||
"loss: -15.5301513671875\n", | ||
"loss: -15.532791137695312\n", | ||
"loss: -15.533793449401855\n", | ||
"loss: -15.534193992614746\n", | ||
"loss: -15.534348487854004\n", | ||
"loss: -15.534411430358887\n", | ||
"loss: -15.534439086914062\n", | ||
"loss: -15.534448623657227\n", | ||
"loss: -15.534452438354492\n", | ||
"loss: -15.534453392028809\n", | ||
"loss: -15.534455299377441\n", | ||
"loss: -15.534454345703125\n", | ||
"loss: -15.534454345703125\n", | ||
"loss: -15.534454345703125\n", | ||
"loss: -15.534455299377441\n", | ||
"loss: -15.534455299377441\n", | ||
"loss: -15.534454345703125\n", | ||
"loss: -15.534453392028809\n", | ||
"loss: -15.534456253051758\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"def model(data):\n", | ||
" # a HalfNormal can be used here as well\n", | ||
" R = pyro.sample('pv_cov', dist.HalfCauchy(2e-6)) * torch.eye(4)\n", | ||
" Q = pyro.sample('measurement_cov', dist.HalfCauchy(1e-6)) * torch.eye(2)\n", | ||
" # observe the measurements\n", | ||
" pyro.sample('track_{}'.format(i), EKFDistribution(xs_truth[0], R, ncv,\n", | ||
" Q, time_steps=num_frames),\n", | ||
" obs=data)\n", | ||
"\n", | ||
"guide = AutoDelta(model) # MAP estimation\n", | ||
"\n", | ||
"optim = pyro.optim.Adam({'lr': 2e-2})\n", | ||
"svi = SVI(model, guide, optim, loss=Trace_ELBO(retain_graph=True))\n", | ||
"\n", | ||
"pyro.set_rng_seed(0)\n", | ||
"pyro.clear_param_store()\n", | ||
"\n", | ||
"for i in range(250):\n", | ||
" loss = svi.step(zs)\n", | ||
" if not i % 10:\n", | ||
" print('loss: ', loss)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"id": "bb429939", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[<pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35346980>,\n", | ||
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35347c10>,\n", | ||
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35347d60>,\n", | ||
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35345d20>,\n", | ||
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35344730>,\n", | ||
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354ffe80>,\n", | ||
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354ff0d0>,\n", | ||
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354fded0>,\n", | ||
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354ff9d0>,\n", | ||
" <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354feb90>]" | ||
] | ||
}, | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"R = guide()['pv_cov'] * torch.eye(4)\n", | ||
"Q = guide()['measurement_cov'] * torch.eye(2)\n", | ||
"ekf_dist = EKFDistribution(xs_truth[0], R, ncv, Q, time_steps=num_frames)\n", | ||
"states= ekf_dist.filter_states(zs)\n", | ||
"states" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "88a05fba", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Hierarchical mixed-effect hidden Markov models | ||
|
||
Note: This is a cleaned-up version of the seal experiments in [Bingham et al 2019] that is a simplified variant of some of the analysis in the [momentuHMM harbour seal example](https://github.com/bmcclintock/momentuHMM/blob/master/vignettes/harbourSealExample.R) [McClintock et al 2018]. | ||
|
||
Recent advances in sensor technology have made it possible to capture the movements of multiple wild animals within a single population at high spatiotemporal resolution over long periods of time [McClintock et al 2013, Towner et al 2016]. Discrete state-space models, where the latent state is thought of as corresponding to a behavior state such as "foraging" or "resting", have become popular computational tools for analyzing these new datasets thanks to their interpretability and tractability. | ||
|
||
This example applies several different hierarchical discrete state-space models to location data recorded from a colony of harbour seals on foraging excursions in the North Sea [McClintock et al 2013]. | ||
The raw data are irregularly sampled time series (roughly 5-15 minutes between samples) of GPS coordinates and diving activity for each individual in the colony (10 male and 7 female) over the course of a single day recorded by lightweight tracking devices physically attached to each animal by researchers. They have been preprocessed using the momentuHMM example code into smoothed, temporally regular series of step sizes, turn angles, and diving activity for each individual. | ||
|
||
The models are special cases of a time-inhomogeneous discrete state space model | ||
whose state transition distribution is specified by a hierarchical generalized linear mixed model (GLMM). | ||
At each timestep `t`, for each individual trajectory `b` in each group `a`, we have | ||
|
||
``` | ||
logit(p(x[t,a,b] = state i | x[t-1,a,b] = state j)) = | ||
(epsilon_G[a] + epsilon_I[a,b] + Z_I[a,b].T @ beta1 + Z_G[a].T @ beta2 + Z_T[t,a,b].T @ beta3)[i,j] | ||
``` | ||
|
||
where `a,b` correspond to plate indices, `epsilon_G` and `epsilon_I` are independent random variables for each group and individual within each group respectively, `Z`s are covariates, and `beta`s are parameter vectors. | ||
|
||
The random variables `epsilon` may be either discrete or continuous. | ||
If continuous, they are normally distributed. | ||
If discrete, they are sampled from a set of three possible values shared across the innermost plate of a particular variable. | ||
That is, for each individual trajectory `b` in each group `a`, we sample single random effect values for an entire trajectory: | ||
|
||
``` | ||
iota_G[a] ~ Categorical(pi_G) | ||
epsilon_G[a] = Theta_G[iota_G[a]] | ||
iota_I[a,b] ~ Categorical(pi_I[a]) | ||
epsilon_I[a,b] = Theta_I[a][iota_I[a,b]] | ||
``` | ||
|
||
Here `pi_G`, `Theta_G`, `pi_I`, and `Theta_I` are all learnable real-valued parameter vectors and `epsilon` values are batches of vectors the size of state transition matrices. | ||
|
||
Observations `y[t,a,b]` are represented as sequences of real-valued step lengths and turn angles, modelled by zero-inflated Gamma and von Mises likelihoods respectively. | ||
The seal models also include a third observed variable indicating the amount of diving activity between successive locations, which we model with a zero-inflated Beta distribution following [McClintock et al 2018]. | ||
|
||
We grouped animals by sex and implemented versions of this model with (i) no random effects (as a baseline), and with random effects present at the (ii) group, (iii) individual, or (iv) group+individual levels. Unlike the models in [Towner et al 2016], we do not consider fixed effects on any of the parameters. | ||
|
||
# References | ||
* [Obermeyer et al 2019] Obermeyer, F.\*, Bingham, E.\*, Jankowiak, M.\*, Chiu, J., Pradhan, N., Rush, A., and Goodman, N. Tensor Variable Elimination for Plated Factor Graphs, 2019 | ||
* [McClintock et al 2013] McClintock, B. T., Russell, D. J., Matthiopoulos, J., and King, R. Combining individual animal movement and ancillary biotelemetry data to investigate population-level activity budgets. Ecology, 94(4):838–849, 2013 | ||
* [McClintock et al 2018] McClintock, B. T. and Michelot,T. momentuhmm: R package for generalized hidden markov models of animal movement. Methods in Ecology and Evolution, 9(6): 1518–1530, 2018. doi: 10.1111/2041-210X.12995 | ||
* [Towner et al 2016] Towner, A. V., Leos-Barajas, V., Langrock, R., Schick, R. S., Smale, M. J., Kaschke, T., Jewell, O. J., and Papastamatiou, Y. P. Sex-specific and individual preferences for hunting strategies in white sharks. Functional Ecology, 30(8):1397–1407, 2016. |
Empty file.
Oops, something went wrong.