3. Multivariate Bayesian Linear ModelΒΆ

import jax.numpy as jnp
from jax import random
from melvin import LaplaceApproximation
import jax
import matplotlib.pylab as plt
from functools import partial

jax.config.update("jax_enable_x64", True)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
SEED = random.PRNGKey(123068)
N_ROWS = 200
N_PARAMS = 20
NOISE_AMPLITUDE = 2.0
SEED, _seed = random.split(SEED)
true_params = jax.random.normal(key=_seed, shape=(N_PARAMS,))
print(f"True parameters\n{true_params}")

SEED, _seed = random.split(SEED)
X = jax.random.normal(key=_seed, shape=(N_ROWS, N_PARAMS))
print(f"\nFirst 2 rows of data\n{X[:2,:]}")

SEED, _seed = random.split(SEED)
eps = NOISE_AMPLITUDE*jax.random.normal(key=_seed, shape=(N_ROWS,))
y = X @ true_params + eps
print(f"\nFirst 2 labels\n{y[:2]}")
True parameters
[-0.41959653 -0.515441    0.29349301 -1.08905649  0.81179762  1.60976376
 -1.32718381  0.36175878 -1.65091768  0.72506684  0.32725878 -1.15274689
 -1.64335485 -1.25164681  1.36635926  0.63472409  0.42828669 -0.01882997
 -0.77914543  0.49000033]
First 2 rows of data
[[ 1.12425312  1.20534369 -1.49134698 -0.797011   -0.8784741  -0.92112943
  -0.76324367 -0.14223209  0.43083118  0.39633834 -2.00550183  1.14734086
  -0.97638592 -0.60403735 -0.28984657 -0.22324324  0.34924965  0.82033834
  -0.83760291 -1.30904016]
 [ 0.89877167 -1.14499604  0.42068895  0.11273099 -0.61621765  0.0912808
   0.25485903 -0.49365265 -0.24655109 -0.908175    1.78042556  0.72285995
   0.37207329  0.46716234  2.13434933  0.37722536 -1.85845726  0.81766798
   0.00409695 -1.11364174]]
First 2 labels
[-2.12118451  1.43767218]
class BayesianLinearModel(LaplaceApproximation):
    param_bounds = jnp.array([[0.0, jnp.nan]] + [[jnp.nan, jnp.nan]]*N_PARAMS)

    def model(self, params, X):
        weights = params[1:]
        return X @ weights
    
    def log_prior(self, params):
        # Uninformative priors on both parameters
        noise = params[0]
        weights = params[1:]

        weights_log_prior = jax.scipy.stats.norm.logpdf(weights, loc=0.0, scale=100.0)
        noise_log_prior = jax.scipy.stats.expon.logpdf(noise, scale=100.0)

        return noise_log_prior + jnp.sum(weights_log_prior)

    def log_likelihood(self, params, y, y_pred):
        noise = params[0]
        log_like = jax.scipy.stats.norm.logpdf(y, loc=y_pred, scale=noise)
        return jnp.sum(log_like)

SEED, _seed = random.split(SEED)
initial_weights = jax.random.normal(key=_seed, shape=(N_PARAMS,))
initial_noise = jnp.array([0.5])
initial_params = jnp.concatenate([initial_noise, initial_weights])

model = BayesianLinearModel(
    name="Bayesian Linear Model",
    initial_params=initial_params,
    X=X,
    y=y,
)
print(model)
/opt/hostedtoolcache/Python/3.9.5/x64/lib/python3.9/site-packages/scipy/optimize/_minimize.py:524: RuntimeWarning: Method BFGS does not use Hessian information (hess).
  warn('Method %s does not use Hessian information (hess).' % method,
Laplace Approximation: Bayesian Linear Model
Base distribution: normal
Fixed Parameters: []
Fit converged successfully
Fitted Parameters: 
[
  1.8545680191993663 +/- 0.09272195095167172,	 [Lower Bound = 0.0]
  -0.3817689123091307 +/- 0.1494010936747737,
  -0.6289503767735043 +/- 0.14419398401132766,
  0.24525477936465637 +/- 0.14138182213208605,
  -1.0461024063110052 +/- 0.12647925953760678,
  1.0260546709546454 +/- 0.12825306065854983,
  1.374353106948809 +/- 0.1477602636368794,
  -1.2281199536928442 +/- 0.14627975718112793,
  0.48482383719377986 +/- 0.13575234386962834,
  -1.784897716041474 +/- 0.1293721996624074,
  0.6737377511828557 +/- 0.1388664002904215,
  0.6493333626842134 +/- 0.14375742754915977,
  -1.0380128475146542 +/- 0.14765524247581746,
  -1.2232863181923903 +/- 0.13453458593813944,
  -1.2492718564462546 +/- 0.1450679504392249,
  1.4507643357112296 +/- 0.13035399979189052,
  0.7712677082553242 +/- 0.11995210654406582,
  0.5624949147095913 +/- 0.12902459374853462,
  -0.01670756712088431 +/- 0.13251652752627707,
  -0.8227827025516227 +/- 0.15265917255400852,
  0.44370798401974804 +/- 0.13134989662352356,
]
MAP Posterior Prob = -522.4341559217474
print(f"True parameters\n{true_params}")
print(f"\nFitted parameters\n{model.params.x[1:]}")

print(f"\nTrue noise\n{NOISE_AMPLITUDE}")
print(f"\nFitted noise\n{model.params.x[0]}")
True parameters
[-0.41959653 -0.515441    0.29349301 -1.08905649  0.81179762  1.60976376
 -1.32718381  0.36175878 -1.65091768  0.72506684  0.32725878 -1.15274689
 -1.64335485 -1.25164681  1.36635926  0.63472409  0.42828669 -0.01882997
 -0.77914543  0.49000033]

Fitted parameters
[-0.38176891 -0.62895038  0.24525478 -1.04610241  1.02605467  1.37435311
 -1.22811995  0.48482384 -1.78489772  0.67373775  0.64933336 -1.03801285
 -1.22328632 -1.24927186  1.45076434  0.77126771  0.56249491 -0.01670757
 -0.8227827   0.44370798]

True noise
2.0

Fitted noise
1.8545680191993663
SEED, _seed = random.split(SEED,2)
samples = model.sample_params(prng_key = _seed, n_samples = 10000, verbose=True)

params_mean = jnp.mean(samples, axis=0)
params_low = jnp.percentile(samples, q=5, axis=0)
params_upp = jnp.percentile(samples, q=95, axis=0)
Method = simple,	 Perf = -543.9751909661798
Method = importance,	 Perf = -550.7711354973359
**Best Method = simple**
plt.errorbar(
    true_params, 
    model.params.x[1:], 
    yerr=(
        params_upp[1:] - params_mean[1:],
        params_mean[1:] - params_low[1:]
    ), 
    fmt="."
)
plt.plot(
    [-2, 2], [-2, 2], "k--", lw=1
)
plt.xlim([-2, 2])
plt.ylim([-2, 2])
plt.xlabel("True Params")
plt.ylabel("Estimated Params")
plt.show()
_images/multivariate_linear_model_7_0.png