Prior¤
Jim priors are built by composing individual prior components with CombinePrior, which joins them into a joint prior. Each component can cover one or more parameters.
CombinePrior¤
CombinePrior takes a list of priors and treats them as independent:
import jax.numpy as jnp
from jimgw.core.prior import CombinePrior, UniformPrior, SinePrior, CosinePrior, PowerLawPrior
prior = CombinePrior([
UniformPrior(10.0, 80.0, ["M_c"]),
UniformPrior(0.125, 1.0, ["q"]),
UniformPrior(-0.99, 0.99, ["s1_z"]),
UniformPrior(-0.99, 0.99, ["s2_z"]),
PowerLawPrior(10.0, 2000.0, 2.0, ["d_L"]),
UniformPrior(-0.1, 0.1, ["t_c"]),
UniformPrior(0.0, 2.0 * jnp.pi, ["phase_c"]),
SinePrior(["iota"]),
UniformPrior(0.0, jnp.pi, ["psi"]),
UniformPrior(0.0, 2.0 * jnp.pi, ["ra"]),
CosinePrior(["dec"]),
])
The order of parameters in prior.parameter_names follows the order they appear in this list.
Basic Priors¤
All priors are importable from jimgw.core.prior.
UniformPrior¤
Flat distribution over [xmin, xmax]:
UniformPrior(xmin, xmax, ["parameter_name"])
PowerLawPrior¤
Power-law distribution \(p(x) \propto x^\alpha\) over [xmin, xmax]:
PowerLawPrior(xmin, xmax, alpha, ["parameter_name"])
Note
xmin must be positive.
SinePrior¤
\(p(\theta) \propto \sin(\theta)\) over \([0, \pi]\). Commonly used for inclination:
SinePrior(["iota"])
CosinePrior¤
\(p(\delta) \propto \cos(\delta)\) over \([-\pi/2, \pi/2]\). Commonly used for declination:
CosinePrior(["dec"])
UniformSpherePrior¤
Uniform prior on the surface of a unit sphere, parameterised by magnitude, polar angle, and azimuthal angle. Useful for spin vectors:
from jimgw.core.prior import UniformSpherePrior
UniformSpherePrior(["s1"]) # creates s1_mag, s1_theta, s1_phi
GaussianPrior¤
Gaussian distribution with given mean and standard deviation:
from jimgw.core.prior import GaussianPrior
GaussianPrior(mean, std, ["parameter_name"])
RayleighPrior¤
Rayleigh distribution with a given scale:
from jimgw.core.prior import RayleighPrior
RayleighPrior(sigma, ["parameter_name"])
Constraints¤
Warning
When custom constraints are applied, the resulting prior is generally not normalised. Jim uses the prior only as an unnormalised log-probability (it never needs the normalisation constant for sampling), so this is fine in practice. However, you should be aware that log_prob values are not comparable across different constrained priors, and any downstream use that assumes a normalised density will be incorrect.
Single-parameter bounds with BoundedMixin¤
BoundedMixin enforces hard bounds on a single parameter: the log-probability is set to \(-\infty\) for any sample outside [xmin, xmax]. You can use it to add bounds to your own priors by subclassing BoundedMixin before the base prior class:
from jimgw.core.prior import BoundedMixin, GaussianPrior
class BoundedGaussianPrior(BoundedMixin, GaussianPrior):
xmin: float
xmax: float
def __init__(self, mean, std, xmin, xmax, parameter_names):
super().__init__(mean, std, parameter_names)
self.xmin = xmin
self.xmax = xmax
The BoundedMixin must appear before the base prior class in the inheritance list so that its log_prob override is resolved first.
Multi-parameter constraints¤
For constraints that span multiple parameters, subclass CombinePrior and override log_prob to add a \(0 / {-\infty}\) penalty. For example, to enforce \(m_1 > m_2\):
import jax.numpy as jnp
from jimgw.core.prior import CombinePrior, UniformPrior
class OrderedMassPrior(CombinePrior):
def log_prob(self, z):
base = super().log_prob(z)
constraint = jnp.where(z["m1"] > z["m2"], 0.0, -jnp.inf)
return base + constraint
prior = OrderedMassPrior([
UniformPrior(1.0, 100.0, ["m1"]),
UniformPrior(1.0, 100.0, ["m2"]),
])