Skip to content

Commit

Permalink
updated readme
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed Jul 14, 2023
1 parent 18c84ea commit 8fa122f
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 45 deletions.
42 changes: 41 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,48 @@ pip install git+https://github.com/arnauqb/blackbirds

Refer to the [docs](https://arnau.ai/blackbirds) for examples and specific API usage. Here is a basic example:

```python
import torch

from blackbirds.models.random_walk import RandomWalk
from blackbirds.infer.vi import VI
from blackbirds.posterior_estimators import TrainableGaussian
from blackbirds.simulate import simulate_and_observe_model

# random walk model
rw = RandomWalk(n_timesteps=100)

# generate synthetic data to fit to
true_p = torch.logit(torch.tensor(0.25))
true_data = rw.observe(rw.run(torch.tensor([true_p])))

# define loss to minimize
class L2Loss:
def __init__(self, model):
self.model = model
self.loss_fn = torch.nn.MSELoss()
def __call__(self, params, data):
observed_outputs = simulate_and_observe_model(self.model, params)
return self.loss_fn(observed_outputs[0], data[0])

# initialize generalized variational inference
posterior_estimator = TrainableGaussian([0.], 1.0)
prior = torch.distributions.Normal(true_p + 0.2, 1)
optimizer = torch.optim.Adam(posterior_estimator.parameters(), 1e-2)
loss = L2Loss(rw)

vi = VI(loss,
posterior_estimator=posterior_estimator,
prior=prior,
optimizer=optimizer,
w = 0) # no regularization

# train the estimator
vi.run(true_data, 1000, max_epochs_without_improvement=100)

```


# 3. Citation

TODO.
TODO. add after JOSS review.
Loading

0 comments on commit 8fa122f

Please sign in to comment.