Skip to content

Prior

BoundedMixin ¤

Mixin class that adds bounds checking to log_prob.

This mixin should be placed BEFORE the main prior class in the inheritance list (e.g., class MyPrior(BoundedMixin, SequentialTransformPrior)) to ensure the bounds check is applied before delegating to the base prior's log_prob.

Classes using this mixin can override xmin and xmax attributes to set bounds. By default, the bounds are (-inf, inf), meaning no bounds checking.

The mixin returns -inf for values outside [xmin, xmax].

Note: This is a mixin and should not be used standalone. It relies on the presence of parameter_names attribute from the Prior class and the implementation of log_prob in the base class.

CombinePrior ¤

Bases: CompositePrior

Multivariate prior constructed by joining multiple independent priors.

The joint log-probability is the sum of the individual log-probabilities, which is valid when all component priors are independent.

Attributes:

Name Type Description
base_prior tuple[Prior, ...]

Independent component priors.

parameter_names

Names of all parameters in the combined prior.

Methods:

Name Description
__init__

Args:

log_prob

Evaluate the joint log-probability as the sum of component log-probabilities.

sample

Sample from all component priors independently.

__init__(priors: list[Prior]) ¤

Parameters:

Name Type Description Default
priors list[Prior]

List of independent prior objects to combine.

required

log_prob(z: dict[str, Float]) -> FloatScalar ¤

Evaluate the joint log-probability as the sum of component log-probabilities.

Parameters:

Name Type Description Default
z dict[str, Float]

Dictionary of parameter values.

required

Returns:

Type Description
FloatScalar

Sum of log-probabilities from all component priors.

sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample from all component priors independently.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key (split internally for each component).

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

Combined samples from all component priors, keyed by parameter name.

CompositePrior ¤

Bases: Prior

Composite prior consisting of multiple component priors.

Base class for SequentialTransformPrior and CombinePrior. Used to build complex prior distributions from simpler ones.

Attributes:

Name Type Description
base_prior tuple[Prior, ...]

Component prior objects.

parameter_names

Names of all parameters in this composite prior.

Methods:

Name Description
__init__

Args:

trace_prior_parent

Recursively collect all leaf (non-composite) priors.

__init__(priors: list[Prior]) ¤

Parameters:

Name Type Description Default
priors list[Prior]

List of component prior objects.

required

trace_prior_parent(output: Optional[list[Prior]] = None) -> list[Prior] ¤

Recursively collect all leaf (non-composite) priors.

Parameters:

Name Type Description Default
output Optional[list[Prior]]

Accumulator list. If None, a new list is created.

None

Returns:

Type Description
list[Prior]

List of all leaf prior objects in this composite.

CosinePrior ¤

Bases: BoundedMixin, SequentialTransformPrior

Prior with PDF proportional to cos(x) over [-pi/2, pi/2].

GaussianPrior ¤

Bases: SequentialTransformPrior

Gaussian (normal) prior with specified mean and standard deviation.

Attributes:

Name Type Description
mu float

Mean of the distribution.

sigma float

Standard deviation of the distribution.

Methods:

Name Description
__init__

Args:

__init__(mu: float, sigma: float, parameter_names: list[str]) ¤

Parameters:

Name Type Description Default
mu float

Mean of the distribution.

required
sigma float

Standard deviation of the distribution.

required
parameter_names list[str]

List with a single parameter name.

required

LogisticDistribution ¤

Bases: Prior

One-dimensional logistic distribution prior.

Methods:

Name Description
sample

Sample from a logistic distribution.

sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample from a logistic distribution.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

Dict mapping parameter name to samples of shape (n_samples,).

PowerLawPrior ¤

Bases: SequentialTransformPrior

Power-law prior over [xmin, xmax] with exponent alpha.

Attributes:

Name Type Description
xmin float

Lower bound of the interval (must be positive).

xmax float

Upper bound of the interval.

alpha float

Power-law exponent.

Methods:

Name Description
__init__

Args:

__init__(xmin: float, xmax: float, alpha: float, parameter_names: list[str]) ¤

Parameters:

Name Type Description Default
xmin float

Lower bound (must be positive).

required
xmax float

Upper bound.

required
alpha float

Power-law exponent.

required
parameter_names list[str]

List with a single parameter name.

required

Prior ¤

Bases: Module

Base class for prior distributions.

This class should not be used directly. It provides a common interface and bookkeeping for parameter names and transforms.

Methods:

Name Description
__call__

Alias for log_prob.

__init__

Args:

add_name

Convert a flat parameter array to a named dict.

log_prob

Evaluate the log-probability of a sample.

sample

Draw samples from the prior.

Attributes:

Name Type Description
is_normalized bool

Return True if this prior is a proper probability distribution (integrates to 1).

n_dims int

Number of parameters in this prior.

is_normalized: bool property ¤

Return True if this prior is a proper probability distribution (integrates to 1).

Defaults to False for safety. All built-in Jim priors override this to True. Custom priors must explicitly set is_normalized = True (by overriding this property) only after verifying that ∫ exp(log_prob(x)) dx == 1.

Samplers that compute Bayesian evidence (NSS, SMC) require a normalized prior. Jim will raise at construction time if this returns False for those backends.

n_dims: int property ¤

Number of parameters in this prior.

__call__(x: dict[str, Float]) -> FloatScalar ¤

Alias for log_prob.

__init__(parameter_names: list[str]) ¤

Parameters:

Name Type Description Default
parameter_names list[str]

List of parameter names for this prior.

required

add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Convert a flat parameter array to a named dict.

Parameters:

Name Type Description Default
x Float[Array, n_dims]

Array of parameter values, shape (n_dims,).

required

Returns:

Type Description
dict[str, Float]

Dict mapping parameter names to scalar values.

log_prob(z: dict[str, Float]) -> FloatScalar abstractmethod ¤

Evaluate the log-probability of a sample.

Parameters:

Name Type Description Default
z dict[str, Float]

Dict of parameter values.

required

Returns:

Type Description
FloatScalar

Log-probability scalar.

sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] abstractmethod ¤

Draw samples from the prior.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

Dict mapping parameter names to arrays of shape (n_samples,).

RayleighPrior ¤

Bases: BoundedMixin, SequentialTransformPrior

Rayleigh distribution prior with scale parameter sigma.

Attributes:

Name Type Description
sigma float

Scale parameter of the Rayleigh distribution.

Methods:

Name Description
__init__

Args:

__init__(sigma: float, parameter_names: list[str]) ¤

Parameters:

Name Type Description Default
sigma float

Scale parameter of the Rayleigh distribution.

required
parameter_names list[str]

List with a single parameter name.

required

SequentialTransformPrior ¤

Bases: CompositePrior

Prior distribution transformed by a sequence of bijective transforms.

Attributes:

Name Type Description
base_prior tuple[Prior, ...]

The base prior to transform (must be length 1).

transforms tuple[BijectiveTransform, ...]

Transforms applied sequentially in the forward direction.

parameter_names tuple[str, ...]

Names of the parameters after all transforms.

Methods:

Name Description
__init__

Args:

log_prob

Evaluate the log-probability of a transformed sample z.

sample

Sample by drawing from the base prior and applying all transforms.

transform

Apply all transforms sequentially (forward direction).

__init__(base_prior: list[Prior], transforms: list[BijectiveTransform]) ¤

Parameters:

Name Type Description Default
base_prior list[Prior]

A single-element list containing the base prior.

required
transforms list[BijectiveTransform]

Ordered list of bijective transforms to apply to samples from the base prior.

required

log_prob(z: dict[str, Float]) -> FloatScalar ¤

Evaluate the log-probability of a transformed sample z.

Applies the inverse transforms in reverse order, accumulating log-Jacobian determinants, then evaluates the base prior.

Parameters:

Name Type Description Default
z dict[str, Float]

Sample in the transformed (output) space.

required

Returns:

Type Description
FloatScalar

Log-probability of z under the induced distribution.

sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample by drawing from the base prior and applying all transforms.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

Transformed samples keyed by parameter name.

transform(x: dict[str, Float]) -> dict[str, Float] ¤

Apply all transforms sequentially (forward direction).

Parameters:

Name Type Description Default
x dict[str, Float]

Sample in the base prior space.

required

Returns:

Type Description
dict[str, Float]

Transformed sample.

SinePrior ¤

Bases: BoundedMixin, SequentialTransformPrior

Prior with PDF proportional to sin(x) over [0, pi].

StandardNormalDistribution ¤

Bases: Prior

One-dimensional standard normal (Gaussian) distribution prior.

Methods:

Name Description
sample

Sample from a standard normal distribution.

sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample from a standard normal distribution.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

Dict mapping parameter name to samples of shape (n_samples,).

UniformDistribution ¤

Bases: Prior

One-dimensional uniform distribution prior over [0, 1].

Methods:

Name Description
sample

Sample from a uniform distribution.

sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample from a uniform distribution.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

Dict mapping parameter name to samples of shape (n_samples,).

UniformPrior ¤

Bases: SequentialTransformPrior

Uniform prior over a finite interval [xmin, xmax].

Attributes:

Name Type Description
xmin float

Lower bound of the interval.

xmax float

Upper bound of the interval.

UniformSpherePrior ¤

Bases: CombinePrior

Uniform prior over a sphere, parameterized by magnitude, polar angle, and azimuth.

Methods:

Name Description
__init__

Args:

__init__(parameter_names: list[str], max_mag: float = 1.0) ¤

Parameters:

Name Type Description Default
parameter_names list[str]

Single-element list with the base parameter name. Expands to <name>_mag, <name>_theta, <name>_phi.

required
max_mag float

Maximum magnitude of the vector.

1.0