Prior
BoundedMixin
¤
Mixin class that adds bounds checking to log_prob.
This mixin should be placed BEFORE the main prior class in the inheritance list
(e.g., class MyPrior(BoundedMixin, SequentialTransformPrior)) to ensure the
bounds check is applied before delegating to the base prior's log_prob.
Classes using this mixin can override xmin and xmax attributes to set bounds.
By default, the bounds are (-inf, inf), meaning no bounds checking.
The mixin returns -inf for values outside [xmin, xmax].
Note: This is a mixin and should not be used standalone. It relies on the
presence of parameter_names attribute from the Prior class and the
implementation of log_prob in the base class.
CombinePrior
¤
Bases: CompositePrior
Multivariate prior constructed by joining multiple independent priors.
The joint log-probability is the sum of the individual log-probabilities, which is valid when all component priors are independent.
Attributes:
| Name | Type | Description |
|---|---|---|
base_prior |
tuple[Prior, ...]
|
Independent component priors. |
parameter_names |
Names of all parameters in the combined prior. |
Methods:
| Name | Description |
|---|---|
__init__ |
Args: |
log_prob |
Evaluate the joint log-probability as the sum of component log-probabilities. |
sample |
Sample from all component priors independently. |
__init__(priors: list[Prior])
¤
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
priors
|
list[Prior]
|
List of independent prior objects to combine. |
required |
log_prob(z: dict[str, Float]) -> FloatScalar
¤
Evaluate the joint log-probability as the sum of component log-probabilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
dict[str, Float]
|
Dictionary of parameter values. |
required |
Returns:
| Type | Description |
|---|---|
FloatScalar
|
Sum of log-probabilities from all component priors. |
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']]
¤
Sample from all component priors independently.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
Key
|
JAX PRNG key (split internally for each component). |
required |
n_samples
|
int
|
Number of samples to draw. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Float[Array, ' n_samples']]
|
Combined samples from all component priors, keyed by parameter name. |
CompositePrior
¤
Bases: Prior
Composite prior consisting of multiple component priors.
Base class for SequentialTransformPrior and CombinePrior.
Used to build complex prior distributions from simpler ones.
Attributes:
| Name | Type | Description |
|---|---|---|
base_prior |
tuple[Prior, ...]
|
Component prior objects. |
parameter_names |
Names of all parameters in this composite prior. |
Methods:
| Name | Description |
|---|---|
__init__ |
Args: |
trace_prior_parent |
Recursively collect all leaf (non-composite) priors. |
CosinePrior
¤
Bases: BoundedMixin, SequentialTransformPrior
Prior with PDF proportional to cos(x) over [-pi/2, pi/2].
GaussianPrior
¤
Bases: SequentialTransformPrior
Gaussian (normal) prior with specified mean and standard deviation.
Attributes:
| Name | Type | Description |
|---|---|---|
mu |
float
|
Mean of the distribution. |
sigma |
float
|
Standard deviation of the distribution. |
Methods:
| Name | Description |
|---|---|
__init__ |
Args: |
__init__(mu: float, sigma: float, parameter_names: list[str])
¤
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
float
|
Mean of the distribution. |
required |
sigma
|
float
|
Standard deviation of the distribution. |
required |
parameter_names
|
list[str]
|
List with a single parameter name. |
required |
LogisticDistribution
¤
Bases: Prior
One-dimensional logistic distribution prior.
Methods:
| Name | Description |
|---|---|
sample |
Sample from a logistic distribution. |
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']]
¤
Sample from a logistic distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
Key
|
JAX PRNG key. |
required |
n_samples
|
int
|
Number of samples to draw. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Float[Array, ' n_samples']]
|
Dict mapping parameter name to samples of shape |
PowerLawPrior
¤
Bases: SequentialTransformPrior
Power-law prior over [xmin, xmax] with exponent alpha.
Attributes:
| Name | Type | Description |
|---|---|---|
xmin |
float
|
Lower bound of the interval (must be positive). |
xmax |
float
|
Upper bound of the interval. |
alpha |
float
|
Power-law exponent. |
Methods:
| Name | Description |
|---|---|
__init__ |
Args: |
__init__(xmin: float, xmax: float, alpha: float, parameter_names: list[str])
¤
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
xmin
|
float
|
Lower bound (must be positive). |
required |
xmax
|
float
|
Upper bound. |
required |
alpha
|
float
|
Power-law exponent. |
required |
parameter_names
|
list[str]
|
List with a single parameter name. |
required |
Prior
¤
Bases: Module
Base class for prior distributions.
This class should not be used directly. It provides a common interface and bookkeeping for parameter names and transforms.
Methods:
| Name | Description |
|---|---|
__call__ |
Alias for |
__init__ |
Args: |
add_name |
Convert a flat parameter array to a named dict. |
log_prob |
Evaluate the log-probability of a sample. |
sample |
Draw samples from the prior. |
Attributes:
| Name | Type | Description |
|---|---|---|
is_normalized |
bool
|
Return True if this prior is a proper probability distribution (integrates to 1). |
n_dims |
int
|
Number of parameters in this prior. |
is_normalized: bool
property
¤
Return True if this prior is a proper probability distribution (integrates to 1).
Defaults to False for safety. All built-in Jim priors override this to True.
Custom priors must explicitly set is_normalized = True (by overriding this
property) only after verifying that ∫ exp(log_prob(x)) dx == 1.
Samplers that compute Bayesian evidence (NSS, SMC) require a normalized prior. Jim will raise at construction time if this returns False for those backends.
n_dims: int
property
¤
Number of parameters in this prior.
__call__(x: dict[str, Float]) -> FloatScalar
¤
Alias for log_prob.
__init__(parameter_names: list[str])
¤
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameter_names
|
list[str]
|
List of parameter names for this prior. |
required |
add_name(x: Float[Array, n_dims]) -> dict[str, Float]
¤
Convert a flat parameter array to a named dict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, n_dims]
|
Array of parameter values, shape |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Float]
|
Dict mapping parameter names to scalar values. |
log_prob(z: dict[str, Float]) -> FloatScalar
abstractmethod
¤
Evaluate the log-probability of a sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
dict[str, Float]
|
Dict of parameter values. |
required |
Returns:
| Type | Description |
|---|---|
FloatScalar
|
Log-probability scalar. |
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']]
abstractmethod
¤
Draw samples from the prior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
Key
|
JAX PRNG key. |
required |
n_samples
|
int
|
Number of samples to draw. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Float[Array, ' n_samples']]
|
Dict mapping parameter names to arrays of shape |
RayleighPrior
¤
Bases: BoundedMixin, SequentialTransformPrior
Rayleigh distribution prior with scale parameter sigma.
Attributes:
| Name | Type | Description |
|---|---|---|
sigma |
float
|
Scale parameter of the Rayleigh distribution. |
Methods:
| Name | Description |
|---|---|
__init__ |
Args: |
__init__(sigma: float, parameter_names: list[str])
¤
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sigma
|
float
|
Scale parameter of the Rayleigh distribution. |
required |
parameter_names
|
list[str]
|
List with a single parameter name. |
required |
SequentialTransformPrior
¤
Bases: CompositePrior
Prior distribution transformed by a sequence of bijective transforms.
Attributes:
| Name | Type | Description |
|---|---|---|
base_prior |
tuple[Prior, ...]
|
The base prior to transform (must be length 1). |
transforms |
tuple[BijectiveTransform, ...]
|
Transforms applied sequentially in the forward direction. |
parameter_names |
tuple[str, ...]
|
Names of the parameters after all transforms. |
Methods:
| Name | Description |
|---|---|
__init__ |
Args: |
log_prob |
Evaluate the log-probability of a transformed sample z. |
sample |
Sample by drawing from the base prior and applying all transforms. |
transform |
Apply all transforms sequentially (forward direction). |
__init__(base_prior: list[Prior], transforms: list[BijectiveTransform])
¤
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_prior
|
list[Prior]
|
A single-element list containing the base prior. |
required |
transforms
|
list[BijectiveTransform]
|
Ordered list of bijective transforms to apply to samples from the base prior. |
required |
log_prob(z: dict[str, Float]) -> FloatScalar
¤
Evaluate the log-probability of a transformed sample z.
Applies the inverse transforms in reverse order, accumulating log-Jacobian determinants, then evaluates the base prior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
dict[str, Float]
|
Sample in the transformed (output) space. |
required |
Returns:
| Type | Description |
|---|---|
FloatScalar
|
Log-probability of z under the induced distribution. |
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']]
¤
Sample by drawing from the base prior and applying all transforms.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
Key
|
JAX PRNG key. |
required |
n_samples
|
int
|
Number of samples to draw. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Float[Array, ' n_samples']]
|
Transformed samples keyed by parameter name. |
transform(x: dict[str, Float]) -> dict[str, Float]
¤
Apply all transforms sequentially (forward direction).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
dict[str, Float]
|
Sample in the base prior space. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Float]
|
Transformed sample. |
SinePrior
¤
StandardNormalDistribution
¤
Bases: Prior
One-dimensional standard normal (Gaussian) distribution prior.
Methods:
| Name | Description |
|---|---|
sample |
Sample from a standard normal distribution. |
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']]
¤
Sample from a standard normal distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
Key
|
JAX PRNG key. |
required |
n_samples
|
int
|
Number of samples to draw. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Float[Array, ' n_samples']]
|
Dict mapping parameter name to samples of shape |
UniformDistribution
¤
Bases: Prior
One-dimensional uniform distribution prior over [0, 1].
Methods:
| Name | Description |
|---|---|
sample |
Sample from a uniform distribution. |
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']]
¤
Sample from a uniform distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
Key
|
JAX PRNG key. |
required |
n_samples
|
int
|
Number of samples to draw. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Float[Array, ' n_samples']]
|
Dict mapping parameter name to samples of shape |
UniformPrior
¤
Bases: SequentialTransformPrior
Uniform prior over a finite interval [xmin, xmax].
Attributes:
| Name | Type | Description |
|---|---|---|
xmin |
float
|
Lower bound of the interval. |
xmax |
float
|
Upper bound of the interval. |
UniformSpherePrior
¤
Bases: CombinePrior
Uniform prior over a sphere, parameterized by magnitude, polar angle, and azimuth.
Methods:
| Name | Description |
|---|---|
__init__ |
Args: |
__init__(parameter_names: list[str], max_mag: float = 1.0)
¤
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameter_names
|
list[str]
|
Single-element list with the base parameter name.
Expands to |
required |
max_mag
|
float
|
Maximum magnitude of the vector. |
1.0
|