Skip to content

Commit

Permalink
Working StoryLines ProtoType
Browse files Browse the repository at this point in the history
Working prototype for Storylines
- baseline analysis
- Linear Regression + MLE
- Bayesian Linear Regression + MAP/MCMC
- Bayesian Hierarchical Regression + MAP/MCMC
- Gaussian Process Regression + MAP
  • Loading branch information
jejjohnson committed Mar 6, 2024
1 parent ae1467b commit b20db84
Show file tree
Hide file tree
Showing 16 changed files with 28,312 additions and 670 deletions.
225 changes: 111 additions & 114 deletions notebooks/ai4storylines/0.1_main_idea.ipynb

Large diffs are not rendered by default.

3,902 changes: 3,346 additions & 556 deletions notebooks/ai4storylines/1.1_baseline_sa_lr.ipynb

Large diffs are not rendered by default.

1,916 changes: 1,916 additions & 0 deletions notebooks/ai4storylines/1.1_lr_mle.ipynb

Large diffs are not rendered by default.

753 changes: 753 additions & 0 deletions notebooks/ai4storylines/2.1_blr_map.ipynb

Large diffs are not rendered by default.

3,650 changes: 3,650 additions & 0 deletions notebooks/ai4storylines/2.2_blr_mcmc.ipynb

Large diffs are not rendered by default.

777 changes: 777 additions & 0 deletions notebooks/ai4storylines/2.3_blr_mcmc_sa.ipynb

Large diffs are not rendered by default.

4,482 changes: 4,482 additions & 0 deletions notebooks/ai4storylines/3.1_bhm_mcmc.ipynb

Large diffs are not rendered by default.

877 changes: 877 additions & 0 deletions notebooks/ai4storylines/3.2_bhm_mcmc_sa.ipynb

Large diffs are not rendered by default.

900 changes: 900 additions & 0 deletions notebooks/ai4storylines/4.1_gp_map.ipynb

Large diffs are not rendered by default.

File renamed without changes.
1,645 changes: 1,645 additions & 0 deletions notebooks/ai4storylines/dev/3.1_baseline_bhm_vi.ipynb

Large diffs are not rendered by default.

4,276 changes: 4,276 additions & 0 deletions notebooks/ai4storylines/dev/4.2_baseline_bhm_gp_mcmc.ipynb

Large diffs are not rendered by default.

3,048 changes: 3,048 additions & 0 deletions notebooks/ai4storylines/dev/5.1_baseline_bhm_gp_mcmc.ipynb

Large diffs are not rendered by default.

2,106 changes: 2,106 additions & 0 deletions notebooks/ai4storylines/dev/5.2_baseline_bhm_gp_vi.ipynb

Large diffs are not rendered by default.

251 changes: 251 additions & 0 deletions notebooks/ai4storylines/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
from dataclasses import dataclass
from typing import Callable, Dict
from functools import partial
import jax
import jax.numpy as jnp
import jax.random as jrandom
from jaxtyping import Float, Array
import numpyro
import numpyro.distributions as dist
from numpyro.infer import Predictive


# @dataclass
# class ModelPredictorMCMC:
# model: Callable
# posterior_params: Dict
# num_spatial: int

# def predict(self, x: Float[Array, "M D"], rng_key=jrandom.PRNGKey(0), *args, **kwargs) -> Float[Array, "P M N"]:
# predict_fn = Predictive(self.model, posterior_samples=self.posterior_params, parallel=True, *args, **kwargs)
# return predict_fn(rng_key, x=x, num_spatial=self.num_spatial)["obs"]

# def gradient(self, x: Float[Array, "N D"], rng_key=jrandom.PRNGKey(0), *args, **kwargs) -> Float[Array, "P N M D"]:
# fn = partial(self.predict, rng_key=rng_key, *args, **kwargs)
# gradient_fn = jax.jacfwd(self.predict)
# out = gradient_fn(x)[:, 0, ...]
# return out

# def gradient_batch(self, x: Float[Array, "N D"], rng_key=jrandom.PRNGKey(0), *args, **kwargs) -> Float[Array, "P N M D"]:
# raise NotImplemented

@dataclass
class ModelPredictorMCMC:
model: Callable
posterior_params: Dict

def predict(self, x: Float[Array, "M D"], rng_key=jrandom.PRNGKey(0), *args, **kwargs) -> Float[Array, "P M N"]:
predict_fn = Predictive(self.model, posterior_samples=self.posterior_params, parallel=True, *args, **kwargs)
return predict_fn(rng_key, x=x)["obs"]

def gradient(self, x: Float[Array, "N D"], rng_key=jrandom.PRNGKey(0), *args, **kwargs) -> Float[Array, "P N M D"]:
fn = partial(self.predict, rng_key=rng_key, *args, **kwargs)
gradient_fn = jax.jacfwd(self.predict)
out = gradient_fn(x)[:, 0, ...]
return out

def gradient_batch(self, x: Float[Array, "N D"], rng_key=jrandom.PRNGKey(0), *args, **kwargs) -> Float[Array, "P N M D"]:
raise NotImplemented


def model_bhm(num_spatial: int, x: Float[Array, "M D"], y: Float[Array, "M N"]=None):
if len(x.shape) != 2:
x = x[None,:]
num_models, num_covariates = x.shape
# noise
scale = numpyro.sample("scale", dist.LogNormal(0.0, 10.0))
z_scale_mu = numpyro.sample("z_scale_mean", dist.Normal(0.0, 10.0))
z_scale_scale = numpyro.sample("z_scale_scale", dist.LogNormal(0.0, 10.0))
# priors on mean
loc_prior_mean = numpyro.sample("loc_prior_mean", dist.Normal(0.0, 10.))
loc_prior_scale = numpyro.sample("loc_prior_scale", dist.LogNormal(0.0, 10.))
# priors on bias
bias_prior_mean = numpyro.sample("bias_prior_mean", dist.Normal(0.0, 10.))
bias_prior_scale = numpyro.sample("bias_prior_scale", dist.LogNormal(0.0, 10.))
# spatial locations
with numpyro.plate("spatial", num_spatial):
# sample from bias (N,)
bias: Float[Array, "N"] = numpyro.sample("bias", dist.Normal(bias_prior_mean, bias_prior_scale))
# same noise
z_scale: Float[Array, "N"] = numpyro.sample("z_scale", dist.LogNormal(z_scale_mu, z_scale_scale))
with numpyro.plate("covariates", num_covariates), numpyro.plate("spatial", num_spatial):
# sample weights, (N,D)
weight: Float[Array, "N D"] = numpyro.sample("weight", dist.Normal(loc_prior_mean, loc_prior_scale))
# conditional latent variable, (M,N)
z: Float[Array, "M N"] = jnp.einsum("ND,MD->MN",weight, x) + bias + z_scale
# Data Likelihood
numpyro.sample("obs", dist.Normal(z, scale), obs=y)


@dataclass
class LinearRegression:
num_spatial: int

def model(self, x: Float[Array, "M D"], y: Float[Array, "M N"]=None):
if len(x.shape) != 2:
x = x[None,:]
num_models, num_covariates = x.shape

# location: (MxD)
loc = numpyro.param("loc", init_value=jnp.ones((self.num_spatial, num_covariates)))
# bias
bias = numpyro.param("bias", init_value=jnp.ones((self.num_spatial,)))
# location: ()
scale = numpyro.param("scale", init_value=2.0, constraints=dist.constraints.greater_than(0.0))

with numpyro.plate("models", num_models):

z = jnp.einsum("ND,MD->MN",loc,x) + bias
numpyro.sample("obs", dist.Normal(z, scale).to_event(1), obs=y)

@dataclass
class BayesianLinearRegression:
num_spatial: int

def model(self, x: Float[Array, "M D"], y: Float[Array, "M N"]=None):
if len(x.shape) != 2:
x = x[None,:]
num_models, num_covariates = x.shape

# location: (MxD)
loc = numpyro.sample("loc", dist.Normal(0.0, 10.0), sample_shape=(self.num_spatial, num_covariates))
# bias
bias = numpyro.sample("bias", dist.Normal(0.0, 10.0), sample_shape=(self.num_spatial,))
# location: ()
scale = numpyro.sample("scale", dist.LogNormal(0.0, 10.0))

with numpyro.plate("models", num_models):

z = jnp.einsum("ND,MD->MN",loc,x) + bias
numpyro.sample("obs", dist.Normal(z, scale).to_event(1), obs=y)



def sqeuclidean_distance(x: jnp.array, y: jnp.array) -> float:
return jnp.sum((x - y) ** 2)

def kernel_rbf(x: jnp.array, y: jnp.array, length_scale=1.0, variance=1.0) -> float:
return variance * jnp.exp(- sqeuclidean_distance(x/length_scale, y/length_scale))

def gram(kernel_fn, x: jnp.array, y: jnp.array, *args, **kwargs) -> float:
return jax.vmap(lambda x1: jax.vmap(lambda y1: kernel_fn(x1, y1, *args, **kwargs))(y))(x)


def model_bhm_gp(num_spatial: int, x: Float[Array, "M D"], S: Float[Array, "N"], y: Float[Array, "M N"]=None):
num_models, num_features = x.shape
num_spatial = S.shape[0]
# noise
scale = numpyro.sample("scale", dist.LogNormal(0.0, 10.0))
z_scale_mu = numpyro.sample("z_scale_mean", dist.LogNormal(0.0, 10.0))
z_scale_scale = numpyro.sample("z_scale_scale", dist.LogNormal(0.0, 10.0))
# priors on mean
loc_prior_mean = numpyro.sample("loc_prior_mean", dist.Normal(0.0, 10.))
loc_prior_scale = numpyro.sample("loc_prior_scale", dist.LogNormal(0.0, 10.))
# priors on bias
bias_prior_mean = numpyro.sample("bias_prior_mean", dist.Normal(0.0, 10.))
bias_prior_scale = numpyro.sample("bias_prior_scale", dist.LogNormal(0.0, 10.))
# prior on spatial process
variance = numpyro.sample("variance", dist.LogNormal(0.0, 10.0))
length_scale = numpyro.sample("length_scale", dist.LogNormal(0.0, 10.0))
K: Float[Array, "N N"] = gram(kernel_rbf, S, S, length_scale=length_scale, variance=variance)
K += 1e-6 * jnp.eye(K.shape[0])

# weights & biases
with numpyro.plate("spatial", num_spatial), numpyro.plate("covariates", num_features):
# sample weights, (N,D)
weight: Float[Array, "N D"] = numpyro.sample("weight", dist.Normal(loc_prior_mean, loc_prior_scale))

with numpyro.plate("spatial", num_spatial):
# sample bias
bias: Float[Array, "N"] = numpyro.sample("bias", dist.Normal(bias_prior_mean, bias_prior_scale))
# same noise
z_scale: Float[Array, "N"] = numpyro.sample("z_scale", dist.LogNormal(z_scale_mu, z_scale_scale))

# sample spatial process
f: Float[Array, "M"] = numpyro.sample("f", dist.MultivariateNormal(loc=jnp.zeros(S.shape[0]), covariance_matrix=K))

# conditional latent variable, (M,N)
z: Float[Array, "M N"] = jnp.einsum("DN,MD->MN",weight, x) + f + bias + z_scale

# data likelihood
numpyro.sample("obs", dist.Normal(z, scale), obs=y)


def model_gp(x: Float[Array, "M D"], S: Float[Array, "N"], y: Float[Array, "M N"]=None):
num_models, num_features = x.shape
num_spatial = S.shape[0]
# noise
scale = numpyro.sample("scale", dist.HalfCauchy(scale=10))
# prior on spatial process
variance = numpyro.sample("variance", dist.HalfNormal(5.0))
length_scale = numpyro.sample("length_scale", dist.HalfNormal(5.0))
K: Float[Array, "N N"] = gram(kernel_rbf, S, S, length_scale=length_scale, variance=variance)
K += (scale + 1e-6) * jnp.eye(K.shape[0])

# weights & biases
with numpyro.plate("spatial", num_spatial), numpyro.plate("covariates", num_features):
# sample weights, (N,D)
weight: Float[Array, "N D"] = numpyro.sample("weight", dist.Normal(0.0, 10.))

with numpyro.plate("spatial", num_spatial):
# sample bias
bias: Float[Array, "N"] = numpyro.sample("bias", dist.Normal(0.0, 10.))


with numpyro.plate("model", num_models):
# sample spatial process
z: Float[Array, "M N"] = numpyro.deterministic("z", jnp.einsum("DN,MD->MN",weight, x) + bias)
f: Float[Array, "M"] = numpyro.sample(
"obs",
dist.MultivariateNormal(loc=z, covariance_matrix=K),
obs=y
)

def gp_mean(
s: Float[Array, "N 2"],
x: Float[Array, "D"],
weight_x: Float[Array, "N D"],
weight_s: Float[Array, "2"],
bias: Float[Array, ""]):
return jnp.einsum("D,ND->N", weight_s, s) + jnp.einsum("ND,D->N",weight_x, x) + bias

def model_gp_v2(x: Float[Array, "M D"], S: Float[Array, "N"], y: Float[Array, "M N"]=None):
num_models, num_features = x.shape
num_spatial = S.shape[0]
# noise
scale = numpyro.sample("scale", dist.HalfCauchy(scale=10))
# prior on spatial process
variance = numpyro.sample("variance", dist.LogNormal(0.0, 10.0))
length_scale = numpyro.sample("length_scale", dist.LogNormal(0.0, 10.0))
K: Float[Array, "N N"] = gram(kernel_rbf, S, S, length_scale=length_scale, variance=variance)
K += (scale + 1e-6) * jnp.eye(K.shape[0])

weight_s: Float[Array, "N D"] = numpyro.sample("weight_s", dist.Normal(0.0, 2.0), sample_shape=(2,))

# weights & biases
with numpyro.plate("spatial", num_spatial), numpyro.plate("covariates", num_features):
# sample weights, (N,D)
weight: Float[Array, "N D"] = numpyro.sample("weight", dist.Normal(0.0, 10.))

with numpyro.plate("spatial", num_spatial):
# sample bias
bias: Float[Array, "N"] = numpyro.sample("bias", dist.Normal(0.0, 10.))

# z: Float[Array, "M N"] = jnp.einsum("DN,MD->MN",weight, x) + bias
z = jax.vmap(gp_mean, in_axes=(None,0,None,None,None))(S, x, weight.T, weight_s, bias)
# sample spatial process
f: Float[Array, "M"] = numpyro.sample(
"obs",
dist.MultivariateNormal(loc=z, covariance_matrix=K),
obs=y
)


# z: Float[Array, "M N"] = jnp.einsum("DN,MD->MN",weight, x) + bias
# # sample spatial process
# f: Float[Array, "M"] = numpyro.sample(
# "obs",
# dist.MultivariateNormal(loc=z, covariance_matrix=K),
# obs=y
# )

Loading

0 comments on commit b20db84

Please sign in to comment.